#ifndef __VGL_CGSOLVER_H__
#define __VGL_CGSOLVER_H__

BEGIN_VLADO

// basic resolve target - this is what the resolver works with
class CGTarget {
	public:
		real *r0, *r1, *p; // used by the resolver, must be allocated by the target
		virtual int getNumVars(void)=0; // returns the number of variables

		virtual void addChanges(real *x, real scale)=0; // adds the given values
		virtual void computeChanges(real *x)=0; // computes the required changes

		void solve(real precision, int maxSteps, real conjugateGradient) {
			int n=getNumVars();

			int s=0;
			double oldt=0.0f;
			int i;
			double t=0.0f;

			while (1) {
				// evaluate the residual
				computeChanges(r0);

				// check convergence
				t=0.0f;
				for (i=0; i<n; i++) t+=double(r0[i])*double(r0[i]);
				if ((fabs(t-oldt)<precision*precision || t<precision) && s>0) break;

				// compute direction
				if (s==0) {
					for (i=0; i<n; i++) p[i]=r0[i];
					oldt=t;
				} else {
					double b=t/oldt*conjugateGradient;
					for (i=0; i<n; i++) p[i]=real(r0[i]+b*p[i]);
					oldt=t;
				}

				// take a step in the direction
				addChanges(p, 1.0f);

				// evaluate the residual again
				computeChanges(r1);

				// compute distance along p so that [r0+(r1-r0)*t]*r0[i]=0
				double t1=0.0f;
				for (i=0; i<n; i++) t1+=double(r1[i]-r0[i])*double(r0[i]);

				if (fabs(t1)<1e-6f) break;
				double k=-t/t1;

				// move along that distance
				addChanges(p, real(k-1.0f));

				s++;
				if (s>maxSteps) break;
			}
		}
};

class CGVectorTarget {
	public:
		union {
			Vector *r0, *r1, *p; // Temp variables used by the solve(), must be allocated by the target
			float *r0f, *r1f, *pf;
		};
		virtual int getNumVars(void)=0; // Returns the number of variables

		virtual void addChanges(Vector *x, real scale)=0; // Adds the given values
		virtual void computeChanges(Vector *x, int firstCall)=0; // Computes the residual

		bool solve(real precision, int maxSteps, real conjugateGradient) {
			int n=getNumVars();

			int s=0, i;
			double oldt=0.0f;

			float startStep=1.0f;
			computeChanges(r0, true);
			addChanges(r0, startStep);

			while (1) {
				// Evaluate the residual
				computeChanges(r0, false);

				// Check convergence
				double t=0.0f;
				for (i=0; i<n; i++) t+=r0[i]*r0[i];
				if (t<precision/**sqr(startStep)*/ && s>0) break;

				// Compute direction
				if (t>oldt && s>0 && startStep>1e-8f) {
					// The error is increasing, which means that the previous guess was too far away
					// Restart and move in smaller steps
					startStep*=0.5f;
					s=0;
					oldt=0.0;
					continue;
				}

				double t0;
				if (s==0) {
					for (i=0; i<n; i++) p[i]=r0[i];
					oldt=t0=t;
				} else {
					real b=float(t/oldt)*conjugateGradient;
					if (b<1e-12f) b=0.0f;
					t0=0.0f;
					for (i=0; i<n; i++) {
						p[i]=r0[i]+b*p[i];
						t0+=r0[i]*p[i];
					}
					oldt=t;
				}

				// Take a step in the direction
				addChanges(p, startStep);

				// Evaluate the residual again
				computeChanges(r0, false);

				// Compute distance along p so that [r0+(r1-r0)*t]*p=0
				double t1=0.0f;
				for (i=0; i<n; i++) t1+=r0[i]*p[i];
				t1-=t0;

				if (fabs(t1)<1e-6f) break;
				real k=-float(t0/t1);
				if (k<0.0f) k=0.0f;

				// Move along that distance
				addChanges(p, real(k-1.0f)*startStep);

				s++;
				if (s>maxSteps) break;
			}

			return false;
		}
};

template<class CGVectorTarget, int useSSE>
int cgSolveVectorTarget(CGVectorTarget &target, real precision, int maxSteps, real conjugateGradient, real startStep) {
	int n=target.getNumVars();

	int nf4=n*3/4; // The array size in packets of 4 floats
	int nfr=n*3; // The array size in floats

	int s=0, i;
	double oldt=0.0f;
	Vector *r=target.get_r();
	float *rf=(float*) r;
	Vector *p=target.get_p();
	float *pf=(float*) p;
	Vector *r1=target.get_r1();
	float *r1f=(float*) r1;

	Vector *temp=target.get_temp();

	// Initial Euler step
	target.computeChanges(r, true);
	static Random rnd(1);
	for (int i=0; i<n*3; i++) rf[i]*=1.0f+(rnd.rnd()-0.5f)*0.01f;
	target.addChanges(r, 1.0f);

	int totalSteps=0;
	while (1) {
		// Evaluate the residual
		target.computeChanges(r, false);

		// Check convergence
		double t, maxt;
		if (!useSSE) {
			t=0.0f; maxt=0.0f;
			for (i=0; i<n; i++) {
				float k=r[i]*r[i];
				if (k>maxt) maxt=k;
				t+=k;
			}
		} else {
			__m128 tf=_mm_setzero_ps(), maxtf=_mm_setzero_ps();
			int i;
			for (i=0; i<nf4; i++) {
				__m128 kf=_mm_load_ps(rf+i*4);
				kf=_mm_mul_ps(kf, kf);
				maxtf=_mm_max_ps(maxtf, kf);
				tf=_mm_add_ps(tf, kf);
			}
			tf=_mm_add_ps(tf, _mm_shuffle_ps(tf, tf, _MM_SHUFFLE(2, 3, 0, 1)));
			tf=_mm_add_ps(tf, _mm_shuffle_ps(tf, tf, _MM_SHUFFLE(0, 1, 2, 3)));
			float temp;
			_mm_store_ss(&temp, tf);
			t=temp;

			maxtf=_mm_max_ps(maxtf, _mm_shuffle_ps(maxtf, maxtf, _MM_SHUFFLE(2, 3, 0, 1)));
			maxtf=_mm_max_ps(maxtf, _mm_shuffle_ps(maxtf, maxtf, _MM_SHUFFLE(0, 1, 2, 3)));
			_mm_store_ss(&temp, maxtf);
			maxt=temp;

			i*=4;
			for (; i<nfr; i++) {
				float k=rf[i]*rf[i];
				if (k>maxt) maxt=k;
				t+=k;
			}
		}

		if (!_finite(t)) return -1;

		if (maxt<precision) break;

		if (s==0) {
			for (i=0; i<n; i++) p[i]=r[i];
			oldt=t;
		} else {
			if (t>oldt*1.5f) return maxSteps;

			double b=Max(0.0, (t/oldt)*conjugateGradient); 
			oldt=t;
			if (!useSSE) for (int i=0; i<n; i++) p[i]=r[i]+b*p[i];
			else {
				__m128 bf=_mm_set_ps1(b);
				int i;
				for (i=0; i<nf4; i++) _mm_store_ps(pf+i*4, _mm_add_ps(_mm_load_ps(rf+i*4), _mm_mul_ps(_mm_load_ps(pf+i*4), bf)));
				i*=4;
				for (; i<nfr; i++) pf[i]=rf[i]+b*pf[i];
			}
		}

		double t0;
		if (!useSSE) {
			t0=0.0f;
			for (int i=0; i<n; i++) t0+=r[i]*p[i];
		} else {
			__m128 t0f=_mm_setzero_ps();
			int i;
			for (i=0; i<nf4; i++) t0f=_mm_add_ps(t0f, _mm_mul_ps(_mm_load_ps(rf+i*4), _mm_load_ps(pf+i*4)));
			t0f=_mm_add_ps(t0f, _mm_shuffle_ps(t0f, t0f, _MM_SHUFFLE(2, 3, 0, 1)));
			t0f=_mm_add_ps(t0f, _mm_shuffle_ps(t0f, t0f, _MM_SHUFFLE(0, 1, 2, 3)));
			float temp;
			_mm_store_ss(&temp, t0f);
			t0=temp;
			i*=4;
			for (; i<nfr; i++) t0+=rf[i]*pf[i];
		}

		target.addChanges(p, startStep);
		target.computeChanges(r, false);

		// Compute distance along p so that [r0+(r1-r0)*t]*p=0
		double t1;
		if (!useSSE) {
			t1=0.0f;
			for (int i=0; i<n; i++) t1+=r[i]*p[i];
		} else {
			__m128 t1f=_mm_setzero_ps();
			int i;
			for (i=0; i<nf4; i++) t1f=_mm_add_ps(t1f, _mm_mul_ps(_mm_load_ps(rf+i*4), _mm_load_ps(pf+i*4)));
			t1f=_mm_add_ps(t1f, _mm_shuffle_ps(t1f, t1f, _MM_SHUFFLE(2, 3, 0, 1)));
			t1f=_mm_add_ps(t1f, _mm_shuffle_ps(t1f, t1f, _MM_SHUFFLE(0, 1, 2, 3)));
			float temp;
			_mm_store_ss(&temp, t1f);
			t1=temp;
			i*=4;
			for (; i<nfr; i++) t1+=rf[i]*pf[i];
		}

		t1-=t0;

		if (fabs(t1)<1e-6f) {
			// target.addChanges(p, -startStep);
			break;
		}

		float k=-float(t0/t1);
		if (k<0.0f) k=0.0f;

		// Move along that distance
		target.addChanges(p, (k-1.0f)*startStep);

		s++;
		totalSteps++;
		if (totalSteps>maxSteps) break;
	}

	return totalSteps;
}

inline double addNumbers(real *x, int n, double &maxerr) {
	/*
	if (n<10) {
		double s=0.0f;
		for (int i=0; i<n; i++) s+=x[i]*x[i];
		return s;
	}

	int m=n/2;
	return addNumbers(x, m)+addNumbers(x+m, n-m);
	*/
	double res=x[0]*x[0];
	double err=0.0f;
	for (int i=1; i<n; i++) {
		double k=x[i]*x[i];
		res+=k;
		err=Max(err, k);
	}
	maxerr=err;
	return res;
}

template<class CGTarget>
int cgSolveTarget(CGTarget &target, real precision, int maxSteps, real conjugateGradient, real startStep) {
	int n=target.getNumVars();

	int s=0, i;
	double t, oldt=0.0f;
	double oldMult=1.0f;

	real *r=target.get_r();
	real *r1=target.get_r1();
	real *p=target.get_p();

	// Initial Euler step
	target.computeChanges(r);
	target.addChanges(r, startStep);

	// target.computeChanges(r);
	// t=addNumbers(r, n);

	double maxerr;

	int totalSteps=0;
	while (1) {
		target.computeChanges(r);
		t=addNumbers(r, n, maxerr);

		// Check convergence
		if (!_finite(t)) return -1;
		if (maxerr<precision*precision) break;

		if (s==0) {
			for (i=0; i<n; i++) r1[i]=p[i]=r[i];
			oldt=t;
		} else {
			if (t>oldt) {
				for (int i=0; i<n; i++) p[i]=r[i];
			} else {
				double b=Max(0.0, (t/oldt)*conjugateGradient);
				oldt=t;
				for (int i=0; i<n; i++) p[i]=r[i]+b*p[i];
			}
		}

		double t0=0.0f;
		for (int i=0; i<n; i++) t0+=r[i]*p[i];

		target.addChanges(p, startStep);
		target.computeChanges(r);

		// Compute distances along p so that [r0+(r1-r0)*t]*p=0
		double t1=0.0f;
		for (int i=0; i<n; i++) t1+=r[i]*p[i];

		t1-=t0;
		if (fabs(t1)<1e-6f) break;

		float k=-float(t0/t1);
		if (k<=1e-6f) break;

		// Move along that distance
		target.addChanges(p, (k-1.0f)*startStep);

		// target.computeChanges(r);
		// t=addNumbers(r, n);

		s++;
		totalSteps++;
		if (totalSteps>maxSteps) break;
	}

	return totalSteps;
}

END_VLADO

#endif