Deep Restricted Boltzmann Machine - Java
The initial RBM Contrastive Divergence algorithm implemented from this blog.
This version includes image encoding/decoding schemes, Contrastive Divergence training for a single RBM, deep RBM, and recurrent RBMs. Uses Parallel Colt for matrix processing. Also includes a Multithreaded Deep RBM.
Source Code available on GitHub: here
Results RBM(visual=6,hidden=4)
Training Data:
[[1.0, 1.0, 1.0, 0.0, 0.0, 0.0]
[1.0, 0.0, 1.0, 0.0, 0.0, 0.0]
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 1.0, 1.0, 1.0, 0.0]
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0]
[0.0, 0.0, 1.0, 1.0, 1.0, 0.0]]
Input: [[0.0, 0.0, 0.0, 1.0, 1.0, 0.0]]
Output: [[0.0, 0.0, 1.0, 1.0, 1.0, 0.0]]
Inputs: [[0.0, 0.0, 0.0, 1.0, 1.0, 0.0] [0.0, 0.0, 1.0, 1.0, 0.0, 0.0]]
Outputs: [0.0, 0.0, 1.0, 1.0, 1.0, 0.0] [0.0, 0.0, 1.0, 1.0, 0.0, 0.0]]
Code for above Output:
final RBM rbm = RBM_FACTORY.build(6, 3);
final ContrastiveDivergence contrastiveDivergence = new ContrastiveDivergence(new LearningParameters().setEpochs(25000));
contrastiveDivergence.learn(rbm, buildBetterSampleTrainingData());
// fetch two recommendations
final Matrix testData = DenseMatrix.make(new double[][]{{0, 0, 0, 1, 1, 0}, {0, 0, 1, 1, 0, 0}});
final Matrix hidden = contrastiveDivergence.runVisible(rbm, testData);
LOGGER.info(testData);
final Matrix visual = contrastiveDivergence.runHidden(rbm, hidden);
LOGGER.info(visual);
Image Recognition
Shallow RBM - Input a 100x63 pixel image of a fighter jet at 24bit color resolution. Each RGB value is encoded as a 24 bit vector making a total input size of 100 x 24 x 63 bits.
Input |
RBM Generated 24 bit |
RBM Generated 8 bit |
Deep RBM - Input a 400*250 pixel image of a fighter jet at 24bit color resolution. Each RGB value is encoded as a 24 bit vector making a total input size of 400 * 24 * 250 bits. That's 2.4 Million inputs to be learned.
Input |
1 Epoch |
11 Epochs |
RBM - Learn 9 Pokemon Image (Full dataset contains 151 pokemon) 60x60 pixels, 24bit resolution.
RBM - Note how having a white BG (max value input) negatively affects learning, where as a Black (zero value input) converges quickly. They were trained on identical RBMs for the same number of epochs.
Number Recognition
// INPUT
INFO nn.rbm.TestRBM -
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□■■□■■■□□□□□
□□□□□□□□□□□■■■■■■■■■■■■□□□□□
□□□□□□□□■■■■■■■■■■□□□□□□□□□□
□□□□□□□□■■■■■■■■■■□□□□□□□□□□
□□□□□□□□□■□■■■□□□■□□□□□□□□□□
□□□□□□□□□□□■■□□□□□□□□□□□□□□□
□□□□□□□□□□□■■■□□□□□□□□□□□□□□
□□□□□□□□□□□□■■□□□□□□□□□□□□□□
□□□□□□□□□□□□□■■■□□□□□□□□□□□□
□□□□□□□□□□□□□□■■■□□□□□□□□□□□
□□□□□□□□□□□□□□□■■■■□□□□□□□□□
□□□□□□□□□□□□□□□□□■■■□□□□□□□□
□□□□□□□□□□□□□□□□□■■■□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□■■■■■■■□□□□□□□□
□□□□□□□□□□□□■■■■■■□□□□□□□□□□
□□□□□□□□□□■■■■■■□□□□□□□□□□□□
□□□□□□□■■■■■■■□□□□□□□□□□□□□□
□□□□□■■■■■■■■□□□□□□□□□□□□□□□
□□□□■■■■■■■□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
INFO nn.rbm.TestOldRBM -
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□■■■□□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□■■■■■■■□□□□□□□□
□□□□□□□□□□□□■■■■■■■□□□□□□□□□
□□□□□□□□□□□□■■□□■■■□□□□□□□□□
□□□□□□□□□□□■■□□■■■□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□□■■□■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■□■■□□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□■■□□□□□□□□□□□□
□□□□□□□□□□■■■■■■□□□□□□□□□□□□
□□□□□□□□□□□■■■■□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
...
INFO nn.rbm.learn.OldContrastiveDivergence - Start Learning (7 samples)
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 0/15000, error: 1305.5197925558577, time: 0.059s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 100/15000, error: 57.300594478427854, time: 0.004s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 200/15000, error: 15.952329441261893, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 300/15000, error: 5.4044291068371155, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 400/15000, error: 2.602268788842556, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 500/15000, error: 1.4970038901297982, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 600/15000, error: 1.1067551950980756, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 700/15000, error: 0.8295110168889177, time: 0.003s
...
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 14400/15000, error: 0.002183322948934887, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 14500/15000, error: 0.0018464431984471126, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 14600/15000, error: 0.002316604784920346, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 14700/15000, error: 0.015824371649477142, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 14800/15000, error: 0.0033692543108419077, time: 0.003s
INFO nn.rbm.learn.OldContrastiveDivergence - Epoch: 14900/15000, error: 0.006265503532066407, time: 0.003s
INFO nn.rbm.TestOldRBM - Data Index: 0
INFO nn.rbm.TestOldRBM -
INFO nn.rbm.TestOldRBM -
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□■■■■□■■■■□□□□
□□□□□□□□□□□■■■■■■■■■■■■□□□□□
□□□□□□□□■■■■■■■■■■□□■□■□□□□□
□□□□□□□■■■■■■■■■■■□□□□□□□□□□
□□□□□□□□□□□■■■□□■■□□□□□□□□□□
□□□□□□□□□□□□■□□□□□□□□□□□□□□□
□□□□□□□□□□□■■■□□□□□□□□□□□□□□
□□□□□□□□□□□□■■□□□□□□□□□□□□□□
□□□□□□□□□□□□□■■□□□□□□□□□□□□□
□□□□□□□□□□□□□□■■■■□□□□□□□□□□
□□□□□□□□□□■□□□□■■■■□□□□□□□□□
□□□□□□□□□□□□□□□□□■■□□□□□□□□□
□□□□□□□□□□□□□□□□□■■■□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□■■■■■■□□□□□□□□□
□□□□□□□□□□■■■■■■■■□□□□□□□□□□
□□□□□□□□□□■■■■■□□□□□□□□□□□□□
□□□□□□□■■■■■■■□□□□□□□□□□□□□□
□□□□■■□■■■■■□□□□□□□□□□□□□□□□
□□□□□■■■■□■□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
INFO nn.rbm.TestOldRBM -
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□■■□□□□□□□□□
□□□□□□□□□□□□□□□■■■■□□□□□□□□□
□□□□□□□□□□□□□□□■■■■■□□□□□□□□
□□□□□□□□□□□□□■■■■■■□□□□□□□□□
□□□□□□□□□□□□■■■■■■■■□□□□□□□□
□□□□□□□□□□□■■■■□■■■□□□□□□□□□
□□□□□□□□□□□■■□□□■□□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□□■■□■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■■□□□□□□□□□□□□□
□□□□□□□□□□□□■■□■■□□□□□□□□□□□
□□□□□□□□□□□□■□□■■□□□□□□□□□□□
□□□□□□□□□□□■■□□■■■□□□□□□□□□□
□□□□□□□□□□□■■□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□□■■■□□□□□□□□□□
□□□□□□□□□□■□□□□■■□□□□□□□□□□□
□□□□□□□□□□■■□□■■■□□□□□□□□□□□
□□□□□□□□□□■■■■■■□□□□□□□□□□□□
□□□□□□□□□□□□■■□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
□□□□□□□□□□□□□□□□□□□□□□□□□□□□
...