#ifndef __FORCES_H__
#define __FORCES_H__

BEGIN_VLADO

// A simple linear spring force
template<class Particle>
class SpringForce {
public:
	Particle *p0, *p1;
	float restLen;
	float ks, kd;
	float scale;

	int init(Particle &p0, Particle &p1, float ks, float kd) {
		restLen=length(p0.getPos()-p1.getPos());
		if (restLen<1e-6f) return false;
		this->p0=&p0;
		this->p1=&p1;
		this->ks=ks;
		this->kd=kd;
		scale=1.0f;
		return true;
	}

	void applyForces(void) {
		Vector d=p1->getPos()-p0->getPos();
		float dlen=d.length();
		
		float dp=(scale-dlen/restLen)*ks;
		float dv=((p1->getVel()-p0->getVel())*d/(dlen*restLen))*kd;

		d*=(dp-dv)/dlen;

		p0->applyForce(-d);
		p1->applyForce(d);
	}
};

// The stretch/shear force is the same as described by Baraff
template<class Particle>
class StretchShearForce {
public:
	Particle *parts[3];
	float ku, kv, kc, kd;
	float scale;

	float uvArea;
	float ac, bc, cc, dc;
	float uc[3], vc[3];

	void init(Particle &p0, Particle &p1, Particle &p2, float ks, float kc, float kd, float iuc[3], float ivc[3]) {
		parts[0]=&p0;
		parts[1]=&p1;
		parts[2]=&p2;
		ku=kv=ks;
		this->kc=kc;
		this->kd=kd;
		scale=1.0f;

		// Compute the matrix for conversion from world space to texture space
		for (int i=0; i<3; i++) { uc[i]=iuc[i]; vc[i]=ivc[i]; }

		float du0=uc[1]-uc[0], du1=uc[2]-uc[0];
		float dv0=vc[1]-vc[0], dv1=vc[2]-vc[0];
		float area=du0*dv1-dv0*du1;
		float D=1.0f/area;
		if (area<0.0f) area=-area;

		uvArea=area;                                                                                                                                                                                                                                                                                                                                                                 
		ac=dv1*D;
		bc=-dv0*D;
		cc=-du1*D;
		dc=du0*D;
	}

	void applyForces(void) {
		// Compute wu and wv
		Vector x0=parts[1]->getPos()-parts[0]->getPos(), x1=parts[2]->getPos()-parts[0]->getPos();
		Vector wu=x0*ac+x1*bc;
		Vector wv=x0*cc+x1*dc;

		float wu_len=length(wu);
		float wv_len=length(wv);

		// Compute derivatives
		Vector wnu=wu/wu_len;
		Vector wnv=wv/wv_len;

		// Compute conditions
		float cu=uvArea*(wu_len-scale);
		float cv=uvArea*(wv_len-scale);
		float cs=uvArea*(wnu*wnv);

		Vector ud1=ac*wnu, vd1=cc*wnv;
		Vector ud2=bc*wnu, vd2=dc*wnv;

		Vector dp1_wu_wv=ac*wv+wu*cc;
		Vector dp1_wu_len_wv_len=wu*ac*sqr(wv_len)+wv*cc*sqr(wu_len);
		Vector sd1=(dp1_wu_wv*(wu_len*wv_len)-(wnu*wnv)*dp1_wu_len_wv_len)/sqr(wu_len*wv_len);

		Vector dp2_wu_wv=bc*wv+wu*dc;
		Vector dp2_wu_len_wv_len=wu*bc*sqr(wv_len)+wv*dc*sqr(wu_len);
		Vector sd2=(dp2_wu_wv*(wu_len*wv_len)-(wnu*wnv)*dp2_wu_len_wv_len)/sqr(wu_len*wv_len);

		// compute du/dt, dv/dt
		Vector dvel0=parts[1]->getVel()-parts[0]->getVel();
		Vector dvel1=parts[2]->getVel()-parts[0]->getVel();

		float u_dt=(ud1*dvel0+ud2*dvel1)/(ud1*ud1+ud2*ud2);
		float v_dt=(vd1*dvel0+vd2*dvel1)/(vd1*vd1+vd2*vd2);
		float s_dt=(sd1*dvel0+sd2*dvel1)/(sd1*sd1+sd2*sd2);

		// compute -ks*c-kd*ct
		float k_u=-ku*cu-kd*u_dt;
		float k_v=-kv*cv-kd*v_dt;
		float k_s=-kc*cs-kd*s_dt;

		// compute the forces
		Vector f1=k_u*ud1+k_v*vd1+k_s*sd1;
		Vector f2=k_u*ud2+k_v*vd2+k_s*sd2;

		parts[0]->applyForce(-f1-f2);
		parts[1]->applyForce(f1);
		parts[2]->applyForce(f2);
	}
};

// The tri-spring force is based on the idea to transform a distorted triangle
// into its limit (original) state in one time step.
template<class Particle>
class Trispring {
public:
	Particle *parts[3];
	float ks, kd;
	float scale;

	float uc[3], vc[3];
	float ac, bc, cc, dc;
	float pc, qc, rc;
	float I;

	void init(Particle &p0, Particle &p1, Particle &p2, float iks, float ikd) {
		parts[0]=&p0;
		parts[1]=&p1;
		parts[2]=&p2;
		ks=iks;
		kd=ikd;
		scale=1.0f;

		// Generate texture coordinates
		Vlado::Vector e0=p1.getPos()-p0.getPos();
		Vlado::Vector e1=p2.getPos()-p0.getPos();

		Vlado::Vector n=e0^e1;
		Vlado::Vector u=Vlado::Vector(0.0f, 1.0f, 0.0f)^n;
		if (u.lengthSqr()<1e-12f) u=n^Vlado::Vector(1.0f, 0.0f, 0.0f);
		u.makeNormalized();
		Vlado::Vector v=Vlado::normalize(n^u);

		uc[0]=0.0f; vc[0]=0.0f;
		uc[1]=e0*u; vc[1]=e0*v;
		uc[2]=e1*u; vc[2]=e1*v;

		// Center the texture coordinates about the origin
		float um=(uc[0]+uc[1]+uc[2])/3.0f;
		float vm=(vc[0]+vc[1]+vc[2])/3.0f;
		for (int i=0; i<3; i++) { uc[i]-=um; vc[i]-=vm; }

		// Compute the matrix for conversion from world space to texture space
		float du0=uc[1]-uc[0], du1=uc[2]-uc[0];
		float dv0=vc[1]-vc[0], dv1=vc[2]-vc[0];
		float area=du0*dv1-dv0*du1;
		float D=1.0f/area;
		if (area<0.0f) area=-area;

		ac=dv1*D;
		bc=-dv0*D;
		cc=-du1*D;
		dc=du0*D;

		pc=uc[0]*uc[0]+uc[1]*uc[1]+uc[2]*uc[2];
		qc=uc[0]*vc[0]+uc[1]*vc[1]+uc[2]*vc[2];
		rc=vc[0]*vc[0]+vc[1]*vc[1]+vc[2]*vc[2];

		I=1.0f/(pc+rc);
	}

	void applyForces(void) {
		Vector p[3]={ parts[0]->getPos(), parts[1]->getPos(), parts[2]->getPos() };
		Vector v[3]={ parts[0]->getVel(), parts[1]->getVel(), parts[2]->getVel() };

		// Stretching
		Vector x0=p[1]-p[0], x1=p[2]-p[0];
		Vector wu=x0*ac+x1*bc;
		Vector wv=x0*cc+x1*dc;

		Vector a=pc*wu+qc*wv;
		Vector b=qc*wu+rc*wv;

		Vector n=normalize(wu^wv);
		Vector su=a+(b^n);
		Vector sv=b-(a^n);
		float k=sqrtf(2.0f/(su*su+sv*sv));
		su*=k;
		sv*=k;

		Vector du=su-wu;
		Vector dv=sv-wv;

		for (int i=0; i<3; i++) {
			Vector f=du*uc[i]+dv*vc[i];
			parts[i]->applyForce(f*ks);
		}

		// Damping
		Vector V=(v[0]+v[1]+v[2])/3.0f;
		float vu[3], vv[3];
		for (int i=0; i<3; i++) { v[i]-=V; vu[i]=v[i]*su; vv[i]=v[i]*sv; }

		// Calculate the angular momentum in final local plane coordinates
		float Ln=(uc[0]*vv[0]-vc[0]*vu[0])+(uc[1]*vv[1]-vc[1]*vu[1])+(uc[2]*vv[2]-vc[2]*vu[2]);
		float wn=Ln*I;

		// Calculate the final velocity
		for (int i=0; i<3; i++) {
			Vector f=(-vc[i]*wn-vu[i])*su+(uc[i]*wn-vv[i])*sv;
			parts[i]->applyForce(f*kd);
		}
	}
};

// The bend force is the same as described by Baraff
template<class Particle, class Face>
class BendAngle {
public:
	Particle *parts[4];
	Face *tri0, *tri1;
	float ks, kd;

	void init(Particle &p0, Particle &p1, Particle &p2, Particle &p3, Face &tri0, Face &tri1, float ks, float kd) {
		this->parts[0]=&p0;
		this->parts[1]=&p1;
		this->parts[2]=&p2;
		this->parts[3]=&p3;
		this->tri0=&tri0;
		this->tri1=&tri1;
		this->ks=ks;
		this->kd=kd;
	}

	void applyForces(void) {
		Vector ed=(parts[1]->getPos())-(parts[0]->getPos());
		float c=length(ed);

		Vector &n0=tri0->getNormal();
		real a=tri0->getInvNormalLength();

		Vector &n1=tri1->getNormal();
		real b=tri1->getInvNormalLength();

		real sinalpha=(n0^n1)*ed/c;
		real cosalpha=n0*n1;

		// the bend condition
		real bend=atan2f(sinalpha, cosalpha);

		// the derivatives
		Vector cd[4];
		cd[2]=n0*(-a*c);
		cd[3]=n1*(-b*c);

		float kt0=((parts[2]->getPos()-parts[0]->getPos())*ed)/(c*c);
		float kt1=((parts[3]->getPos()-parts[0]->getPos())*ed)/(c*c);

		cd[1]=-cd[2]*kt0-cd[3]*kt1;
		cd[0]=-cd[2]-cd[3]-cd[1];

		// compute bend/dt
		real bend_dt=0.0f;
		for (int i=0; i<4; i++) bend_dt+=cd[i]*parts[i]->getVel();

		// compute the forces
		real k=-(ks*bend+kd*bend_dt);
		for (i=0; i<4; i++) parts[i]->applyForce(k*cd[i]);
	}
};

template<class Particle>
class BendProjection {
public:
	Particle *parts[4];
	float ks, kd;

	void init(Particle &p0, Particle &p1, Particle &p2, Particle &p3, float ks, float kd) {
		this->parts[0]=&p0;
		this->parts[1]=&p1;
		this->parts[2]=&p2;
		this->parts[3]=&p3;
		this->ks=ks;
		this->kd=kd;
	}

	void applyForces(void) {
		Vector p[4], v[4];
		for (int i=0; i<4; i++) { p[i]=parts[i]->getPos(); v[i]=parts[i]->getVel(); }

		Vector uc=p[1]-p[0], vc=p[3]-p[2];
		Vector n=uc^vc;
		Vector n1=n/(n*n);

		Vector P=(p[0]+p[1]+p[2]+p[3])/4.0f;
		Vector V=(v[0]+v[1]+v[2]+v[3])/4.0f;
		float kp[4], kv[4];
		Vector f[4];
		for (int i=0; i<4; i++) { p[i]-=P; v[i]-=V; kv[i]=v[i]*n1; v[i]=n*kv[i]; kp[i]=-p[i]*n1; f[i]=n*kp[i]; p[i]+=f[i]; }

		Vector q0=p[0];

		float u0=(q0^vc)*n1, v0=(uc^q0)*n1;
		float u1=u0+1.0f, v1=v0;
		float u2=-u0-0.5f, v2=-v0-0.5f;
		float u3=u2, v3=v2+1.0f;

		float uu=uc*uc, uv=uc*vc, vv=vc*vc;

		float pc=4.0f*u0*(u0+1.0f)+1.5f; // u0*u0+u1*u1+u2*u2+u3*u3;
		float qc=4.0f*u0*v0+2.0f*v0; // u0*v0+u1*v1+u2*v2+u3*v3;
		float rc=4.0f*v0*v0+0.5f; // v0*v0+v1*v1+v2*v2+v3*v3;

		float a=uv*qc+vv*rc, b=-uv*pc-vv*qc;
		float c=-uu*qc-uv*rc, d=uu*pc+uv*qc;

		float D=1.0f/(a*d-b*c);

		float su=u0*kv[0]+u1*kv[1]+u2*kv[2]+u3*kv[3];
		float sv=v0*kv[0]+v1*kv[1]+v2*kv[2]+v3*kv[3];

		float Lu=uv*su+vv*sv;
		float Lv=-uu*su-uv*sv;

		float wu=(Lu*d-Lv*b)*D;
		float wv=(a*Lv-c*Lu)*D;

		Vector av=uc*wu+vc*wv;

		float psu=u0*kp[0]+u1*kp[1]+u2*kp[2]+u3*kp[3];
		float psv=v0*kp[0]+v1*kp[1]+v2*kp[2]+v3*kp[3];

		float pLu=uv*psu+vv*psv;
		float pLv=-uu*psu-uv*psv;

		float pwu=(pLu*d-pLv*b)*D;
		float pwv=(a*pLv-c*pLu)*D;

		Vector pav=uc*pwu+vc*pwv;

		for (int i=0; i<4; i++) {
			Vector q=p[i];
			Vector fv=(av^q)-v[i];
			Vector fp=-(pav^q)+f[i];
			parts[i]->applyForce(fp*ks+fv*kd);
		}
	}
};

END_VLADO

#endif