1 public class NeuralNetwork {
2
3 private int ni, nh, no;
4
5 private double [] ai, ah, ao;
6
7 private double [][] wi;
8 private double [][] wo;
9
10 private double [][] ci;
11 private double [][] co;
12
13 public NeuralNetwork(int ni, int nh, int no) {
14
15 this.ni = ni + 1;
16 this.nh = nh;
17 this.no = no;
18
19 ai = new double[this.ni];
20
21 ah = new double[this.nh];
22 ao = new double[this.no];
23
24 ai[ni] = 1;
25
26 wi = new double[this.ni][this.nh];
27 wo = new double[this.nh][this.no];
28
29
30 for(int i = 0; i < this.ni; i++)
31 for(int j = 0; j < this.nh; j++) wi[i][j] = Client.nextDouble() / 10;
32
33 for(int i = 0; i < this.nh; i++)
34 for(int j = 0; j < this.no; j++) wo[i][j] = Client.nextDouble() / 10;
35
36 ci = new double[this.ni][this.nh];
37 co = new double[this.nh][this.no];
38 }
39
40 public void train(double [][][] patterns) { train(patterns, 1000, 0.5, 0.1); }
41 public void train(double [][][] patterns, int iterations, double lRate, double mFactor) {
42 double error;
43 double [] inputs, targets;
44 for(int i = 0; i < iterations; i++) {
45 error = 0.0;
46 for(int j = 0; j < patterns.length; j++) {
47 inputs = patterns[j][0];
48 targets = patterns[j][1];
49 update(inputs);
50 error += backPropagate(targets, lRate, mFactor);
51 }
52 if(i % 50 == 0) System.out.println(error);
53 }
54 }
55
56 public double test(double [][][] patterns) {
57 double [] r;
58 int incorrect = 0;
59 for(int i = 0; i < patterns.length; i++) {
60 r = update(patterns[i][0]);
61 if(r[0] >= r[1] && patterns[i][1][0] == 0) incorrect++;
62 else if(r[0] < r[1] && patterns[i][1][1] == 0) incorrect++;
63
64 System.out.print(patterns[i][0][0] + " " + patterns[i][0][1] + " " + r[0] + " " + r[1]);
65 if(r[0] >= r[1]) {
66 System.out.print(" -> A -> ");
67 } else {
68 System.out.print(" -> B -> ");
69 }
70 if(patterns[i][1][0] == 1) {
71 System.out.println("A");
72 } else {
73 System.out.println("B");
74 }
75 }
76 return incorrect / (double)patterns.length;
77 }
78
79 private double[] update(double [] inputs) {
80 double sum;
81
82 for(int i = 0; i < ni - 1; i++) ai[i] = inputs[i];
83
84 for(int i = 0; i < nh; i++) {
85 sum = 0;
86 for(int j = 0; j < ni; j++) sum += ai[j] * wi[j][i];
87 ah[i] = sigmoid(sum);
88 }
89
90 for(int i = 0; i < no; i++) {
91 sum = 0;
92 for(int j = 0; j < nh; j++) sum += ah[j] * wo[j][i];
93 ao[i] = sigmoid(sum);
94 }
95 return ao;
96 }
97
98 private double backPropagate(double [] targets, double lRate, double mFactor) {
99 double [] output_deltas = new double[no];
100 double [] hidden_deltas = new double[nh];
101 double error, change;
102
103 for(int i = 0; i < no; i++) {
104 error = targets[i] - ao[i];
105 output_deltas[i] = dsigmoid(ao[i]) * error;
106 }
107
108 for(int i = 0; i < nh; i++) {
109 error = 0.0;
110 for(int j = 0; j < no; j++) error += output_deltas[j] * wo[i][j];
111 hidden_deltas[i] = dsigmoid(ah[i]) * error;
112 }
113
114 for(int i = 0; i < nh; i++) {
115 for(int j = 0; j < no; j++) {
116 change = output_deltas[j] * ah[i];
117 wo[i][j] += lRate * change + mFactor * co[i][j];
118 co[i][j] = change;
119 }
120 }
121
122 for(int i = 0; i < ni; i++) {
123 for(int j = 0; j < nh; j++) {
124 change = hidden_deltas[j] * ai[i];
125 wi[i][j] += lRate * change + mFactor * ci[i][j];
126 ci[i][j] = change;
127 }
128 }
129
130 error = 0.0;
131 for(int i = 0; i < targets.length; i++) error += 0.5 * Math.pow(targets[i] - ao[i], 2);
132 return error;
133 }
134
135 private double sigmoid(double x) { return 1 / (1 + Math.exp(-x)); }
136 private double dsigmoid(double x) { return x * (1 - x); }
137 }
138