#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "real.h"
#define frac .1
#define tiny 1e-9
#define min_trial 15
#define rvector(n) (real*)calloc((n),sizeof(real))
#define ivector(n) (int*)calloc((n),sizeof(int))

void plainmc(int n,real a[],real b[],real (*func)(real*),
      int N,real *sum,real *sum2,void (*rnd)(int,real*,real*,real*));

void strata(int n,real a[],real b[],real (*func)(real*),
	int N,real *sum,real *sum2,void (*rnd)(int,real*,real*,real*))
{
if(N < min_trial*4) plainmc(n,a,b,func,N,sum,sum2,rnd);
else{
	int i,j,N_trial,*NL,*NR,j_big=0,N_left,N_right;
	real *SL,*SR,*S2L,*S2R,big=0,avrL,varL,avrR,varR,*x,f;
	real *bL,*aR,sumL,sum2L,sumR,sum2R;

	SL=rvector(n); S2L=rvector(n); NL=ivector(n);
	SR=rvector(n); S2R=rvector(n); NR=ivector(n);
	N_trial = N*frac; if(N_trial<min_trial)N_trial=min_trial;
	x=rvector(n); *sum=0; *sum2=0;
	for(i=1; i<=N_trial; i++){
		(*rnd)(n,a,b,x); f=(*func)(x); *sum+=f; *sum2+=f*f;
		for(j=0;j<n;j++){
			if(x[j]<(a[j]+b[j])/2)
				{NL[j]+=1; SL[j]+=f; S2L[j]+=f*f;}
			else 
				{NR[j]+=1; SR[j]+=f; S2R[j]+=f*f;}
			}
		}
	big=0;j_big=0;
	for(j=0;j<n;j++){
		avrL=SL[j]/NL[j]; varL=S2L[j]/NL[j]-avrL*avrL;
		avrR=SR[j]/NR[j]; varR=S2R[j]/NR[j]-avrR*avrR;
		if(fabs(varL-varR)>big){big=fabs(varL-varR); j_big=j;}
		}
	avrL=SL[j_big]/NL[j_big]; varL=S2L[j_big]/NL[j_big]-avrL*avrL;
	avrR=SR[j_big]/NR[j_big]; varR=S2R[j_big]/NR[j_big]-avrR*avrR;
//printf("varL,varR=\t%g %g\n",varL,varR);
	if(varL<tiny)varL=tiny;
	if(varR<tiny)varR=tiny;
	N_left=ceil(varL*(N-N_trial)/(varL+varR));
	if(N_left==N-N_trial)N_left=N-N_trial-1;
	N_right=N-N_trial-N_left;
//printf("NL,NR=\t%d %d\n",N_left,N_right);
	aR=rvector(n); bL=rvector(n);
	for(j=0;j<n;j++){bL[j]=b[j]; aR[j]=a[j];}
	bL[j_big]=(a[j_big]+b[j_big])/2; aR[j_big]=bL[j_big];
//printf("ar=\t%g\n",aR[j_big]);
	strata(n,a,bL,func,N_left,&sumL,&sum2L,rnd);
//printf("1 done\n");
	strata(n,aR,b,func,N_right,&sumR,&sum2R,rnd);
//printf("2 done %d %d\n",N_left,N_right);
	*sum=(*sum)+(sumL+sumR);
//printf("sum=\t%g\n",*sum);
	*sum2=(*sum2)+(sum2L+sum2R);
//printf("sum2=\t%g\n",*sum2);
	}
}
