mutoo
11/5/2013 - 8:08 AM

ann_xor.pde

class TrainingPair
{
	FloatList inputs, outputs;

	TrainingPair(FloatList in, FloatList out) {
		inputs = in;
		outputs = out;
	}
}
class Net
{
	int numOfInputs, numOfHiddens, numOfOutputs;
	ArrayList<Node> inputNodes, hiddenNodes, outputNodes;

	Net(int inputs, int hiddens, int outputs, float learningRate, float momentum) {
		// set net component
		numOfInputs = inputs;
		numOfHiddens = hiddens;
		numOfOutputs = outputs;

		createNet(learningRate, momentum);
	}

	void createNet(float learningRate, float momentum) {
		// input layer
		inputNodes = new ArrayList<Node>();
		for(int i=0; i<numOfInputs; i++) {
			inputNodes.add(new InputNode());
		}

		inputNodes.add(new InputNode(-1));

		// hidden layer
		hiddenNodes = new ArrayList<Node>();
		for(int j=0; j<numOfHiddens; j++) {
			hiddenNodes.add(new HiddenNode(learningRate));
		}

		hiddenNodes.add(new InputNode(-1));

		// output layer
		outputNodes = new ArrayList<Node>();
		for(int k=0; k<numOfOutputs; k++) {
			outputNodes.add(new OutputNode(learningRate));
		}

		// connet inputlayer with hidden layer
		for(int i=0; i<numOfInputs+1; i++) {
			for (int j = 0; j<numOfHiddens; j++){
				Arc arc = new Arc(momentum);
				inputNodes.get(i).connect(arc, hiddenNodes.get(j));
			}
		}

		// connet hidden layer with output layer
		for (int j = 0; j<numOfHiddens+1; j++){
			for (int k = 0; k<numOfOutputs; k++){
				Arc arc = new Arc(momentum);
				hiddenNodes.get(j).connect(arc, outputNodes.get(k));
			}
		}
	}

	FloatList run(FloatList inputs) {
		if(inputs.size()!=numOfInputs) {
			println("number of inputs is wrong!");
			exit();
		}

		FloatList outputs = new FloatList();

		for(int i=0; i<numOfInputs; i++){
			inputNodes.get(i).update(inputs.get(i));
		}

		for(int j=0; j<numOfHiddens; j++){
			hiddenNodes.get(j).update();
		}

		for(int k=0; k<numOfOutputs; k++){
			outputs.append(outputNodes.get(k).update());
		}

		return outputs;
	}

	void train(ArrayList<TrainingPair> trainingSets) {
		float error = 0;
		for(TrainingPair tp:trainingSets){
			FloatList inputs = tp.inputs;
			FloatList outputs = run(inputs);

			for(int k=0; k<numOfOutputs; k++) {
				OutputNode outputNode = (OutputNode)(outputNodes.get(k));
				outputNode.setError(tp.outputs.get(k));
				outputNode.train();
				error += abs(outputNode.error);
			}

			for(int j=0; j<numOfHiddens; j++) {
				HiddenNode hiddenNode = (HiddenNode)(hiddenNodes.get(j));
				hiddenNode.train();
				error += abs(hiddenNode.error);
			}
		}
		// println(error);
	}

	// render net on a canvas
	void render(PGraphics pg) {
		pg.beginDraw();
		pg.rectMode(RADIUS);
		pg.ellipseMode(RADIUS);
		pg.textAlign(CENTER, CENTER);

		int x, xdist, y, ydist, nodeRadius;
		nodeRadius = 25;
		x = xdist = pg.width/5;
		y = pg.height / 2;
		ydist = pg.height/(max(numOfInputs, numOfHiddens, numOfOutputs)+1);

		int[] numOfLayers = {numOfInputs, numOfHiddens, numOfOutputs};
		IntList nodeY, nodeOldY;
		nodeOldY = new IntList();
		for(int l=0; l<3; l++) {
			int numOfLayer = numOfLayers[l];
			nodeY = new IntList();
			for (int i = 0; i<numOfLayer; i++){
				if(i==0 && numOfLayer%2==0) {
						y -= ydist/2;
				}

				ArrayList<Node> layer = new ArrayList<Node>();
				switch (l){
					case 0 :
						layer = inputNodes;
					break;	
					case 1 :
						layer = hiddenNodes;
					break;
					case 2 :
						layer = outputNodes;
					break;
					default :
						println("wrong index of layer!");
						exit();
				}

				Node node = layer.get(i);

				// draw arcs
				for(int j=0; j<nodeOldY.size(); j++) {
					pg.line(x,y,x-xdist+nodeRadius,nodeOldY.get(j));
					pg.fill(255);
					pg.text(node.inArcs.get(j).weight,x-xdist/2,(nodeOldY.get(j)+y)/2);
				}

				if(l==2) {
					pg.line(x,y,x+xdist,y);
				}

				// draw nodes
				pg.fill(255);
				if(l==0)
					pg.rect(x, y, nodeRadius, nodeRadius);
				else
					pg.ellipse(x, y, nodeRadius, nodeRadius);

				pg.fill(0);
				pg.text(node.value, x, y);

				nodeY.append(y);
				y += pow(-1, i%2)*(i+1)*ydist;
			}
			nodeOldY = nodeY;
			x += xdist;
			y = pg.height / 2;
		}
		

		pg.endDraw();
	}
};

// arc for connect nodes
class Arc
{
	Node in, out;
	float weight, delta, momentum;

	Arc(float momentum) {
		weight = random(-1, 1);
		this.momentum = momentum;
		delta = 0;
	}

	void setInputNode(Node node) {
		in = node;
	}

	void setOutputNode(Node node) {
		out = node;
	}

	float getInputValue() {
		return in.value;
	}

	float getWeightedInputValue() {
		return in.value * weight;
	}

	float getWeightedOutputError() {
		return out.error * weight;
	}

	void updateWeight(float delta) {
		this.weight += delta + this.momentum * this.delta;
		this.delta = delta;
	}
};

class Node
{
	ArrayList<Arc> inArcs, outArcs;
	float value, error;

	Node() {
		inArcs = new ArrayList<Arc>();
		outArcs = new ArrayList<Arc>();
		error = 0;
		value = 0;
	}

	// connect node with new arc;
	void connect(Node node) {
		Arc arc = new Arc(0);
		this.connect(arc, node);
	}

	// connect node with an arc;
	void connect(Arc arc, Node node) {
		this.addOutputArc(arc);
		node.addInputArc(arc);
		arc.setInputNode(this);
		arc.setOutputNode(node);
	}

	void addInputArc(Arc arc) {
		inArcs.add(arc);
	}

	void addOutputArc(Arc arc) {
		outArcs.add(arc);
	}

	float update(float value) {
		return this.value = value;
	}

	float update() {
		float result = 0;
		for(Arc arc:inArcs) {
			result += arc.getWeightedInputValue();
		}
		return this.value = sigmoidTransfer(result);
	}

	float sigmoidTransfer(float value) {
		return 1.0/(1.0+exp(-value));
	}
};

class InputNode extends Node
{
	InputNode() {
		super();
	}

	InputNode(float value) {
		this.value = value;
	}
};

class HiddenNode extends Node
{
	float learningRate;

	HiddenNode(float learningRate) {
		this.learningRate = learningRate;
	}

	void train() {
		float total = 0;
		for(Arc arc:outArcs){
			total += arc.getWeightedOutputError();
		}
		this.error = value*(1-value)*total;

		for(Arc arc:inArcs) {
			arc.updateWeight(arc.getInputValue() * error * learningRate);
		}
	}
};

class OutputNode extends Node
{
	float learningRate;

	OutputNode(float learningRate) {
		this.learningRate = learningRate;
	}

	void setError(float target) {
		this.error = (target - value) * value * (1 - value);
	}

	void train() {
		for(Arc arc:inArcs) {
			arc.weight += arc.getInputValue() * error * learningRate;
		}
	}
};


PGraphics pg;

Net net;
ArrayList<TrainingPair> trainingSet;

void setup(){
	size(512, 512);
	pg = createGraphics(width, height);
	net = new Net(2, 3, 1, 0.5, 0.9);

	trainingSet = new ArrayList<TrainingPair>();
	trainingSet.add(new TrainingPair(
		createFloatList(0,0),createFloatList(0)));
	trainingSet.add(new TrainingPair(
		createFloatList(0,1),createFloatList(1)));
	trainingSet.add(new TrainingPair(
		createFloatList(1,0),createFloatList(1)));
	trainingSet.add(new TrainingPair(
		createFloatList(1,1),createFloatList(0)));

	for(int i=0;i<1000;i++)
		net.train(trainingSet);

	index = 0;
}

int index;

void draw(){
	for(int i=0;i<50;i++)
		net.train(trainingSet);
	net.run(trainingSet.get(index).inputs);
	pg.beginDraw();
	pg.background(200);
	net.render(pg);
	pg.endDraw();
	image(pg, 0, 0);
}

void keyPressed() {
	index = (index + 1)%trainingSet.size();
}

void mousePressed() {
	setup();
}

FloatList createFloatList(float...numbers) {
	FloatList result = new FloatList();
	for(int i=0; i<numbers.length; i++){
		result.append(numbers[i]);
	}
	return result;
}