#include <stdio.h>
#include <stdlib.h>
#define real double

int qrdec(real**,real**,int,int);
int qrback(real**,real**,int,int,real*,real*);

real** NewMatrix(int rows, int cols){
	real** a = (real **)malloc(cols*sizeof(real *)); 
	for (int i=0; i<cols ; i++) a[i] = (real *)malloc(rows*sizeof(real)); 
	return a;}

void print_matrix(real** a,int n,int m){
	for(int i=0;i<n;i++){
		for(int j=0;j<m;j++) printf("%.3f\t",a[j][i]); printf("\n"); } }

void print_vector(real* a,int n){
  for(int i=0;i<n;i++) printf("%.3f\n",a[i]);}

void a_times_b(real** a,int n,int m,real** b,int l,real** c){
  for(int i=0;i<n;i++){
		for(int j=0;j<l;j++){
			c[j][i]=0; for(int k=0;k<m;k++) c[j][i]+=a[k][i]*b[j][k]; }}}

void aT_times_b(real** a,int n,int m,real** b,int l,real** c){
  for(int i=0;i<m;i++){
		for(int j=0;j<l;j++){
			c[j][i]=0;
			for(int k=0;k<n;k++) c[j][i]+=a[i][k]*b[j][k]; } } }

void mult_mat_vec(real ** a,int n,int m,real * b,real * c){
  for(int i=0;i<n;i++){ c[i]=0; for(int j=0;j<m;j++) c[i]+=a[j][i]*b[j];}}

int main(int argc, char** argv){
int n=3,m=3;
if(argc>1) {n=atoi(argv[1]); m=n;}
if(argc>2) m=atoi(argv[2]);
if(m>n) m=n;

real** a = NewMatrix(n,m);
real** r = NewMatrix(m,m);
real** q = NewMatrix(n,m);

for(int i=0;i<n;i++)for(int j=0;j<m;j++) q[j][i]=(double)(i==j?(i+j+1):1);

printf("\nmatrix A:\n"); print_matrix(q,n,m);

qrdec(q,r,n,m);
printf("\nmatrix R (should be right-triangular) :\n");
print_matrix(r,m,m);
printf("\nmatrix Q :\n"); print_matrix(q,n,m);
aT_times_b(q,n,m,q,m,a);
printf("\nmatrix Q^T*Q (should be 1) :\n"); print_matrix(a,m,m);

real d=1; for(int i=0; i<m ; i++) d*=r[i][i];
printf("\ndeterminant of R:\t%g\n",d);
a_times_b(q,n,m,r,m,a);
printf("\nmatrix product QR (should be equal A) :\n");
print_matrix(a,n,m);

real* b = (real *)calloc(n,sizeof(real)); 
for(int i=0;i<n;i++) b[i]=i+1;
printf("\nvector b:\n"); print_vector(b,n);

real* x = (real *)calloc(m,sizeof(real)); 
real* ax = (real *)calloc(n,sizeof(real)); 
qrback(q,r,n,m,b,x);
printf("\nleast squares solution x to Ax=b:\n"); print_vector(x,m);

mult_mat_vec(a,n,m,x,ax);
printf("\ncheck: vector Ax (should be close to b):\n");
print_vector(ax,n);

}
