#ifndef __FORCES_H__
#define __FORCES_H__

BEGIN_VLADO

template<class Particle>
class SpringForce {
	public:
		Particle *p0, *p1; // i0, i1;
		float restLen;
		float ks, kd;
		float scale;

		int init(Particle &p0, Particle &p1, float ks, float kd) {
			restLen=length(p0.getPos()-p1.getPos());
			if (restLen<1e-6f) return false;
			this->p0=&p0;
			this->p1=&p1;
			this->ks=ks;
			this->kd=kd;
			scale=1.0f;
			return true;
		}

		void applyForces(void) {
			Vector d=p1->getPos()-p0->getPos();
			float dlen=d.length();
			
			float dp=(scale-dlen/restLen)*ks;
			float dv=((p1->getVel()-p0->getVel())*d/(dlen*restLen))*kd;

			d*=(dp-dv)/dlen;

			p0->applyForce(-d);
			p1->applyForce(d);
		}
};

struct ForceTriplet {
	int v0, v1, v2;
	float uc[3], vc[3];

	float uvArea;
	float a, b, c, d;

	Vector normal;
	float normalLength, invLength;

	float len0, len1, cs, sn;
};

template<class Particle>
class StretchShearForce {
	public:
		ForceTriplet *tri;
		Particle *parts[3];
		float ku, kv, kc, kd;
		float scale;

		void init(Particle &p0, Particle &p1, Particle &p2, ForceTriplet &triplet, float ks, float kc, float kd) {
			this->tri=&triplet;
			parts[0]=&p0;
			parts[1]=&p1;
			parts[2]=&p2;
			ku=kv=ks;
			this->kc=kc;
			this->kd=kd;
			scale=1.0f;

			// Compute the matrix for conversion from world space to texture space
			float du0=tri->uc[1]-tri->uc[0], du1=tri->uc[2]-tri->uc[0];
			float dv0=tri->vc[1]-tri->vc[0], dv1=tri->vc[2]-tri->vc[0];
			float area=du0*dv1-dv0*du1;
			float D=1.0f/area;
			if (area<0.0f) area=-area;

			tri->uvArea=area;                                                                                                                                                                                                                                                                                                                                                                 
			tri->a=dv1*D;
			tri->b=-dv0*D;
			tri->c=-du1*D;
			tri->d=du0*D;
		}

		void applyForces(void) {
			// Compute wu and wv
			Vector x0=parts[1]->getPos()-parts[0]->getPos(), x1=parts[2]->getPos()-parts[0]->getPos();
			Vector wu=x0*tri->a+x1*tri->b;
			Vector wv=x0*tri->c+x1*tri->d;

			float wu_len=length(wu);
			float wv_len=length(wv);

			// Compute derivatives
			Vector wnu=wu/wu_len;
			Vector wnv=wv/wv_len;

			// Compute conditions
			float cu=tri->uvArea*(wu_len-scale);
			float cv=tri->uvArea*(wv_len-scale);
			float cs=tri->uvArea*(wnu*wnv);

			Vector ud1=tri->a*wnu, vd1=tri->c*wnv;
			Vector ud2=tri->b*wnu, vd2=tri->d*wnv;

			Vector dp1_wu_wv=tri->a*wv+wu*tri->c;
			Vector dp1_wu_len_wv_len=wu*tri->a*sqr(wv_len)+wv*tri->c*sqr(wu_len);
			Vector sd1=(dp1_wu_wv*(wu_len*wv_len)-(wnu*wnv)*dp1_wu_len_wv_len)/sqr(wu_len*wv_len);

			Vector dp2_wu_wv=tri->b*wv+wu*tri->d;
			Vector dp2_wu_len_wv_len=wu*tri->b*sqr(wv_len)+wv*tri->d*sqr(wu_len);
			Vector sd2=(dp2_wu_wv*(wu_len*wv_len)-(wnu*wnv)*dp2_wu_len_wv_len)/sqr(wu_len*wv_len);

			// compute du/dt, dv/dt
			Vector dvel0=parts[1]->getVel()-parts[0]->getVel();
			Vector dvel1=parts[2]->getVel()-parts[0]->getVel();

			float u_dt=(ud1*dvel0+ud2*dvel1)/(ud1*ud1+ud2*ud2);
			float v_dt=(vd1*dvel0+vd2*dvel1)/(vd1*vd1+vd2*vd2);
			float s_dt=(sd1*dvel0+sd2*dvel1)/(sd1*sd1+sd2*sd2);

			// compute -ks*c-kd*ct
			float k_u=-ku*cu-kd*u_dt;
			float k_v=-kv*cv-kd*v_dt;
			float k_s=-kc*cs-kd*s_dt;

			// compute the forces
			Vector f1=k_u*ud1+k_v*vd1+k_s*sd1;
			Vector f2=k_u*ud2+k_v*vd2+k_s*sd2;

			parts[0]->applyForce(-f1-f2);
			parts[1]->applyForce(f1);
			parts[2]->applyForce(f2);
		}
};

template<class Particle>
class BendForce {
	public:
		Particle *parts[4];
		ForceTriplet *tri0, *tri1;
		float ks, kd;

		void init(Particle &p0, Particle &p1, Particle &p2, Particle &p3, ForceTriplet &tri0, ForceTriplet &tri1, float ks, float kd) {
			this->parts[0]=&p0;
			this->parts[1]=&p1;
			this->parts[2]=&p2;
			this->parts[3]=&p3;
			this->tri0=&tri0;
			this->tri1=&tri1;
			this->ks=ks;
			this->kd=kd;
		}

		void applyForces(void) {
			Vector ed=(parts[1]->getPos())-(parts[0]->getPos());
			float c=length(ed);

			Vector &n0=tri0->normal;
			real a=tri0->invLength;

			Vector &n1=tri1->normal;
			real b=tri1->invLength;

			real sinalpha=(n0^n1)*ed/c;
			real cosalpha=n0*n1;

			// the bend condition
			real bend=atan2f(sinalpha, cosalpha);

			// the derivatives
			Vector cd[4];
			cd[2]=n0*(-a*c);
			cd[3]=n1*(-b*c);

			float kt0=((parts[2]->getPos()-parts[0]->getPos())*ed)/(c*c);
			float kt1=((parts[3]->getPos()-parts[0]->getPos())*ed)/(c*c);

			cd[1]=-cd[2]*kt0-cd[3]*kt1;
			cd[0]=-cd[2]-cd[3]-cd[1];

			// compute bend/dt
			real bend_dt=0.0f;
			for (int i=0; i<4; i++) bend_dt+=cd[i]*parts[i]->getVel();

			// compute the forces
			real k=-(ks*bend+kd*bend_dt);
			for (i=0; i<4; i++) parts[i]->applyForce(k*cd[i]);
		}
};

END_VLADO

#endif