class Net
{
int numOfInputs, numOfHiddens, numOfOutputs;
ArrayList<Node> inputNodes, hiddenNodes, outputNodes;
Net(int inputs, int hiddens, int outputs, float learningRate, float momentum) {
// set net component
numOfInputs = inputs;
numOfHiddens = hiddens;
numOfOutputs = outputs;
createNet(learningRate, momentum);
}
void createNet(float learningRate, float momentum) {
// input layer
inputNodes = new ArrayList<Node>();
for(int i=0; i<numOfInputs; i++) {
inputNodes.add(new InputNode());
}
inputNodes.add(new InputNode(-1));
// hidden layer
hiddenNodes = new ArrayList<Node>();
for(int j=0; j<numOfHiddens; j++) {
hiddenNodes.add(new HiddenNode(learningRate));
}
hiddenNodes.add(new InputNode(-1));
// output layer
outputNodes = new ArrayList<Node>();
for(int k=0; k<numOfOutputs; k++) {
outputNodes.add(new OutputNode(learningRate));
}
// connet inputlayer with hidden layer
for(int i=0; i<numOfInputs+1; i++) {
for (int j = 0; j<numOfHiddens; j++){
Arc arc = new Arc(momentum);
inputNodes.get(i).connect(arc, hiddenNodes.get(j));
}
}
// connet hidden layer with output layer
for (int j = 0; j<numOfHiddens+1; j++){
for (int k = 0; k<numOfOutputs; k++){
Arc arc = new Arc(momentum);
hiddenNodes.get(j).connect(arc, outputNodes.get(k));
}
}
}
FloatList run(FloatList inputs) {
if(inputs.size()!=numOfInputs) {
println("number of inputs is wrong!");
exit();
}
FloatList outputs = new FloatList();
for(int i=0; i<numOfInputs; i++){
inputNodes.get(i).update(inputs.get(i));
}
for(int j=0; j<numOfHiddens; j++){
hiddenNodes.get(j).update();
}
for(int k=0; k<numOfOutputs; k++){
outputs.append(outputNodes.get(k).update());
}
return outputs;
}
void train(ArrayList<TrainingPair> trainingSets) {
float error = 0;
for(TrainingPair tp:trainingSets){
FloatList inputs = tp.inputs;
FloatList outputs = run(inputs);
for(int k=0; k<numOfOutputs; k++) {
OutputNode outputNode = (OutputNode)(outputNodes.get(k));
outputNode.setError(tp.outputs.get(k));
outputNode.train();
error += abs(outputNode.error);
}
for(int j=0; j<numOfHiddens; j++) {
HiddenNode hiddenNode = (HiddenNode)(hiddenNodes.get(j));
hiddenNode.train();
error += abs(hiddenNode.error);
}
}
// println(error);
}
// render net on a canvas
void render(PGraphics pg) {
pg.beginDraw();
pg.rectMode(RADIUS);
pg.ellipseMode(RADIUS);
pg.textAlign(CENTER, CENTER);
int x, xdist, y, ydist, nodeRadius;
nodeRadius = 25;
x = xdist = pg.width/5;
y = pg.height / 2;
ydist = pg.height/(max(numOfInputs, numOfHiddens, numOfOutputs)+1);
int[] numOfLayers = {numOfInputs, numOfHiddens, numOfOutputs};
IntList nodeY, nodeOldY;
nodeOldY = new IntList();
for(int l=0; l<3; l++) {
int numOfLayer = numOfLayers[l];
nodeY = new IntList();
for (int i = 0; i<numOfLayer; i++){
if(i==0 && numOfLayer%2==0) {
y -= ydist/2;
}
ArrayList<Node> layer = new ArrayList<Node>();
switch (l){
case 0 :
layer = inputNodes;
break;
case 1 :
layer = hiddenNodes;
break;
case 2 :
layer = outputNodes;
break;
default :
println("wrong index of layer!");
exit();
}
Node node = layer.get(i);
// draw arcs
for(int j=0; j<nodeOldY.size(); j++) {
pg.line(x,y,x-xdist+nodeRadius,nodeOldY.get(j));
pg.fill(255);
pg.text(node.inArcs.get(j).weight,x-xdist/2,(nodeOldY.get(j)+y)/2);
}
if(l==2) {
pg.line(x,y,x+xdist,y);
}
// draw nodes
pg.fill(255);
if(l==0)
pg.rect(x, y, nodeRadius, nodeRadius);
else
pg.ellipse(x, y, nodeRadius, nodeRadius);
pg.fill(0);
pg.text(node.value, x, y);
nodeY.append(y);
y += pow(-1, i%2)*(i+1)*ydist;
}
nodeOldY = nodeY;
x += xdist;
y = pg.height / 2;
}
pg.endDraw();
}
};
// arc for connect nodes
class Arc
{
Node in, out;
float weight, delta, momentum;
Arc(float momentum) {
weight = random(-1, 1);
this.momentum = momentum;
delta = 0;
}
void setInputNode(Node node) {
in = node;
}
void setOutputNode(Node node) {
out = node;
}
float getInputValue() {
return in.value;
}
float getWeightedInputValue() {
return in.value * weight;
}
float getWeightedOutputError() {
return out.error * weight;
}
void updateWeight(float delta) {
this.weight += delta + this.momentum * this.delta;
this.delta = delta;
}
};
class Node
{
ArrayList<Arc> inArcs, outArcs;
float value, error;
Node() {
inArcs = new ArrayList<Arc>();
outArcs = new ArrayList<Arc>();
error = 0;
value = 0;
}
// connect node with new arc;
void connect(Node node) {
Arc arc = new Arc(0);
this.connect(arc, node);
}
// connect node with an arc;
void connect(Arc arc, Node node) {
this.addOutputArc(arc);
node.addInputArc(arc);
arc.setInputNode(this);
arc.setOutputNode(node);
}
void addInputArc(Arc arc) {
inArcs.add(arc);
}
void addOutputArc(Arc arc) {
outArcs.add(arc);
}
float update(float value) {
return this.value = value;
}
float update() {
float result = 0;
for(Arc arc:inArcs) {
result += arc.getWeightedInputValue();
}
return this.value = sigmoidTransfer(result);
}
float sigmoidTransfer(float value) {
return 1.0/(1.0+exp(-value));
}
};
class InputNode extends Node
{
InputNode() {
super();
}
InputNode(float value) {
this.value = value;
}
};
class HiddenNode extends Node
{
float learningRate;
HiddenNode(float learningRate) {
this.learningRate = learningRate;
}
void train() {
float total = 0;
for(Arc arc:outArcs){
total += arc.getWeightedOutputError();
}
this.error = value*(1-value)*total;
for(Arc arc:inArcs) {
arc.updateWeight(arc.getInputValue() * error * learningRate);
}
}
};
class OutputNode extends Node
{
float learningRate;
OutputNode(float learningRate) {
this.learningRate = learningRate;
}
void setError(float target) {
this.error = (target - value) * value * (1 - value);
}
void train() {
for(Arc arc:inArcs) {
arc.weight += arc.getInputValue() * error * learningRate;
}
}
};