1    public class NeuralNetwork {
2        /* Number of input, hidden and output nodes */
3        private int ni, nh, no;
4        /* Activation vectors for the nodes */
5        private double [] ai, ah, ao;
6        /* Weight Matrices */
7        private double [][] wi; /* Weights between input nodes and hidden nodes */
8        private double [][] wo; /* Weights between hidden nodes and output nodes */
9        /* Momentum Matrices */
10       private double [][] ci;
11       private double [][] co;
12       
13       public NeuralNetwork(int ni, int nh, int no) {
14           /* +1 for the bias node */
15           this.ni = ni + 1;
16           this.nh = nh;
17           this.no = no;
18           
19           ai = new double[this.ni];
20   
21           ah = new double[this.nh];
22           ao = new double[this.no];
23           
24           ai[ni] = 1;
25           
26           wi = new double[this.ni][this.nh];
27           wo = new double[this.nh][this.no];
28           
29           /* Assign random weights to the input -> hidden node connections (Between -.1 and .1) */
30           for(int i = 0; i < this.ni; i++)
31               for(int j = 0; j < this.nh; j++) wi[i][j] = Client.nextDouble() / 10;
32           /* Assign random weights to the hidden -> output node connections */
33           for(int i = 0; i < this.nh; i++)
34               for(int j = 0; j < this.no; j++) wo[i][j] = Client.nextDouble() / 10;
35           
36           ci = new double[this.ni][this.nh];
37           co = new double[this.nh][this.no];
38       }
39       /* Assigns some default values for training */
40       public void train(double [][][] patterns) { train(patterns, 1000, 0.5, 0.1); }
41       public void train(double [][][] patterns, int iterations, double lRate, double mFactor) {
42           double error;
43           double [] inputs, targets;
44           for(int i = 0; i < iterations; i++) {
45               error = 0.0;
46               for(int j = 0; j < patterns.length; j++) {
47                   inputs = patterns[j][0];
48                   targets = patterns[j][1];
49                   update(inputs);
50                   error += backPropagate(targets, lRate, mFactor);
51               }
52               if(i % 50 == 0) System.out.println(error);
53           }   
54       }
55       
56       public double test(double [][][] patterns) {
57           double [] r;
58           int incorrect = 0;
59           for(int i = 0; i < patterns.length; i++) {
60               r = update(patterns[i][0]);
61               if(r[0] >= r[1] && patterns[i][1][0] == 0) incorrect++;
62               else if(r[0] < r[1] && patterns[i][1][1] == 0) incorrect++;
63               /* For debugging Purposes */
64               System.out.print(patterns[i][0][0] + " " + patterns[i][0][1] + " " + r[0] + " " + r[1]);
65               if(r[0] >= r[1]) {
66                   System.out.print(" -> A -> ");
67               } else {
68                   System.out.print(" -> B -> ");
69               }
70               if(patterns[i][1][0] == 1) {
71                   System.out.println("A");
72               } else {
73                   System.out.println("B");
74               }
75           }
76           return incorrect / (double)patterns.length;
77       }
78       
79       private double[] update(double [] inputs) {
80           double sum;
81           /* Input Activations */
82           for(int i = 0; i < ni - 1; i++) ai[i] = inputs[i];
83           /* Hidden Activations */
84           for(int i = 0; i < nh; i++) {
85               sum = 0;
86               for(int j = 0; j < ni; j++) sum += ai[j] * wi[j][i];
87               ah[i] = sigmoid(sum);
88           }
89           /* Output Activations */
90           for(int i = 0; i < no; i++) {
91               sum = 0;
92               for(int j = 0; j < nh; j++) sum += ah[j] * wo[j][i];
93               ao[i] = sigmoid(sum);
94           }
95           return ao;
96       }
97       
98       private double backPropagate(double [] targets, double lRate, double mFactor) {
99           double [] output_deltas = new double[no];
100          double [] hidden_deltas = new double[nh];
101          double error, change;
102          /* calculate deltas for output nodes */
103          for(int i = 0; i < no; i++) {
104              error = targets[i] - ao[i];
105              output_deltas[i] = dsigmoid(ao[i]) * error;
106          }
107          /* calculate deltas for hidden nodes */
108          for(int i = 0; i < nh; i++) {
109              error = 0.0;
110              for(int j = 0; j < no; j++) error += output_deltas[j] * wo[i][j];
111              hidden_deltas[i] = dsigmoid(ah[i]) * error;
112          }
113          /* update output weights */
114          for(int i = 0; i < nh; i++) {
115              for(int j = 0; j < no; j++) {
116                  change = output_deltas[j] * ah[i];
117                  wo[i][j] += lRate * change + mFactor * co[i][j];
118                  co[i][j] = change;
119              }
120          }
121          /* update input weights */
122          for(int i = 0; i < ni; i++) {
123              for(int j = 0; j < nh; j++) {
124                  change = hidden_deltas[j] * ai[i];
125                  wi[i][j] += lRate * change + mFactor * ci[i][j];
126                  ci[i][j] = change;
127              }
128          }
129          /* calculate error */
130          error = 0.0;
131          for(int i = 0; i < targets.length; i++) error += 0.5 * Math.pow(targets[i] - ao[i], 2);
132          return error;
133      }
134  
135      private double sigmoid(double x) { return 1 / (1 + Math.exp(-x)); }
136      private double dsigmoid(double x) { return x * (1 - x); }
137  }
138