ann_gesture_recognition
class Net
{
int numOfInputs, numOfHiddens, numOfOutputs;
ArrayList<Node> inputNodes, hiddenNodes, outputNodes;
Net(int inputs, int hiddens, int outputs, float learningRate, float momentum) {
// set net component
numOfInputs = inputs;
numOfHiddens = hiddens;
numOfOutputs = outputs;
createNet(learningRate, momentum);
}
void createNet(float learningRate, float momentum) {
// input layer
inputNodes = new ArrayList<Node>();
for(int i=0; i<numOfInputs; i++) {
inputNodes.add(new InputNode());
}
inputNodes.add(new InputNode(-1));
// hidden layer
hiddenNodes = new ArrayList<Node>();
for(int j=0; j<numOfHiddens; j++) {
hiddenNodes.add(new HiddenNode(learningRate));
}
hiddenNodes.add(new InputNode(-1));
// output layer
outputNodes = new ArrayList<Node>();
for(int k=0; k<numOfOutputs; k++) {
outputNodes.add(new OutputNode(learningRate));
}
// connet inputlayer with hidden layer
for(int i=0; i<numOfInputs+1; i++) {
for (int j = 0; j<numOfHiddens; j++){
Arc arc = new Arc(momentum);
inputNodes.get(i).connect(arc, hiddenNodes.get(j));
}
}
// connet hidden layer with output layer
for (int j = 0; j<numOfHiddens+1; j++){
for (int k = 0; k<numOfOutputs; k++){
Arc arc = new Arc(momentum);
hiddenNodes.get(j).connect(arc, outputNodes.get(k));
}
}
}
FloatList run(FloatList inputs) {
if(inputs.size()!=numOfInputs) {
println("number of inputs is wrong!");
exit();
}
FloatList outputs = new FloatList();
for(int i=0; i<numOfInputs; i++){
inputNodes.get(i).update(inputs.get(i));
}
for(int j=0; j<numOfHiddens; j++){
hiddenNodes.get(j).update();
}
for(int k=0; k<numOfOutputs; k++){
outputs.append(outputNodes.get(k).update());
}
return outputs;
}
void train(ArrayList<TrainingPair> trainingSets) {
float error = 0;
for(TrainingPair tp:trainingSets){
FloatList inputs = tp.inputs;
FloatList outputs = run(inputs);
for(int k=0; k<numOfOutputs; k++) {
OutputNode outputNode = (OutputNode)(outputNodes.get(k));
outputNode.setError(tp.outputs.get(k));
outputNode.train();
error += abs(outputNode.error);
}
for(int j=0; j<numOfHiddens; j++) {
HiddenNode hiddenNode = (HiddenNode)(hiddenNodes.get(j));
hiddenNode.train();
error += abs(hiddenNode.error);
}
}
// println(error);
}
// render net on a canvas
void render(PGraphics pg) {
pg.beginDraw();
pg.rectMode(RADIUS);
pg.ellipseMode(RADIUS);
pg.textAlign(CENTER, CENTER);
int x, xdist, y, ydist, nodeRadius;
nodeRadius = 5;
x = xdist = pg.width/5;
y = pg.height / 2;
ydist = pg.height/(max(numOfInputs, numOfHiddens, numOfOutputs)+1);
int[] numOfLayers = {numOfInputs, numOfHiddens, numOfOutputs};
IntList nodeY, nodeOldY;
nodeOldY = new IntList();
for(int l=0; l<3; l++) {
int numOfLayer = numOfLayers[l];
nodeY = new IntList();
for (int i = 0; i<numOfLayer; i++){
if(i==0 && numOfLayer%2==0) {
y -= ydist/2;
}
ArrayList<Node> layer = new ArrayList<Node>();
switch (l){
case 0 :
layer = inputNodes;
break;
case 1 :
layer = hiddenNodes;
break;
case 2 :
layer = outputNodes;
break;
default :
println("wrong index of layer!");
exit();
}
Node node = layer.get(i);
// draw arcs
for(int j=0; j<nodeOldY.size(); j++) {
pg.stroke((node.inArcs.get(j).weight+1)/2*128);
pg.line(x,y,x-xdist+nodeRadius,nodeOldY.get(j));
// pg.fill(255);
// pg.text(node.inArcs.get(j).weight,x-xdist/2,(nodeOldY.get(j)+y)/2);
}
pg.stroke(0);
if(l==2) {
pg.line(x,y,x+xdist,y);
}
// draw nodes
// pg.fill(255);
if(l==0) {
pg.fill((node.value+1)/2*255,0,0);
pg.rect(x, y, nodeRadius, nodeRadius);
} else {
pg.fill(0,node.value*255,0);
pg.ellipse(x, y, nodeRadius, nodeRadius);
}
// pg.fill(0);
// pg.text(node.value, x, y);
nodeY.append(y);
y += pow(-1, i%2)*(i+1)*ydist;
}
nodeOldY = nodeY;
x += xdist;
y = pg.height / 2;
}
pg.endDraw();
}
};
// arc for connect nodes
class Arc
{
Node in, out;
float weight, delta, momentum;
Arc(float momentum) {
weight = random(-1, 1);
this.momentum = momentum;
delta = 0;
}
void setInputNode(Node node) {
in = node;
}
void setOutputNode(Node node) {
out = node;
}
float getInputValue() {
return in.value;
}
float getWeightedInputValue() {
return in.value * weight;
}
float getWeightedOutputError() {
return out.error * weight;
}
void updateWeight(float delta) {
this.weight += delta + this.momentum * this.delta;
this.delta = delta;
}
};
class Node
{
ArrayList<Arc> inArcs, outArcs;
float value, error;
Node() {
inArcs = new ArrayList<Arc>();
outArcs = new ArrayList<Arc>();
error = 0;
value = 0;
}
// connect node with new arc;
void connect(Node node) {
Arc arc = new Arc(0);
this.connect(arc, node);
}
// connect node with an arc;
void connect(Arc arc, Node node) {
this.addOutputArc(arc);
node.addInputArc(arc);
arc.setInputNode(this);
arc.setOutputNode(node);
}
void addInputArc(Arc arc) {
inArcs.add(arc);
}
void addOutputArc(Arc arc) {
outArcs.add(arc);
}
float update(float value) {
return this.value = value;
}
float update() {
float result = 0;
for(Arc arc:inArcs) {
result += arc.getWeightedInputValue();
}
return this.value = sigmoidTransfer(result);
}
float sigmoidTransfer(float value) {
return 1.0/(1.0+exp(-value));
}
};
class InputNode extends Node
{
InputNode() {
super();
}
InputNode(float value) {
this.value = value;
}
};
class HiddenNode extends Node
{
float learningRate;
HiddenNode(float learningRate) {
this.learningRate = learningRate;
}
void train() {
float total = 0;
for(Arc arc:outArcs){
total += arc.getWeightedOutputError();
}
this.error = value*(1-value)*total;
for(Arc arc:inArcs) {
arc.updateWeight(arc.getInputValue() * error * learningRate);
}
}
};
class OutputNode extends Node
{
float learningRate;
OutputNode(float learningRate) {
this.learningRate = learningRate;
}
void setError(float target) {
this.error = (target - value) * value * (1 - value);
}
void train() {
for(Arc arc:inArcs) {
arc.weight += arc.getInputValue() * error * learningRate;
}
}
};
PGraphics pg;
Net net;
ArrayList<TrainingPair> trainingSet;
void setup(){
size(512, 512);
pg = createGraphics(width, 512);
net = new Net(24, 12, 5, 0.5, 0.9);
createTrainingSet();
for(int i=0;i<10;i++)
net.train(trainingSet);
points = new ArrayList<PVector>();
inputs = createFloatList(1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0);
}
void createTrainingSet() {
trainingSet = new ArrayList<TrainingPair>();
trainingSet.add(new TrainingPair(
createFloatList(1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0),createFloatList(1,0,0,0,0)));
trainingSet.add(new TrainingPair(
createFloatList(-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0),createFloatList(0,1,0,0,0)));
trainingSet.add(new TrainingPair(
createFloatList(0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1),createFloatList(0,0,1,0,0)));
trainingSet.add(new TrainingPair(
createFloatList(0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1,0,-1),createFloatList(0,0,0,1,0)));
}
FloatList inputs;
void draw(){
for(int i=0;i<5;i++)
net.train(trainingSet);
net.run(inputs);
pg.beginDraw();
pg.background(200);
pg.stroke(0);
pg.strokeWeight(1);
net.render(pg);
pg.fill(0,0,255);
pg.stroke(0,0,255);
// pg.strokeWeight(1);
for(int i=1;i<points.size();i++){
PVector prev = points.get(i-1);
pg.line(prev.x, prev.y, points.get(i).x, points.get(i).y);
pg.ellipse(points.get(i).x, points.get(i).y, 3, 3);
}
pg.endDraw();
image(pg, 0, 0);
}
void keyPressed() {
if(key==' ') {
createTrainingSet();
trainingSet.add(new TrainingPair(inputs, createFloatList(0,0,0,0,1)));
net = new Net(24, 12, 5, 0.5, 0.9);
for(int i=0;i<100;i++)
net.train(trainingSet);
}
}
ArrayList<PVector> points;
boolean drawing = false;
void mousePressed() {
drawing = true;
points.clear();
}
void mouseDragged() {
if(drawing) {
// if(points.size()>0){
// PVector last = points.get(points.size()-1);
// if(dist(last.x, last.y, mouseX, mouseY)>50)
// points.add(new PVector(mouseX,mouseY));
// }else{
points.add(new PVector(mouseX,mouseY));
// }
}
}
void mouseReleased() {
drawing = false;
pointHandle();
}
void pointHandle() {
if(points.size()<2) {
println("points do not enough!");
return;
}
float totalLen = 0;
for(int i=1;i<points.size();i++) {
PVector prev = points.get(i-1);
totalLen += prev.dist(points.get(i));
}
float deltaLen = totalLen/13;
float target = 0;
float total = 0;
float lastAngle = 0;
int pointIndex = 1;
inputs = new FloatList();
for(int i=0; i<12; i++) {
target += deltaLen;
float angleTotal = 0;
int angles = 0;
println("i:"+i+" total:"+total+" target:"+target);
while(total<target && pointIndex < points.size()){
println("pointIndex:"+pointIndex);
PVector prev = points.get(pointIndex-1);
PVector current = points.get(pointIndex);
angleTotal += atan2(current.y - prev.y, current.x - prev.x);
angles++;
total += prev.dist(current);
pointIndex++;
}
if(angles>0){
lastAngle = angleTotal/angles;
}
inputs.append(cos(lastAngle));
inputs.append(sin(lastAngle));
}
println("inputs:"+inputs);
}
FloatList createFloatList(float...numbers) {
FloatList result = new FloatList();
for(int i=0; i<numbers.length; i++){
result.append(numbers[i]);
}
return result;
}
class TrainingPair
{
FloatList inputs, outputs;
TrainingPair(FloatList in, FloatList out) {
inputs = in;
outputs = out;
}
}