mutoo
11/8/2013 - 10:09 AM

ann_gesture_recognition

ann_gesture_recognition

class Net
{
	int numOfInputs, numOfHiddens, numOfOutputs;
	ArrayList<Node> inputNodes, hiddenNodes, outputNodes;

	Net(int inputs, int hiddens, int outputs, float learningRate, float momentum) {
		// set net component
		numOfInputs = inputs;
		numOfHiddens = hiddens;
		numOfOutputs = outputs;

		createNet(learningRate, momentum);
	}

	void createNet(float learningRate, float momentum) {
		// input layer
		inputNodes = new ArrayList<Node>();
		for(int i=0; i<numOfInputs; i++) {
			inputNodes.add(new InputNode());
		}

		inputNodes.add(new InputNode(-1));

		// hidden layer
		hiddenNodes = new ArrayList<Node>();
		for(int j=0; j<numOfHiddens; j++) {
			hiddenNodes.add(new HiddenNode(learningRate));
		}

		hiddenNodes.add(new InputNode(-1));

		// output layer
		outputNodes = new ArrayList<Node>();
		for(int k=0; k<numOfOutputs; k++) {
			outputNodes.add(new OutputNode(learningRate));
		}

		// connet inputlayer with hidden layer
		for(int i=0; i<numOfInputs+1; i++) {
			for (int j = 0; j<numOfHiddens; j++){
				Arc arc = new Arc(momentum);
				inputNodes.get(i).connect(arc, hiddenNodes.get(j));
			}
		}

		// connet hidden layer with output layer
		for (int j = 0; j<numOfHiddens+1; j++){
			for (int k = 0; k<numOfOutputs; k++){
				Arc arc = new Arc(momentum);
				hiddenNodes.get(j).connect(arc, outputNodes.get(k));
			}
		}
	}

	FloatList run(FloatList inputs) {
		if(inputs.size()!=numOfInputs) {
			println("number of inputs is wrong!");
			exit();
		}

		FloatList outputs = new FloatList();

		for(int i=0; i<numOfInputs; i++){
			inputNodes.get(i).update(inputs.get(i));
		}

		for(int j=0; j<numOfHiddens; j++){
			hiddenNodes.get(j).update();
		}

		for(int k=0; k<numOfOutputs; k++){
			outputs.append(outputNodes.get(k).update());
		}

		return outputs;
	}

	void train(ArrayList<TrainingPair> trainingSets) {
		float error = 0;
		for(TrainingPair tp:trainingSets){
			FloatList inputs = tp.inputs;
			FloatList outputs = run(inputs);

			for(int k=0; k<numOfOutputs; k++) {
				OutputNode outputNode = (OutputNode)(outputNodes.get(k));
				outputNode.setError(tp.outputs.get(k));
				outputNode.train();
				error += abs(outputNode.error);
			}

			for(int j=0; j<numOfHiddens; j++) {
				HiddenNode hiddenNode = (HiddenNode)(hiddenNodes.get(j));
				hiddenNode.train();
				error += abs(hiddenNode.error);
			}
		}
		// println(error);
	}

	// render net on a canvas
	void render(PGraphics pg) {
		pg.beginDraw();
		pg.rectMode(RADIUS);
		pg.ellipseMode(RADIUS);
		pg.textAlign(CENTER, CENTER);

		int x, xdist, y, ydist, nodeRadius;
		nodeRadius = 5;
		x = xdist = pg.width/5;
		y = pg.height / 2;
		ydist = pg.height/(max(numOfInputs, numOfHiddens, numOfOutputs)+1);

		int[] numOfLayers = {numOfInputs, numOfHiddens, numOfOutputs};
		IntList nodeY, nodeOldY;
		nodeOldY = new IntList();
		for(int l=0; l<3; l++) {
			int numOfLayer = numOfLayers[l];
			nodeY = new IntList();
			for (int i = 0; i<numOfLayer; i++){
				if(i==0 && numOfLayer%2==0) {
						y -= ydist/2;
				}

				ArrayList<Node> layer = new ArrayList<Node>();
				switch (l){
					case 0 :
						layer = inputNodes;
					break;	
					case 1 :
						layer = hiddenNodes;
					break;
					case 2 :
						layer = outputNodes;
					break;
					default :
						println("wrong index of layer!");
						exit();
				}

				Node node = layer.get(i);

				// draw arcs
				for(int j=0; j<nodeOldY.size(); j++) {
					pg.stroke((node.inArcs.get(j).weight+1)/2*128);
					pg.line(x,y,x-xdist+nodeRadius,nodeOldY.get(j));
					// pg.fill(255);
					// pg.text(node.inArcs.get(j).weight,x-xdist/2,(nodeOldY.get(j)+y)/2);
				}

				pg.stroke(0);
				if(l==2) {
					pg.line(x,y,x+xdist,y);
				}

				// draw nodes
				// pg.fill(255);
				if(l==0) {
					pg.fill((node.value+1)/2*255,0,0);
					pg.rect(x, y, nodeRadius, nodeRadius);
				} else {
					pg.fill(0,node.value*255,0);
					pg.ellipse(x, y, nodeRadius, nodeRadius);
				}

				// pg.fill(0);
				// pg.text(node.value, x, y);

				nodeY.append(y);
				y += pow(-1, i%2)*(i+1)*ydist;
			}
			nodeOldY = nodeY;
			x += xdist;
			y = pg.height / 2;
		}
		

		pg.endDraw();
	}
};

// arc for connect nodes
class Arc
{
	Node in, out;
	float weight, delta, momentum;

	Arc(float momentum) {
		weight = random(-1, 1);
		this.momentum = momentum;
		delta = 0;
	}

	void setInputNode(Node node) {
		in = node;
	}

	void setOutputNode(Node node) {
		out = node;
	}

	float getInputValue() {
		return in.value;
	}

	float getWeightedInputValue() {
		return in.value * weight;
	}

	float getWeightedOutputError() {
		return out.error * weight;
	}

	void updateWeight(float delta) {
		this.weight += delta + this.momentum * this.delta;
		this.delta = delta;
	}
};

class Node
{
	ArrayList<Arc> inArcs, outArcs;
	float value, error;

	Node() {
		inArcs = new ArrayList<Arc>();
		outArcs = new ArrayList<Arc>();
		error = 0;
		value = 0;
	}

	// connect node with new arc;
	void connect(Node node) {
		Arc arc = new Arc(0);
		this.connect(arc, node);
	}

	// connect node with an arc;
	void connect(Arc arc, Node node) {
		this.addOutputArc(arc);
		node.addInputArc(arc);
		arc.setInputNode(this);
		arc.setOutputNode(node);
	}

	void addInputArc(Arc arc) {
		inArcs.add(arc);
	}

	void addOutputArc(Arc arc) {
		outArcs.add(arc);
	}

	float update(float value) {
		return this.value = value;
	}

	float update() {
		float result = 0;
		for(Arc arc:inArcs) {
			result += arc.getWeightedInputValue();
		}
		return this.value = sigmoidTransfer(result);
	}

	float sigmoidTransfer(float value) {
		return 1.0/(1.0+exp(-value));
	}
};

class InputNode extends Node
{
	InputNode() {
		super();
	}

	InputNode(float value) {
		this.value = value;
	}
};

class HiddenNode extends Node
{
	float learningRate;

	HiddenNode(float learningRate) {
		this.learningRate = learningRate;
	}

	void train() {
		float total = 0;
		for(Arc arc:outArcs){
			total += arc.getWeightedOutputError();
		}
		this.error = value*(1-value)*total;

		for(Arc arc:inArcs) {
			arc.updateWeight(arc.getInputValue() * error * learningRate);
		}
	}
};

class OutputNode extends Node
{
	float learningRate;

	OutputNode(float learningRate) {
		this.learningRate = learningRate;
	}

	void setError(float target) {
		this.error = (target - value) * value * (1 - value);
	}

	void train() {
		for(Arc arc:inArcs) {
			arc.weight += arc.getInputValue() * error * learningRate;
		}
	}
};


PGraphics pg;

Net net;
ArrayList<TrainingPair> trainingSet;

void setup(){
	size(512, 512);
	pg = createGraphics(width, 512);
	net = new Net(24, 12, 5, 0.5, 0.9);

	createTrainingSet();

	for(int i=0;i<10;i++)
		net.train(trainingSet);

	points = new ArrayList<PVector>();
	inputs = createFloatList(1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0);
}

void createTrainingSet() {
	trainingSet = new ArrayList<TrainingPair>();
	trainingSet.add(new TrainingPair(
		createFloatList(1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0),createFloatList(1,0,0,0,0)));
	trainingSet.add(new TrainingPair(
		createFloatList(-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0),createFloatList(0,1,0,0,0)));
	trainingSet.add(new TrainingPair(
		createFloatList(0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1),createFloatList(0,0,1,0,0)));
	trainingSet.add(new TrainingPair(
		createFloatList(0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1),createFloatList(0,0,0,1,0)));
}

FloatList inputs;

void draw(){
	for(int i=0;i<5;i++)
		net.train(trainingSet);

	net.run(inputs); 
	pg.beginDraw();
	pg.background(200);

	pg.stroke(0);
	pg.strokeWeight(1);
	net.render(pg);

	pg.fill(0,0,255);
	pg.stroke(0,0,255);
	// pg.strokeWeight(1);
	for(int i=1;i<points.size();i++){
		PVector prev = points.get(i-1);
		pg.line(prev.x, prev.y, points.get(i).x, points.get(i).y);
		pg.ellipse(points.get(i).x, points.get(i).y, 3, 3);
	}

	pg.endDraw();
	image(pg, 0, 0);
}

void keyPressed() {
	if(key==' ') {
		createTrainingSet();
		trainingSet.add(new TrainingPair(inputs, createFloatList(0,0,0,0,1)));
		net = new Net(24, 12, 5, 0.5, 0.9);
		for(int i=0;i<100;i++)
			net.train(trainingSet);
	}
}

ArrayList<PVector> points;
boolean drawing = false;
void mousePressed() {
	drawing = true;
	points.clear();
}

void mouseDragged() {
	if(drawing) {
		// if(points.size()>0){
			// PVector last = points.get(points.size()-1);
			// if(dist(last.x, last.y, mouseX, mouseY)>50)
				// points.add(new PVector(mouseX,mouseY));
		// }else{
			points.add(new PVector(mouseX,mouseY));
		// }
	}
}

void mouseReleased() {
	drawing = false;
	pointHandle();
}

void pointHandle() {
	if(points.size()<2) {
		println("points do not enough!");
		return;
	}

	float totalLen = 0;
	for(int i=1;i<points.size();i++) {
		PVector prev = points.get(i-1);
		totalLen += prev.dist(points.get(i));
	}

	float deltaLen = totalLen/13;
	float target = 0;
	float total = 0;
	float lastAngle = 0;
	int pointIndex = 1;

	inputs = new FloatList();
	for(int i=0; i<12; i++) {
		target += deltaLen;
		float angleTotal = 0;
		int angles = 0;
		println("i:"+i+" total:"+total+" target:"+target);
		while(total<target && pointIndex < points.size()){
			println("pointIndex:"+pointIndex);
			PVector prev = points.get(pointIndex-1);
			PVector current = points.get(pointIndex);
			angleTotal += atan2(current.y - prev.y, current.x - prev.x);
			angles++;
			total += prev.dist(current);
			pointIndex++;
		}
		if(angles>0){
			lastAngle = angleTotal/angles;
		}

		inputs.append(cos(lastAngle));
		inputs.append(sin(lastAngle));
	}
	println("inputs:"+inputs);
}

FloatList createFloatList(float...numbers) {
	FloatList result = new FloatList();
	for(int i=0; i<numbers.length; i++){
		result.append(numbers[i]);
	}
	return result;
}

class TrainingPair
{
	FloatList inputs, outputs;

	TrainingPair(FloatList in, FloatList out) {
		inputs = in;
		outputs = out;
	}
}