Restricted Boltzmann Machine with DeepLearnToolbox

In attempt to learn about deep learning’s fundamentals (and to fulfill a course assignment too), I tried to write a simple Restricted Boltzmann Machine (RBM) in GNU Octave by extracting Deep Belief Network (DBN) code example of DeepLearnToolbox. In case you are wondering, RBM is a machine learning algorithm that is promoted by Geoffrey Hinton as the basic of deep learning.

This code simply load MNIST digits handwriting dataset (provided by DeepLearnToolbox) and use it to train an RBM with 100 hidden nodes. The magic of RBM in this code happens within rbmtrain() function where the network is trained using Free Energy and Contrastive Divergence (CD). After 1 epoch of training, the weight of the RBM will be visualized and the reconstruction error will be printed.

function test_RBM 
t = time; 

# load dataset 
load mnist_uint8; 
train_x = double(train_x) / 255; 

# config 
opts.numepochs = 5; 
opts.batchsize = 100; 
opts.momentum = 0; 
opts.alpha = 1; 
rbm.sizes = [100]; 
rbm.alpha = opts.alpha; 
rbm.momentum = opts.momentum; 

# setup 
n = size(train_x, 2);
rbm.sizes = [n, rbm.sizes]; 

# weight 
rbm.W = zeros(rbm.sizes(2), 
rbm.sizes(1)); 
rbm.vW = zeros(rbm.sizes(2), rbm.sizes(1)); 

# biases 
rbm.b = zeros(rbm.sizes(1), 1); 
rbm.vb = zeros(rbm.sizes(1), 1); 
rbm.c = zeros(rbm.sizes(2), 1); 
rbm.vc = zeros(rbm.sizes(2), 1); 

# train 
rbm = rbmtrain(rbm, train_x, opts); 

# visualize 
figure; 
visualize(rbm.W'); 
disp(['elapsed time: ' num2str(time - t) 's']);

If the code runs correctly, you will see the weight visualization pops up.

MNIST RBM weight visualization

 

DeepLearnToolbox provides an easy-to-read code to help you understand deep learning algorithm better so I do encourage to drill down inside the API (eg. rbmtrain()). Since Octave/Matlab syntax is quite close to math formula, we can compare the API implementation directly to the math formula in theoretical explanation such as in http://deeplearning.net/tutorial/rbm.html.

Leave a Reply

avatar
  Subscribe  
Notify of