Deep Restricted Boltzmann Machine - Java

Posted on by Kenny Cason
tags = [ rbm, restricted boltzmann machine, AI, Machine Learning, Java ]

The initial RBM Contrastive Divergence algorithm implemented from this blog.

This version includes image encoding/decoding schemes, Contrastive Divergence training for a single RBM, deep RBM, and recurrent RBMs. Uses Parallel Colt for matrix processing. Also includes a Multithreaded Deep RBM.

Source Code available on GitHub: here

Results RBM(visual=6,hidden=4)

1
2
3
4
5
6
7
8
9
10
11
12
13
Training Data:
[[1.0, 1.0, 1.0, 0.0, 0.0, 0.0]
 [1.0, 0.0, 1.0, 0.0, 0.0, 0.0]
 [1.0, 1.0, 1.0, 0.0, 0.0, 0.0]
 [0.0, 0.0, 1.0, 1.0, 1.0, 0.0]
 [0.0, 0.0, 1.0, 1.0, 0.0, 0.0]
 [0.0, 0.0, 1.0, 1.0, 1.0, 0.0]]

Input:  [[0.0, 0.0, 0.0, 1.0, 1.0, 0.0]]
Output: [[0.0, 0.0, 1.0, 1.0, 1.0, 0.0]]

Inputs: [[0.0, 0.0, 0.0, 1.0, 1.0, 0.0] [0.0, 0.0, 1.0, 1.0, 0.0, 0.0]]
Outputs: [0.0, 0.0, 1.0, 1.0, 1.0, 0.0] [0.0, 0.0, 1.0, 1.0, 0.0, 0.0]]

Code for above Output:

1
2
3
4
5
6
7
8
9
10
11
final RBM rbm = RBM_FACTORY.build(6, 3);
final ContrastiveDivergence contrastiveDivergence = new ContrastiveDivergence(new LearningParameters().setEpochs(25000));

contrastiveDivergence.learn(rbm, buildBetterSampleTrainingData());

// fetch two recommendations
final Matrix testData = DenseMatrix.make(new double[][]{{0, 0, 0, 1, 1, 0}, {0, 0, 1, 1, 0, 0}});
final Matrix hidden = contrastiveDivergence.runVisible(rbm, testData);
LOGGER.info(testData);
final Matrix visual = contrastiveDivergence.runHidden(rbm, hidden);
LOGGER.info(visual);

Image Recognition

Shallow RBM - Input a 100x63 pixel image of a fighter jet at 24bit color resolution. Each RGB value is encoded as a 24 bit vector making a total input size of 100 x 24 x 63 bits.
Input
RBM Generated 24 bit
RBM Generated 8 bit
Deep RBM - Input a 400250 pixel image of a fighter jet at 24bit color resolution. Each RGB value is encoded as a 24 bit vector making a total input size of 400 24 * 250 bits. That’s 2.4 Million inputs to be learned.
Input
1 Epoch
11 Epochs
RBM - Learn 9 Pokemon Image (Full dataset contains 151 pokemon) 60x60 pixels, 24bit resolution.
RBM - Note how having a white BG (max value input) negatively affects learning, where as a Black (zero value input) converges quickly. They were trained on identical RBMs for the same number of epochs.

Number Recognition

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// INPUT
INFO  nn.rbm.TestRBM -
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□■■□■■■□□□□□
□□□□□□□□□□□■■■■■■■■■■■■□□□□□
□□□□□□□□■■■■■■■■■■□□□□□□□□□□
□□□□□□□□■■■■■■■■■■□□□□□□□□□□
□□□□□□□□□■□■■■□□□■□□□□□□□□□□
□□□□□□□□□□□■■□□□□□□□□□□□□□□□
□□□□□□□□□□□■■■□□□□□□□□□□□□□□
□□□□□□□□□□□□■■□□□□□□□□□□□□□□
□□□□□□□□□□□□□■■■□□□□□□□□□□□□
□□□□□□□□□□□□□□■■■□□□□□□□□□□□
□□□□□□□□□□□□□□□■■■■□□□□□□□□□
□□□□□□□□□□□□□□□□□■■■□□□□□□□□
□□□□□□□□□□□□□□□□□■■■□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□■■■■■■■□□□□□□□□
□□□□□□□□□□□□■■■■■■□□□□□□□□□□
□□□□□□□□□□■■■■■■□□□□□□□□□□□□
□□□□□□□■■■■■■■□□□□□□□□□□□□□□
□□□□□■■■■■■■■□□□□□□□□□□□□□□□
□□□□■■■■■■■□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□

INFO  nn.rbm.TestOldRBM -
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□■■■□□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□■■■■■■■□□□□□□□□
□□□□□□□□□□□□■■■■■■■□□□□□□□□□
□□□□□□□□□□□□■■□□■■■□□□□□□□□□
□□□□□□□□□□□■■□□■■■□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□□■■□■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■□■■□□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□■■□□□□□□□□□□□□
□□□□□□□□□□■■■■■■□□□□□□□□□□□□
□□□□□□□□□□□■■■■□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□

...

INFO  nn.rbm.learn.OldContrastiveDivergence - Start Learning (7 samples)
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 0/15000, error: 1305.5197925558577, time: 0.059s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 100/15000, error: 57.300594478427854, time: 0.004s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 200/15000, error: 15.952329441261893, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 300/15000, error: 5.4044291068371155, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 400/15000, error: 2.602268788842556, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 500/15000, error: 1.4970038901297982, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 600/15000, error: 1.1067551950980756, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 700/15000, error: 0.8295110168889177, time: 0.003s
...
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 14400/15000, error: 0.002183322948934887, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 14500/15000, error: 0.0018464431984471126, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 14600/15000, error: 0.002316604784920346, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 14700/15000, error: 0.015824371649477142, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 14800/15000, error: 0.0033692543108419077, time: 0.003s
INFO  nn.rbm.learn.OldContrastiveDivergence - Epoch: 14900/15000, error: 0.006265503532066407, time: 0.003s
INFO  nn.rbm.TestOldRBM - Data Index: 0
INFO  nn.rbm.TestOldRBM -

INFO  nn.rbm.TestOldRBM -
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□■■■■□■■■■□□□□
□□□□□□□□□□□■■■■■■■■■■■■□□□□□
□□□□□□□□■■■■■■■■■■□□■□■□□□□□
□□□□□□□■■■■■■■■■■■□□□□□□□□□□
□□□□□□□□□□□■■■□□■■□□□□□□□□□□
□□□□□□□□□□□□■□□□□□□□□□□□□□□□
□□□□□□□□□□□■■■□□□□□□□□□□□□□□
□□□□□□□□□□□□■■□□□□□□□□□□□□□□
□□□□□□□□□□□□□■■□□□□□□□□□□□□□
□□□□□□□□□□□□□□■■■■□□□□□□□□□□
□□□□□□□□□□■□□□□■■■■□□□□□□□□□
□□□□□□□□□□□□□□□□□■■□□□□□□□□□
□□□□□□□□□□□□□□□□□■■■□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□■■■■■■□□□□□□□□□
□□□□□□□□□□■■■■■■■■□□□□□□□□□□
□□□□□□□□□□■■■■■□□□□□□□□□□□□□
□□□□□□□■■■■■■■□□□□□□□□□□□□□□
□□□□■■□■■■■■□□□□□□□□□□□□□□□□
□□□□□■■■■□■□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□

INFO  nn.rbm.TestOldRBM -
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□■■□□□□□□□□□
□□□□□□□□□□□□□□□■■■■□□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□■■■■■■□□□□□□□□□
□□□□□□□□□□□□■■■■■■■■□□□□□□□□
□□□□□□□□□□□■■■■□■■■□□□□□□□□□
□□□□□□□□□□□■■□□□■□□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□□■■□■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■□□□□□□□□□□□□□
□□□□□□□□□□□□■■□■■□□□□□□□□□□□
□□□□□□□□□□□□■□□■■□□□□□□□□□□□
□□□□□□□□□□□■■□□■■■□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□□■■■□□□□□□□□□□
□□□□□□□□□□■□□□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□■■■□□□□□□□□□□□
□□□□□□□□□□■■■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□

...
comments powered by Disqus