#ifndef __VGL_COLLISION_H__
#define __VGL_COLLISION_H__

BEGIN_VLADO

#include "Sequence.h"
#include "CollisionMesh.h"

template<class Mesh>
class CollisionParticle {
	public:
		Mesh *cmesh; // The collision mesh to which the particle belongs
		int index; // The index of the particle within the mesh

		void set(Mesh *cm, int index) { cmesh=cm; this->index=index; }
};

// A particle - the basic simulation unit
template<class Mesh>
struct Particle {
	Mesh *obj;
	float invMass, origInvMass; // The inverse of the particle mass; 0.0 means the particle is fixed (not affected by the simulation)

	Vlado::Vector v; // Current velocity
	Vlado::Vector nv; // New velocity computed by the itegrator

	Vlado::Vector pp; // Previous position (used to detect intersections with other objects)
	Vlado::Vector rp; // Used by the implicit solver
	Vlado::Vector p; // Current position
	Vlado::Vector np; // New position computed by the integrator

	Vlado::Vector sv, sp, spp; // Used to back-step if necessary

	Vlado::Vector f; // Forces acting on the particle
	Vlado::Vector posImp[2]; // Position impulse - used to correct particle position for collision response
	Vlado::Vector velImp[2]; // Velocity impulse - used to correct particle velocity for collision response

	Vlado::Vector q0, q1, q2; // These are the start and end position per frame of the particle if it's a fixed one
	Vlado::Vector qv0, qv1; // The velocity of the particle if it's a fixed one
	Vlado::Vector constraintVector;
	Vlado::Vector nnp, nnv;

	unsigned char collide; // True if the particle is allowed to collide with deflectors
	unsigned char selfCollide; // True if the particle is allowed to collide with other cloth particles
	unsigned char checkIntersections; // False if the particle is not allowed to participate in intersection checks

	unsigned char correctPos;

	Vector getPos(void) { return np; }
	Vector* getPosAddr(void) { return &np; }

	Vector getVel(void) { return nv; }
	Vector* getVelAddr(void) { return &nv; }

	void applyForce(Vector &force) { f+=force; }
	Vector *getForceAddr(void) { return &f; }
};

enum CollisionType {
	c_vertex_face,
	c_edge_edge,
};

template<class Mesh>
class Collision {
	public:
		CollisionType type; // The type of collision
		BVolume *b0, *b1; // The colliding elements
		int valid; // true if the collision is valid and should be considered
		union { // Specifies the collision points within the elements
			struct { float uc, vc; };
			struct { float t0, t1; };
		};

		int numParts; // number of particles in the collision

		int deflectorCollision;
		Vector posOffset, velOffset;

		Particle<Mesh> *parts[8]; // the particles
		real weights[8]; // weights of the particles in the collision

		Vector normal; // the normal of the collision
		real dist; // desired distace between the elements
		real dynamicFrictionCoeff; // friction coefficient
		real staticFrictionCoeff;
		real bounceCoeff; // bounce coefficient

		float wsum;
		int ipath;

		void initWSum(void) {
			wsum=0.0f;
			for (int i=0; i<numParts; i++) {
				wsum+=sqr(weights[i])*(parts[i]->invMass);
			}
		}

		Vector relPos(void) {
			Vector rpos(0,0,0);
			for (int i=0; i<numParts; i++) rpos+=parts[i]->np*weights[i];
			return rpos+posOffset;
		}

		Vector relVel(void) {
			Vector rvel(0,0,0);
			for (int i=0; i<numParts; i++) rvel+=parts[i]->nv*weights[i];
			return rvel+velOffset;
		}

		Vector relForce(void) {
			Vector rf(0,0,0);
			for (int i=0; i<numParts; i++) rf+=parts[i]->f*weights[i];
			return rf;
		}

		Vector relPosImp(void) {
			Vector rpos(0,0,0);
			for (int i=0; i<numParts; i++) rpos+=(parts[i]->np+parts[i]->posImp[0]+parts[i]->posImp[1])*weights[i];
			return rpos+posOffset;
		}

		Vector relVelImp(void) {
			Vector rvel(0,0,0);
			for (int i=0; i<numParts; i++) rvel+=(parts[i]->nv+parts[i]->velImp[0]+parts[i]->velImp[1])*weights[i];
			return rvel+velOffset;
		}

		void dispatchPos(void) {
			Vector rpos=relPosImp();
			real pc=(dist-rpos*normal);
			Vector pimp=pc/wsum*normal;
			for (int i=0; i<numParts; i++) parts[i]->posImp[deflectorCollision]+=(weights[i]*parts[i]->invMass)*pimp;
		}

		float getDistRel(void) {
			Vector rpos=relPosImp();
			real pc=rpos*normal/dist;
			return pc;
		}

		void dispatchVel(void) {
			Vector rvel=relVelImp();
			real vc=-(rvel*normal);
			Vector vimp=vc/wsum*normal;
			for (int i=0; i<numParts; i++) parts[i]->velImp[deflectorCollision]+=(weights[i]*parts[i]->invMass)*vimp;
		}

		// Friction force proportional to the relative tangential velocity of the elements
		void dispatchFrictionForce() {
			Vector rvel=relVel(); // Relative velocity
			real nv=rvel*normal; // Normal velocity
			Vector tvel=rvel-normal*nv; // Tangential velocity
			Vector fimp=-tvel*dynamicFrictionCoeff;
			for (int i=0; i<numParts; i++) {
				if (parts[i]->invMass>0.0f) parts[i]->f+=(weights[i]/sqr(parts[i]->invMass))*fimp;
			}
		}

		void addParticle(Mesh *cm, int index, float weight) {
			parts[numParts]=cm->getParticle(index);
			weights[numParts]=weight;
			numParts++;
		}

		void addParticleConst(Mesh *cm, int index, float weight) {
			Particle<Mesh> &p=*(cm->getParticle(index));
			posOffset+=p.np*weight;
			velOffset+=p.nv*weight;
		}

		int initParticles() {
			if (type==c_vertex_face) initVertexFace();
			else initEdgeEdge();

			if (ipath && (relPos()*normal>0.0f)) normal=-normal;

			// staticFrictionCoeff*=0.5f;
			dynamicFrictionCoeff*=0.5f;
			// bounceCoeff*=0.5f;

			initWSum();
			Vector rvel=relVel();
			float k=rvel*normal;

			if (wsum<1e-6f || (!ipath && k>1e-4f)) valid=false;
			else valid=true;

			float len=length(rvel);
			if (len>1e-12f) k/=len; else k=0.0f;

			dynamicFrictionCoeff*=fabsf(k);
			if (dynamicFrictionCoeff<0.0f) dynamicFrictionCoeff=0.0f;

			return valid;
		}

		void initVertexFace(void) {
			BVertex<Mesh> *bv=(BVertex<Mesh>*) b0;
			BFace<Mesh> *bf=(BFace<Mesh>*) b1;

			numParts=0;
			posOffset.makeZero();
			velOffset.makeZero();
			deflectorCollision=0;

			switch (bv->cmesh->objType) {
				case obj_deflector: {
					addParticle(bv->cmesh, bv->idx, 1.0f);
					deflectorCollision=1;
					break;
				}
				case obj_cloth:
					addParticle(bv->cmesh, bv->idx, 1.0f);
					break;
				case obj_rigid: {
					Vector p=bv->cmesh->getVertexObjPos(bv->idx);
					addParticle(bv->cmesh, 0, 1.0f-p.x-p.y-p.z);
					addParticle(bv->cmesh, 1, p.x);
					addParticle(bv->cmesh, 2, p.y);
					addParticle(bv->cmesh, 3, p.z);
					break;
				}
			}

			switch (bf->cmesh->objType) {
				case obj_deflector: {
					addParticle(bf->cmesh, bf->idx[0], -(1.0f-uc-vc));
					addParticle(bf->cmesh, bf->idx[1], -uc);
					addParticle(bf->cmesh, bf->idx[2], -vc);
					deflectorCollision=1;
					break;
				}
				case obj_cloth: {
					addParticle(bf->cmesh, bf->idx[0], -(1.0f-uc-vc));
					addParticle(bf->cmesh, bf->idx[1], -uc);
					addParticle(bf->cmesh, bf->idx[2], -vc);
					break;
				}
				case obj_rigid: {
					Vector p=
						bf->cmesh->getVertexObjPos(bf->idx[0])*(1.0f-uc-vc)+
						bf->cmesh->getVertexObjPos(bf->idx[1])*uc+
						bf->cmesh->getVertexObjPos(bf->idx[2])*vc;

					addParticle(bf->cmesh, 0, -(1.0f-p.x-p.y-p.z));
					addParticle(bf->cmesh, 1, -p.x);
					addParticle(bf->cmesh, 2, -p.y);
					addParticle(bf->cmesh, 3, -p.z);
					break;
				}
			}

			// staticFrictionCoeff=bv->cmesh->getStaticFriction()+bf->cmesh->getStaticFriction();
			dynamicFrictionCoeff=bv->cmesh->getDynamicFriction()+bf->cmesh->getDynamicFriction();
			// bounceCoeff=bv->cmesh->getBounce()+bf->cmesh->getBounce();
		}

		void initEdgeEdge(void) {
			BEdge<Mesh> *be0=(BEdge<Mesh>*) b0;
			BEdge<Mesh> *be1=(BEdge<Mesh>*) b1;

			numParts=0;
			posOffset.makeZero();
			velOffset.makeZero();
			deflectorCollision=0;

			switch (be0->cmesh->objType) {
				case obj_deflector: {
					addParticle(be0->cmesh, be0->idx[0], 1.0f-t0);
					addParticle(be0->cmesh, be0->idx[1], t0);
					deflectorCollision=1;
					break;
				}
				case obj_cloth: {
					addParticle(be0->cmesh, be0->idx[0], 1.0f-t0);
					addParticle(be0->cmesh, be0->idx[1], t0);
					break;
				}
				case obj_rigid: {
					Vector p=
						be0->cmesh->getVertexObjPos(be0->idx[0])*(1.0f-t0)+
						be0->cmesh->getVertexObjPos(be0->idx[1])*t0;
					addParticle(be0->cmesh, 0, 1.0f-p.x-p.y-p.z);
					addParticle(be0->cmesh, 1, p.x);
					addParticle(be0->cmesh, 2, p.y);
					addParticle(be0->cmesh, 3, p.z);
					break;
				}
			}

			switch (be1->cmesh->objType) {
				case obj_deflector: {
					addParticle(be1->cmesh, be1->idx[0], -(1.0f-t1));
					addParticle(be1->cmesh, be1->idx[1], -t1);
					deflectorCollision=1;
					break;
				}
				case obj_cloth: {
					addParticle(be1->cmesh, be1->idx[0], -(1.0f-t1));
					addParticle(be1->cmesh, be1->idx[1], -t1);
					break;
				}
				case obj_rigid: {
					Vector p=
						be1->cmesh->getVertexObjPos(be1->idx[0])*(1.0f-t1)+
						be1->cmesh->getVertexObjPos(be1->idx[1])*t1;
					addParticle(be1->cmesh, 0, -(1.0f-p.x-p.y-p.z));
					addParticle(be1->cmesh, 1, -p.x);
					addParticle(be1->cmesh, 2, -p.y);
					addParticle(be1->cmesh, 3, -p.z);
					break;
				}
			}

			// staticFrictionCoeff=be0->cmesh->getStaticFriction()+be1->cmesh->getStaticFriction();
			dynamicFrictionCoeff=be0->cmesh->getDynamicFriction()+be1->cmesh->getDynamicFriction();
			// bounceCoeff=be0->cmesh->getBounce()+be1->cmesh->getBounce();
		}
};

// typedef SequenceCounter<Collision> CollisionListCounter;
// typedef Sequence<Collision> CollisionList;

END_VLADO

#endif