├── README.md ├── base64encode.m ├── edbn_95.52.mat ├── edbn_brian_test.py ├── edbnclean.m ├── edbnsetup.m ├── edbntest.m ├── edbntoptrain.m ├── edbntoxml.m ├── edbntrain.m ├── erbmdown.m ├── erbmtrain.m ├── erbmup.m ├── example.m ├── good_train.m ├── live_edbn.m ├── mnist_uint8.mat ├── siegert.m ├── struct2xml.m └── visualize.m /README.md: -------------------------------------------------------------------------------- 1 | # EDBN 2 | ==== 3 | 4 | EDBN is an event-based deep learning architecture for Matlab, first published in "Real-Time Classification and Sensor Fusion with a Spiking Deep Belief Network" by Peter O'Connor, Daniel Neil, Shih-Chii Liu, Tobi Delbruck and Michael Pfeiffer. The purpose of this code is to provide a prototype example of the completed algorithms presented in the paper which can be modified, learned, or extended. 5 | 6 | The original paper can be found at: 7 | [http://www.frontiersin.org/neuromorphic%20engineering/10.3389/fnins.2013.00178/abstract](http://www.frontiersin.org/neuromorphic%20engineering/10.3389/fnins.2013.00178/abstract) 8 | 9 | ### Features 10 | 11 | * Fast vectorized implementation. 12 | 13 | * Selection of add-ons to improve accuracy in select domains: persistent contrastive divergence, fast weights, sparsity and selectivity, decay, momentum, temperature, and variable gibbs steps. 14 | 15 | * Small filecount for easy modifications and extensions. 16 | 17 | ### Example 18 | 19 | ```matlab 20 | %% Load paths 21 | addpath(genpath('.')); 22 | 23 | %% Load data 24 | load mnist_uint8; 25 | 26 | % Convert data and rescale between 0 and 0.2 27 | train_x = double(train_x) / 255 * 0.2; 28 | test_x = double(test_x) / 255 * 0.2; 29 | train_y = double(train_y) * 0.2; 30 | test_y = double(test_y) * 0.2; 31 | 32 | %% Train network 33 | % Setup 34 | rand('seed', 42); 35 | clear edbn opts; 36 | edbn.sizes = [784 100 10]; 37 | opts.numepochs = 6; 38 | 39 | [edbn, opts] = edbnsetup(edbn, opts); 40 | 41 | % Train 42 | fprintf('Beginning training.\n'); 43 | edbn = edbntrain(edbn, train_x, opts); 44 | % Use supervised training on the top layer 45 | edbn = edbntoptrain(edbn, train_x, opts, train_y); 46 | 47 | % Show results 48 | figure; 49 | visualize(edbn.erbm{1}.W'); % Visualize the RBM weights 50 | er = edbntest (edbn, train_x, train_y); 51 | fprintf('Scored: %2.2f\n', (1-er)*100); 52 | 53 | %% Show the EDBN in action 54 | spike_list = live_edbn(edbn, test_x(1, :), opts); 55 | output_idxs = (spike_list.layers == numel(edbn.sizes)); 56 | 57 | figure(2); clf; 58 | hist(spike_list.addrs(output_idxs) - 1, 0:edbn.sizes(end)); 59 | title('Label Layer Classification Spikes'); 60 | %% Export to xml 61 | edbntoxml(edbn, opts, 'mnist_edbn'); 62 | ``` 63 | 64 | ### Installation 65 | 66 | Unzip the repo and navigate to it within Matlab. That's it. If you'd like to test the installation, run the following matlab file: 67 | ```matlab 68 | example.m 69 | ``` 70 | This will train a network on the MNIST handwritten digit database to greater than 92% accuracy, then run it on a spiking neuron network. The source has been tested successfully with Matlab R2012a and later. 71 | 72 | ### File overview 73 | The main files are referenced by the [example](example.m) file, which is compromised of the following execution order: 74 | * example.m - runs an example. 75 | * edbnsetup.m - initializes the network and load defaults. 76 | * edbntrain.m - performs unsupervised training of the network. 77 | * erbmtrain.m - trains a single RBM layer in the DBN. This is the *core* source file for the algorithm. 78 | * siegert.m - calculates the output spike rate of an input rate and input weights for LIF neurons. 79 | * erbmup.m / erbmdown.m - propagates rate-based activations up or down through LIF neurons. 80 | * edbntoptrain.m - performs supervised training by concatenating the top layer to the top-2 layer, and jointly training a (top-2, top)<->(top-1) RBM, then unrolling again. 81 | * edbnclean.m - cleans out all the temporary activations to save a minimum-size EDBN file. 82 | * edbntoxml.m - creates a base64-encoded representation of the network. 83 | * live_edbn.m - run the weights on an actual spiking network of neurons. 84 | 85 | The source also contains two .mat files which contain matlab structures: 86 | * edbn_95.52.mat is a spiking network that achieves 95.52% recognition accuracy on the MNIST database of handwritten digits. 87 | * mnist_uint8.mat contains the MNIST network in matlab UINT8 format, courtesy of the [DeepLearnToolbox](https://github.com/rasmusbergpalm/DeepLearnToolbox) and ultimately Yann LeCun. 88 | 89 | The source file that created the 95.52% accurate MNIST network is generated with [good_train.m](good_train.m). 90 | 91 | Finally, the script [edbn_brian_test.py](edbn_brian_test.py) demonstrates loading a trained network in the BRIAN neural network simulator, feeding it Poisson inputs from the MNIST database, and calculating the resulting output classification. 92 | 93 | ### Questions 94 | Please feel free to reach out here if you have any questions or difficulties. I'm happy to help guide you. 95 | -------------------------------------------------------------------------------- /base64encode.m: -------------------------------------------------------------------------------- 1 | function y = base64encode(x, alg, isChunked, url_safe) 2 | %BASE64ENCODE Perform base64 encoding on a string. 3 | % INPUT: 4 | % x - block of data to be encoded. Can be a string or a numeric 5 | % vector containing integers in the range 0-255. 6 | % alg - Algorithm to use: can take values 'java' or 'matlab'. Optional 7 | % variable defaulting to 'java' which is a little faster. If 8 | % 'java' is chosen than core of the code is performed by a call to 9 | % a java library. Optionally all operations can be performed using 10 | % matleb code. 11 | % isChunked - encode output into 76 character blocks. The returned 12 | % encoded string is broken into lines of no more than 13 | % 76 characters each, and each line will end with EOL. Notice that 14 | % if resulting string is saved as part of an xml file, those EOL's 15 | % are often stripped by xmlwrite funtrion prior to saving. 16 | % url_safe - use Modified Base64 for URL applications ('base64url' 17 | % encoding) "Base64 alphabet" ([A-Za-z0-9-_=]). 18 | % 19 | % 20 | % OUTPUT: 21 | % y - character array using only "Base64 alphabet" characters 22 | % 23 | % This function may be used to encode strings into the Base64 encoding 24 | % specified in RFC 2045 - MIME (Multipurpose Internet Mail Extensions). 25 | % The Base64 encoding is designed to represent arbitrary sequences of 26 | % octets in a form that need not be humanly readable. A 65-character 27 | % subset ([A-Za-z0-9+/=]) of US-ASCII is used, enabling 6 bits to be 28 | % represented per printable character. 29 | % 30 | % See also BASE64DECODE. 31 | % 32 | % Written by Jarek Tuszynski, SAIC, jaroslaw.w.tuszynski_at_saic.com 33 | % 34 | % Matlab version based on 2004 code by Peter J. Acklam 35 | % E-mail: pjacklam@online.no 36 | % URL: http://home.online.no/~pjacklam 37 | % http://home.online.no/~pjacklam/matlab/software/util/datautil/base64encode.m 38 | 39 | if nargin<2, alg='java'; end 40 | if nargin<3, isChunked=false; end 41 | if ~islogical(isChunked) 42 | if isnumeric(isChunked) 43 | isChunked=(isChunked>0); 44 | else 45 | isChunked=false; 46 | end 47 | end 48 | if nargin<4, url_safe=false; end 49 | if ~islogical(url_safe) 50 | if isnumeric(url_safe) 51 | url_safe=(url_safe>0); 52 | else 53 | url_safe=false; 54 | end 55 | end 56 | 57 | 58 | %% if x happen to be a filename than read the file 59 | % if (numel(x)<256) 60 | % if (exist(x, 'file')==2) 61 | % fid = fopen(x,'rb'); 62 | % x = fread(fid, 'uint8'); % read image file as a raw binary 63 | % fclose(fid); 64 | % end 65 | % end 66 | 67 | %% Perform conversion 68 | switch (alg) 69 | case 'java' 70 | base64 = org.apache.commons.codec.binary.Base64; 71 | y = base64.encodeBase64(x, isChunked); 72 | if url_safe 73 | y = strrep(y,'=','-'); 74 | y = strrep(y,'/','_'); 75 | end 76 | 77 | case 'matlab' 78 | 79 | %% add padding if necessary, to make the length of x a multiple of 3 80 | x = uint8(x(:)); 81 | ndbytes = length(x); % number of decoded bytes 82 | nchunks = ceil(ndbytes / 3); % number of chunks/groups 83 | if rem(ndbytes, 3)>0 84 | x(end+1 : 3*nchunks) = 0; % add padding 85 | end 86 | x = reshape(x, [3, nchunks]); % reshape the data 87 | y = repmat(uint8(0), 4, nchunks); % for the encoded data 88 | 89 | %% Split up every 3 bytes into 4 pieces 90 | % aaaaaabb bbbbcccc ccdddddd 91 | % to form 92 | % 00aaaaaa 00bbbbbb 00cccccc 00dddddd 93 | y(1,:) = bitshift(x(1,:), -2); % 6 highest bits of x(1,:) 94 | y(2,:) = bitshift(bitand(x(1,:), 3), 4); % 2 lowest bits of x(1,:) 95 | y(2,:) = bitor(y(2,:), bitshift(x(2,:), -4)); % 4 highest bits of x(2,:) 96 | y(3,:) = bitshift(bitand(x(2,:), 15), 2); % 4 lowest bits of x(2,:) 97 | y(3,:) = bitor(y(3,:), bitshift(x(3,:), -6)); % 2 highest bits of x(3,:) 98 | y(4,:) = bitand(x(3,:), 63); % 6 lowest bits of x(3,:) 99 | 100 | %% Perform the mapping 101 | % 0 - 25 -> A-Z 102 | % 26 - 51 -> a-z 103 | % 52 - 61 -> 0-9 104 | % 62 -> + 105 | % 63 -> / 106 | map = ['A':'Z', 'a':'z', '0':'9', '+/']; 107 | if (url_safe), map(63:64)='-_'; end 108 | y = map(y(:)+1); 109 | 110 | %% Add padding if necessary. 111 | npbytes = 3 * nchunks - ndbytes; % number of padding bytes 112 | if npbytes>0 113 | y(end-npbytes+1 : end) = '='; % '=' is used for padding 114 | end 115 | 116 | %% break into lines with length LineLength 117 | if (isChunked) 118 | eol = sprintf('\n'); 119 | nebytes = numel(y); 120 | nlines = ceil(nebytes / 76); % number of lines 121 | neolbytes = length(eol); % number of bytes in eol string 122 | 123 | % pad data so it becomes a multiple of 76 elements 124 | y(nebytes + 1 : 76 * nlines) = 0; 125 | y = reshape(y, 76, nlines); 126 | 127 | % insert eol strings 128 | y(end + 1 : end + neolbytes, :) = eol(:, ones(1, nlines)); 129 | 130 | % remove padding, but keep the last eol string 131 | m = nebytes + neolbytes * (nlines - 1); 132 | n = (76+neolbytes)*nlines - neolbytes; 133 | y(m+1 : n) = []; 134 | end 135 | end 136 | 137 | %% reshape to a row vector and make it a character array 138 | y = char(reshape(y, 1, numel(y))); 139 | -------------------------------------------------------------------------------- /edbn_95.52.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyneil/edbn/a815348223c435d9cf9127af0962113126bbcfce/edbn_95.52.mat -------------------------------------------------------------------------------- /edbn_brian_test.py: -------------------------------------------------------------------------------- 1 | from brian import * 2 | from numpy import * 3 | from functools import partial 4 | import scipy.io as sio 5 | import multiprocessing 6 | 7 | ## Quick classification demonstration using BRIAN's LIF neurons 8 | ## To run: python ./edbn_brian_test.py 9 | ## Requires BRIAN, numpy, and scipy 10 | 11 | # Load MNIST data 12 | def load_data(matfile): 13 | dataset = sio.loadmat('mnist_uint8.mat') 14 | # Rescale from int to float [0 1.0], and then between [0 0.2] of max firing 15 | test_x = dataset['test_x'] / 255.0 * 0.2 16 | test_y = dataset['test_y'] / 255.0 * 0.2 17 | return test_x, test_y 18 | 19 | # Thanks to Evangelos Stromatias for this Matlab->Python interface function: 20 | def loadMatEDBNDescription(matfile): 21 | ''' 22 | This function opens a MAT description of a trainned Event-Based 23 | Deep Belief Network and returns the number of neurons in each layer of the network 24 | and a list of NUMPY arrays of the synaptic weights of each layer. 25 | ''' 26 | edbn = sio.loadmat(matfile) 27 | # Get size and dimensions of layers 28 | edbnTopology = edbn['edbn']['sizes'][0][0][0] #number of layers and neurons 29 | # Layer weights are saved at this index: 30 | _WEIGHTStr = 10 31 | # Get weights for each layer 32 | weightList = [] 33 | for l in range(edbnTopology.size-1): 34 | weightList.append(edbn['edbn']['erbm'][0][0][0][l][0][0][_WEIGHTStr]) 35 | return edbnTopology, weightList 36 | 37 | # Core of this script: initialize a net, load a digit, and test its results 38 | def run_digit(test_data, weightlist, topology, params): 39 | reinit_default_clock() 40 | clear(True) 41 | test_x = test_data[0] 42 | test_y = test_data[1] 43 | pops = [] 44 | conns = [] 45 | # Build populations 46 | for (layer_num, layer_size) in enumerate(edbn_topology): 47 | if layer_num == 0: 48 | pops.append(PoissonGroup(layer_size)) 49 | else: 50 | pops.append(NeuronGroup(layer_size, model=params['eqs'], threshold=params['v_th'], reset=params['v_r'])) 51 | # Build connections 52 | for (layer_num, weights) in enumerate(weightlist): 53 | c = Connection(pops[layer_num], pops[layer_num+1], 'v', weight=weights.T*1000*mV) 54 | conns.append(c) 55 | 56 | # Set a spike rate of 0.2 * 30 = 6 spikes/sec for a fully-on pixel 57 | pops[0].rate = test_x * 30.0 58 | # Track the output layer spikes 59 | output_spikes = SpikeCounter(pops[-1]) 60 | # Run for one second, approximately 1000 spikes total input over all 784 pixels 61 | run(1.0*second) 62 | # Get digit guess and correct answer 63 | guessed_digit = np.argmax(output_spikes.count) 64 | correct_digit = np.argmax(test_y) 65 | # Return true if correct 66 | return guessed_digit == correct_digit 67 | 68 | if __name__ == "__main__": 69 | # How many digits to test 70 | num_to_test = 12 71 | # Set parameters 72 | params = {} 73 | params['tau_m'] = 5000 * ms # membrane time constant 74 | params['v_r'] = 0 * mV # reset potential 75 | params['v_th'] = 1000 * mV # threshold potential 76 | params['eqs'] = ''' 77 | dv/dt = -v/params['tau_m'] : volt 78 | ''' 79 | # Load network 80 | edbn_topology, weightlist = loadMatEDBNDescription('edbn_95.52.mat') 81 | # Load data 82 | test_x, test_y = load_data('mnist_uint8.mat') 83 | # Build partial function to pass to the mapper 84 | partial_run_digit = partial(run_digit, topology=edbn_topology, weightlist=weightlist, params=params) 85 | # Initialize a multiprocessing pool 86 | pool = multiprocessing.Pool(4) 87 | # Distribute the test data into a set of tuples with the label 88 | test_data = [(test_x[idx, :], test_y[idx,:]) for idx in range(num_to_test) ] 89 | # Map the function over the pool 90 | results = pool.map(partial_run_digit, test_data) 91 | # Spit out results 92 | print "%i correct answers in %i trials for %.2f accuracy." % (np.sum(results), num_to_test, float(np.sum(results))/num_to_test*100) 93 | -------------------------------------------------------------------------------- /edbnclean.m: -------------------------------------------------------------------------------- 1 | function edbn = edbn_clean(edbn) 2 | 3 | for l=1:numel(edbn.erbm); 4 | % Wipe the extra stuff 5 | edbn.erbm{l}.vW = []; 6 | edbn.erbm{l}.vb = []; 7 | edbn.erbm{l}.vc = []; 8 | edbn.erbm{l}.FW = []; 9 | edbn.erbm{l}.vFW = []; 10 | edbn.erbm{l}.h1 = []; 11 | edbn.erbm{l}.h2 = []; 12 | edbn.erbm{l}.v1 = []; 13 | edbn.erbm{l}.v2 = []; 14 | end 15 | -------------------------------------------------------------------------------- /edbnsetup.m: -------------------------------------------------------------------------------- 1 | function [edbn, opts] = edbnsetup(edbn, opts) 2 | assert (numel(edbn.sizes) > 1, 'Sizes must be given for all layers, including input and output.'); 3 | 4 | % Set defaults 5 | if ~isfield(opts,'alpha'), opts.alpha = 1; end 6 | if ~isfield(opts,'decay'), opts.decay = 0.0001; end 7 | if ~isfield(opts,'momentum'), opts.momentum = 0.0; end 8 | if ~isfield(opts,'temp'), opts.temp = 0.005; end 9 | if ~isfield(opts,'tau_m'), opts.tau_m = 5.0; end 10 | if ~isfield(opts,'tau_s'), opts.tau_s = 0.001; end 11 | if ~isfield(opts,'t_ref'), opts.t_ref = 0.002; end 12 | if ~isfield(opts,'v_thr'), opts.v_thr = 0.005; end 13 | if ~isfield(opts,'f_infl'), opts.f_infl = 1; end 14 | if ~isfield(opts,'f_decay'), opts.f_decay = 0.05; end 15 | if ~isfield(opts,'f_alpha'), opts.f_alpha = 5; end 16 | if ~isfield(opts,'pcd'), opts.pcd = 1; end 17 | if ~isfield(opts,'sp'), opts.sp = 0.1; end 18 | if ~isfield(opts,'sp_infl'), opts.sp_infl = 0.2; end 19 | if ~isfield(opts,'ngibbs'), opts.ngibbs = 2; end 20 | if ~isfield(opts,'initscl'), opts.initscl = 0.01; end 21 | if ~isfield(opts,'batchsize'), opts.batchsize = 50; end 22 | if ~isfield(opts,'reup'), opts.reup = 1; end 23 | if ~isfield(opts,'wtreset'), opts.wtreset = 1; end 24 | 25 | for u = 1 : numel(edbn.sizes) - 1 26 | % Set up constants for learning 27 | edbn.erbm{u}.alpha = opts.alpha; 28 | edbn.erbm{u}.decay = opts.decay; 29 | edbn.erbm{u}.momentum = opts.momentum; 30 | edbn.erbm{u}.f_infl = opts.f_infl; 31 | edbn.erbm{u}.f_decay = opts.f_decay; 32 | edbn.erbm{u}.f_alpha = opts.f_alpha; 33 | edbn.erbm{u}.pcd = opts.pcd; 34 | edbn.erbm{u}.sp = opts.sp; 35 | edbn.erbm{u}.sp_infl = opts.sp_infl; 36 | 37 | % Set up constants for the Siegert formula 38 | edbn.erbm{u}.sieg.tau_m = opts.tau_m; 39 | edbn.erbm{u}.sieg.tau_s = opts.tau_s; 40 | edbn.erbm{u}.sieg.t_ref = opts.t_ref; 41 | edbn.erbm{u}.sieg.v_thr = opts.v_thr; 42 | edbn.erbm{u}.sieg.temp = opts.temp; 43 | 44 | % Weights, lower bias, and upper bias 45 | if(opts.wtreset) 46 | edbn.erbm{u}.W = opts.initscl*randn(edbn.sizes(u+1), edbn.sizes(u)); 47 | edbn.erbm{u}.b = zeros(edbn.sizes(u ), 1); 48 | edbn.erbm{u}.c = zeros(edbn.sizes(u+1), 1); 49 | end 50 | 51 | % Delta values 52 | edbn.erbm{u}.vW = zeros(edbn.sizes(u+1), edbn.sizes(u)); 53 | edbn.erbm{u}.vb = zeros(edbn.sizes(u ), 1); 54 | edbn.erbm{u}.vc = zeros(edbn.sizes(u+1), 1); 55 | 56 | % Temporary storage for data and model intermediates as well as fast weights 57 | edbn.erbm{u}.v1 = zeros(opts.batchsize, edbn.sizes(u)); 58 | edbn.erbm{u}.h1 = zeros(opts.batchsize, edbn.sizes(u+1)); 59 | edbn.erbm{u}.v2 = zeros(opts.batchsize, edbn.sizes(u)); 60 | edbn.erbm{u}.h2 = zeros(opts.batchsize, edbn.sizes(u+1)); 61 | edbn.erbm{u}.FW = zeros(edbn.sizes(u+1), edbn.sizes(u)); 62 | edbn.erbm{u}.vFW = zeros(edbn.sizes(u+1), edbn.sizes(u)); 63 | end 64 | 65 | end 66 | 67 | -------------------------------------------------------------------------------- /edbntest.m: -------------------------------------------------------------------------------- 1 | function [er, bad] = edbntest(edbn, x, y) 2 | % Pass activations up 3 | for i = 1:numel(edbn.erbm) 4 | x = erbmup(edbn.erbm{i}, x)'; 5 | end 6 | 7 | % Get winners from top row labels 8 | [~, found] = max(x,[],2); 9 | [~, expected] = max(y,[],2); 10 | bad = find(found ~= expected); 11 | er = numel(bad) / size(y, 1); 12 | end 13 | -------------------------------------------------------------------------------- /edbntoptrain.m: -------------------------------------------------------------------------------- 1 | function edbn = edbntoptrain(edbn, data, opts, y) 2 | % Train the top layer by concatenating the top layer to a lower layer 3 | % and jointly training the set. 4 | orig_data = data; 5 | 6 | % Find merged layer and output layer 7 | merge_l = numel(edbn.erbm) - 1; 8 | out_l = numel(edbn.erbm); 9 | 10 | % Build composite layer to train output; take top layer and place it 11 | % next to top-2 to joint train the layers: [top-2 top] <-> [top-1] 12 | top = edbn.erbm{out_l}; 13 | top.W = [edbn.erbm{merge_l}.W edbn.erbm{out_l}.W' ]; 14 | top.vW = [edbn.erbm{merge_l}.vW edbn.erbm{out_l}.vW' ]; 15 | top.FW = [edbn.erbm{merge_l}.FW edbn.erbm{out_l}.FW' ]; 16 | top.vFW = [edbn.erbm{merge_l}.vFW edbn.erbm{out_l}.vFW']; 17 | top.v1 = [edbn.erbm{merge_l}.v1 edbn.erbm{out_l}.h1 ]; 18 | top.v2 = [edbn.erbm{merge_l}.v2 edbn.erbm{out_l}.h2 ]; 19 | top.b = [edbn.erbm{merge_l}.b; edbn.erbm{out_l}.c ]; 20 | top.vb = [edbn.erbm{merge_l}.vb; edbn.erbm{out_l}.vc ]; 21 | top.h1 = edbn.erbm{merge_l}.h1; 22 | top.h2 = edbn.erbm{merge_l}.h2; 23 | top.c = edbn.erbm{merge_l}.c; 24 | top.vc = edbn.erbm{merge_l}.vc; 25 | 26 | % Train the composite layer 27 | if(opts.reup) 28 | opts.epoch_loops = opts.numepochs; 29 | opts.numepochs = 1; 30 | else 31 | opts.epoch_loops = 1; 32 | end 33 | 34 | for i = 0:opts.epoch_loops-1 35 | opts.ep_st = i; 36 | % Pass the data up 37 | for i = 1:numel(edbn.erbm)-2 38 | data = erbmup(edbn.erbm{i}, data)'; 39 | end 40 | top = erbmtrain(top, [data y], opts); 41 | data = orig_data; 42 | end 43 | 44 | % Slice the top layer from the composite layer 45 | slice = edbn.sizes(merge_l)+1 : edbn.sizes(merge_l)+edbn.sizes(out_l+1); 46 | edbn.erbm{out_l}.W = top.W(:, slice)'; 47 | edbn.erbm{out_l}.vW = top.vW(:, slice)'; 48 | edbn.erbm{out_l}.FW = top.FW(:, slice)'; 49 | edbn.erbm{out_l}.vFW = top.vFW(:, slice)'; 50 | edbn.erbm{out_l}.h1 = top.v1(:, slice); 51 | edbn.erbm{out_l}.h2 = top.v2(:, slice); 52 | edbn.erbm{out_l}.v1 = top.h1( :, : ); 53 | edbn.erbm{out_l}.v2 = top.h2( :, : ); 54 | edbn.erbm{out_l}.b = top.c( : ); 55 | edbn.erbm{out_l}.vb = top.vc( : ); 56 | edbn.erbm{out_l}.c = top.b(slice); 57 | edbn.erbm{out_l}.vc = top.vb(slice); 58 | 59 | % Slice off the composite layer into its original layer 60 | slice = 1 : edbn.sizes(merge_l); 61 | edbn.erbm{merge_l}.W = top.W (:, slice); 62 | edbn.erbm{merge_l}.vW = top.vW(:, slice); 63 | edbn.erbm{merge_l}.FW = top.FW (:, slice); 64 | edbn.erbm{merge_l}.vFW = top.vFW(:, slice); 65 | edbn.erbm{merge_l}.h1 = top.h1( :, : ); 66 | edbn.erbm{merge_l}.h2 = top.h2( :, : ); 67 | edbn.erbm{merge_l}.v1 = top.v1(:, slice); 68 | edbn.erbm{merge_l}.v2 = top.v2(:, slice); 69 | edbn.erbm{merge_l}.b = top.b(slice); 70 | edbn.erbm{merge_l}.vb = top.vb(slice); 71 | edbn.erbm{merge_l}.c = top.c; 72 | edbn.erbm{merge_l}.vc = top.vc; 73 | 74 | end -------------------------------------------------------------------------------- /edbntoxml.m: -------------------------------------------------------------------------------- 1 | function edbntoxml(edbn, opts, name) 2 | 3 | % Load defaults 4 | if ~isfield(opts,'makevisdim'), opts.makevisdim = 1; end; 5 | if ~isfield(edbn,'fw_conns'), edbn.fw_conns=num2cell([2:numel(edbn.sizes) 0]); end; 6 | if ~isfield(edbn,'fw_erbm'), edbn.fw_erbm=[1:numel(edbn.erbm) 0]; end 7 | 8 | % Clean out all temporary variables 9 | edbn = edbnclean(edbn); 10 | 11 | % Set up metadata 12 | net.name = name; 13 | net.type = 'edbn'; 14 | net.notes = ''; 15 | net.nLayers = numel(edbn.sizes); 16 | net.param = edbn.erbm{1}.sieg; 17 | net.dob = datestr(now); 18 | 19 | % Build show dimensions 20 | if(opts.makevisdim) 21 | for i = 1:numel(edbn.sizes) 22 | factors = factor(edbn.sizes(i)); 23 | opts.show_dims{i} = [prod(factors(2:2:end)) prod(factors(1:2:end))]; 24 | end 25 | end 26 | 27 | % Encode data in base64 for export 28 | encode = @(x)base64encode(typecast(single(x(:)),'uint8'),'matlab'); 29 | 30 | % Build layers with all parameters 31 | net.Layer=cell(1, numel(edbn.sizes)); 32 | for i = 1:numel(edbn.sizes) 33 | lay.index = i-1; 34 | lay.nUnits = edbn.sizes(i); 35 | lay.dimx = opts.show_dims{i}(1); 36 | lay.dimy = opts.show_dims{i}(2); 37 | lay.thresh.Attributes.dt = 'base64-single'; 38 | 39 | if (edbn.fw_erbm(i) ~= 0) 40 | lay.targ = i; 41 | lay.param = edbn.erbm{edbn.fw_erbm(i)}.sieg; 42 | lay.W.Attributes.dt = 'base64-single'; 43 | lay.W.Text = encode(edbn.erbm{edbn.fw_erbm(i)}.W); 44 | lay.thresh.Text=encode(repmat(edbn.erbm{edbn.fw_erbm(i)}.sieg.v_thr, 1, lay.nUnits)); 45 | else 46 | lay.targ = '-'; 47 | lay.param = edbn.erbm{end}.sieg; 48 | lay.W.Attributes.dt = 'base64-single'; 49 | lay.W.Text = encode([]); 50 | lay.thresh.Text=encode(repmat(edbn.erbm{end}.sieg.v_thr, 1, lay.nUnits)); 51 | end 52 | 53 | net.Layer{i} = lay; 54 | end 55 | 56 | % Encapsulate 57 | outnet.Network = net; 58 | % Write XML 59 | struct2xml(outnet, strcat(name, '.xml')); -------------------------------------------------------------------------------- /edbntrain.m: -------------------------------------------------------------------------------- 1 | function edbn = edbntrain(edbn, data, opts) 2 | % Initialize the training count 3 | opts.ep_st = 0; 4 | 5 | % Expect the final two layers to be top-trained 6 | for l = 1:numel(edbn.erbm) - 2 7 | edbn.erbm{l} = erbmtrain(edbn.erbm{l}, data, opts); 8 | data = erbmup(edbn.erbm{l}, data)'; 9 | end 10 | end -------------------------------------------------------------------------------- /erbmdown.m: -------------------------------------------------------------------------------- 1 | function x = erbmdown(erbm, x) 2 | erbm.sieg.temp = 0; 3 | x = siegert(x', erbm.W', erbm.sieg); 4 | end 5 | -------------------------------------------------------------------------------- /erbmtrain.m: -------------------------------------------------------------------------------- 1 | function erbm = erbmtrain(erbm, x, opts) 2 | % Check the inputs 3 | assert(isfloat(x), 'Data must be a float.'); 4 | m = size(x, 1); 5 | numbatches = m / opts.batchsize; 6 | assert(rem(numbatches, 1) == 0, 'Numbatches is not an integer.'); 7 | 8 | % Preallocate 9 | linsel = linspace(0, 1, opts.batchsize); 10 | linsp = linspace(0, 1, size(erbm.W,1)); 11 | 12 | % Loop and train 13 | for ep = 1 : opts.numepochs 14 | kk = randperm(m); 15 | for l = 1 : numbatches 16 | batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :); 17 | 18 | % Obtain data sample 19 | erbm.v1 = batch; 20 | erbm.h1 = siegert(erbm.v1', erbm.W , erbm.sieg)'; 21 | if(~erbm.pcd) 22 | erbm.h2 = erbm.h1; 23 | end 24 | 25 | % Obtain model sample, using fast weights to explore quickly 26 | for g = 1:opts.ngibbs 27 | erbm.v2 = siegert(erbm.h2', erbm.W' + ... 28 | erbm.f_infl * erbm.FW', erbm.sieg)'; 29 | erbm.h2 = siegert(erbm.v2', erbm.W + ... 30 | erbm.f_infl * erbm.FW , erbm.sieg)'; 31 | end 32 | 33 | % Sparsify; see Goh, Thome, Cord 34 | [~, ixsp] = sort(erbm.h1, 2); 35 | [~, ordersp] = sort(ixsp , 2); 36 | ranksp = linsp(ordersp); 37 | h1sp = ranksp.^(1/erbm.sp-1); 38 | [~, ixsel] = sort(h1sp, 1); 39 | [~, ordersel] = sort(ixsel, 1); 40 | ranksel = linsel(ordersel); 41 | h1sp = ranksel.^(1/erbm.sp-1); 42 | erbm.h1 = erbm.sp_infl * h1sp + (1 - erbm.sp_infl) * erbm.h1; 43 | 44 | % Calculate activation correlations 45 | c1 = erbm.h1' * erbm.v1; 46 | c2 = erbm.h2' * erbm.v2; 47 | 48 | % Update fast weights and biases 49 | erbm.vFW = erbm.f_alpha / opts.batchsize * (c1 - c2); 50 | dW = erbm.alpha / opts.batchsize * (c1 - c2); 51 | db = erbm.alpha / opts.batchsize * sum(erbm.v1 - erbm.v2)'; 52 | dc = erbm.alpha / opts.batchsize * sum(erbm.h1 - erbm.h2)'; 53 | 54 | % Incorporate decay 55 | erbm.FW = (1 - erbm.f_decay) * erbm.FW + erbm.vFW; 56 | dW = dW - erbm.decay * erbm.alpha * erbm.W; 57 | db = db - erbm.decay * erbm.alpha * erbm.b; 58 | dc = dc - erbm.decay * erbm.alpha * erbm.c; 59 | 60 | % Incorporate momentum 61 | erbm.vW = opts.momentum * erbm.vW + dW; 62 | erbm.vb = opts.momentum * erbm.vb + db; 63 | erbm.vc = opts.momentum * erbm.vc + dc; 64 | 65 | % Update final values 66 | erbm.W = erbm.W + erbm.vW; 67 | erbm.b = erbm.b + erbm.vb; 68 | erbm.c = erbm.c + erbm.vc; 69 | end 70 | 71 | % Inform the user 72 | fprintf('Epoch %i: mean error: %1.5f.\n', ... 73 | ep+opts.ep_st, mean(abs(sum(erbm.v2) - sum(erbm.v1))) / opts.batchsize); 74 | end 75 | end -------------------------------------------------------------------------------- /erbmup.m: -------------------------------------------------------------------------------- 1 | function x = erbmup(erbm, x) 2 | %erbm.sieg.temp = 0; 3 | x = siegert(x', erbm.W, erbm.sieg); 4 | end 5 | -------------------------------------------------------------------------------- /example.m: -------------------------------------------------------------------------------- 1 | %% Load paths 2 | addpath(genpath('.')); 3 | 4 | %% Load data 5 | load mnist_uint8; 6 | 7 | % Convert data and rescale between 0 and 0.2 8 | train_x = double(train_x) / 255 * 0.2; 9 | test_x = double(test_x) / 255 * 0.2; 10 | train_y = double(train_y) * 0.2; 11 | test_y = double(test_y) * 0.2; 12 | 13 | %% Train network 14 | % Setup 15 | rand('seed', 42); 16 | clear edbn opts; 17 | edbn.sizes = [784 100 10]; 18 | opts.numepochs = 6; 19 | 20 | [edbn, opts] = edbnsetup(edbn, opts); 21 | 22 | % Train 23 | fprintf('Beginning training.\n'); 24 | edbn = edbntrain(edbn, train_x, opts); 25 | % Use supervised training on the top layer 26 | edbn = edbntoptrain(edbn, train_x, opts, train_y); 27 | 28 | % Show results 29 | figure; 30 | visualize(edbn.erbm{1}.W'); % Visualize the RBM weights 31 | er = edbntest (edbn, test_x, test_y); 32 | fprintf('Scored: %2.2f\n', (1-er)*100); 33 | 34 | %% Show the EDBN in action 35 | spike_list = live_edbn(edbn, test_x(1, :), opts); 36 | output_idxs = (spike_list.layers == numel(edbn.sizes)); 37 | 38 | figure(2); clf; 39 | hist(spike_list.addrs(output_idxs) - 1, 0:edbn.sizes(end)); 40 | xlabel('Digit Guessed'); 41 | ylabel('Histogram Spike Count'); 42 | title('Label Layer Classification Spikes'); 43 | %% Export to xml 44 | edbntoxml(edbn, opts, 'mnist_edbn'); -------------------------------------------------------------------------------- /good_train.m: -------------------------------------------------------------------------------- 1 | %% Load paths 2 | addpath(genpath('.')); 3 | 4 | %% Load data 5 | load mnist_uint8; 6 | 7 | train_x = double(train_x) / 255 * 0.2; 8 | test_x = double(test_x) / 255 * 0.2; 9 | train_y = double(train_y) * 0.2; 10 | test_y = double(test_y) * 0.2; 11 | 12 | %% Train network 13 | rand('seed', 42); 14 | clear edbn opts; 15 | edbn.sizes = [784 500 500 10]; 16 | opts.numepochs = 2; 17 | opts.alpha = 0.005; 18 | [edbn, opts] = edbnsetup(edbn, opts); 19 | 20 | opts.momentum = 0.0; opts.numepochs = 2; 21 | edbn = edbntrain(edbn, train_x, opts); 22 | edbn = edbntoptrain(edbn, train_x, opts, train_y); 23 | 24 | opts.momentum = 0.8; opts.numepochs = 60; 25 | edbn = edbntrain(edbn, train_x, opts); 26 | 27 | edbn = edbntrain(edbn, train_x, opts); 28 | edbn = edbntoptrain(edbn, train_x, opts, train_y); 29 | 30 | % Show results 31 | figure; 32 | visualize(edbn.erbm{1}.W'); % Visualize the RBM weights 33 | er = edbntest (edbn, train_x, train_y); 34 | fprintf('Scored: %2.2f\n', (1-er)*100); 35 | filename = sprintf('good_mnist_%2.2f-%s.mat',(1-er)*100, date()); 36 | edbnclean(edbn); 37 | save(filename,'edbn'); 38 | 39 | opts.momentum = 0.8; 40 | opts.numepochs = 80; 41 | edbn = edbntoptrain(edbn, train_x, opts, train_y); 42 | 43 | % Show results 44 | figure; 45 | visualize(edbn.erbm{1}.W'); % Visualize the RBM weights 46 | er = edbntest (edbn, train_x, train_y); 47 | fprintf('Scored: %2.2f\n', (1-er)*100); 48 | filename = sprintf('good_mnist_%2.2f-%s.mat',(1-er)*100, date()); 49 | edbnclean(edbn); 50 | save(filename,'edbn'); 51 | 52 | %% Show the EDBN in action 53 | spike_list = live_edbn(edbn, test_x(1, :), opts); 54 | output_idxs = (spike_list.layers == numel(edbn.sizes)); 55 | 56 | figure(2); clf; 57 | hist(spike_list.addrs(output_idxs) - 1, 0:edbn.sizes(end)); 58 | 59 | %% Export to xml to load into JSpikeStack 60 | edbntoxml(edbn, opts, 'mnist_edbn'); -------------------------------------------------------------------------------- /live_edbn.m: -------------------------------------------------------------------------------- 1 | function [spike_list, edbn, opts] = live_edbn(edbn, x, opts) 2 | % Set defaults 3 | if ~isfield(edbn,'fw_conns'), edbn.fw_conns=num2cell([2:numel(edbn.sizes) 0]); end; 4 | if ~isfield(edbn,'fb_conns'), edbn.fb_conns=num2cell([0 1:numel(edbn.sizes)-1]); end; 5 | if ~isfield(edbn,'fw_erbm'), edbn.fw_erbm=[1:numel(edbn.erbm) 0]; end 6 | 7 | if ~isfield(opts,'recreate'), opts.recreate = 1; end; 8 | if ~isfield(opts,'timespan'), opts.timespan = 4; end; 9 | if ~isfield(opts,'numspikes'), opts.numspikes = 2000; end; 10 | if ~isfield(opts,'delay'), opts.delay = 0.001; end; 11 | if ~isfield(opts,'show_dt'), opts.show_dt = 0.010; end; 12 | if ~isfield(opts,'vis_tau'), opts.vis_tau = 0.05; end; 13 | if ~isfield(opts,'makespikes'), opts.makespikes = 1; end; 14 | if ~isfield(opts,'makevisdim'), opts.makevisdim = 1; end; 15 | if ~isfield(opts,'vis_handle'), opts.vis_handle = figure; end; 16 | if ~isfield(opts,'vis_layers'), opts.vis_layers = 1:numel(edbn.sizes); end; 17 | 18 | if ~isfield(opts,'ff'), opts.ff = [ones(1, numel(edbn.sizes)-1) 0]; end; 19 | if ~isfield(opts,'fb'), opts.fb = zeros(1, numel(edbn.sizes)); end; 20 | 21 | % Create reconstruction layer as necessary 22 | if(opts.recreate) 23 | edbn.sizes = [edbn.sizes edbn.sizes(1)]; 24 | edbn.erbm = [edbn.erbm edbn.erbm(1)]; 25 | edbn.fb_conns{2} = numel(edbn.sizes); 26 | edbn.fw_conns = [edbn.fw_conns edbn.fw_conns{1}]; 27 | edbn.fb_conns = [edbn.fb_conns 0]; 28 | opts.ff = [opts.ff 0]; 29 | opts.fb(2) = 1; 30 | opts.fb = [opts.fb 0]; 31 | edbn.fw_erbm = [edbn.fw_erbm numel(edbn.erbm)]; 32 | opts.vis_layers = [opts.vis_layers numel(edbn.sizes)]; 33 | end 34 | 35 | % Create spikes proportional to intensity and assign a random time 36 | if(opts.makespikes) 37 | inp.addr = randsample(numel(x), opts.numspikes, true, x(:))'; 38 | inp.times = sort(rand(1,opts.numspikes)) .* opts.timespan; 39 | x = inp; 40 | end 41 | 42 | % Initialize neurons 43 | layers = numel(edbn.sizes); 44 | mem = cell(1, layers); 45 | mem_time = cell(1, layers); 46 | refrac_end = cell(1, layers); 47 | for i = 1:layers 48 | mem{i} = zeros(1, edbn.sizes(i)); 49 | mem_time{i} = 0; 50 | refrac_end{i} = zeros(1, edbn.sizes(i)); 51 | end 52 | 53 | % Initialize spike queues 54 | queue.times = x.times; 55 | queue.layers = 1 * ones(1, numel(x.addr)); 56 | queue.addrs = num2cell(x.addr); 57 | 58 | % Build show dimensions 59 | if(opts.makevisdim) 60 | for i = 1:layers 61 | factors = factor(edbn.sizes(i)); 62 | opts.show_dims{i} = [prod(factors(2:2:end)) prod(factors(1:2:end))]; 63 | end 64 | end 65 | 66 | % Build plotspace 67 | figure(opts.vis_handle); clf; 68 | last_spiked = cell(1, layers); 69 | image_handle = cell(1, layers); 70 | for i=1:numel(edbn.sizes) 71 | % Create storage for visualization 72 | last_spiked{i} = zeros(1, edbn.sizes(i)); 73 | 74 | if(any(i==opts.vis_layers)) 75 | % Plot and store a handle 76 | subplot(1, numel(opts.vis_layers), find(i==opts.vis_layers)); 77 | image(reshape(last_spiked{i}, opts.show_dims{i})', 'CDataMapping','scaled'); 78 | colormap bone; axis image; axis off; 79 | image_handle{i} = get(gca,'Children'); 80 | end 81 | end 82 | 83 | % Run the event-based network 84 | tic; 85 | idx = 1; 86 | last_show = 0; 87 | while(idx < size(queue.times, 2)) 88 | % Pull out spikes to process 89 | curr_time = queue.times(idx); 90 | from_addr = queue.addrs{idx}; 91 | from_layer = queue.layers(idx); 92 | 93 | % Update display 94 | last_spiked{from_layer}(from_addr) = last_spiked{from_layer}(from_addr) + 1; 95 | 96 | % Process forward and backward connections 97 | for direction=1:2 98 | if(direction == 1), conns = opts.ff(from_layer) .* edbn.fw_conns{from_layer}; end 99 | if(direction == 2), conns = opts.fb(from_layer) .* edbn.fb_conns{from_layer}; end 100 | 101 | % Skip if nothing to do 102 | if(~any(conns)), continue; end 103 | 104 | for l_idx = 1:numel(conns) 105 | layer = conns(l_idx); 106 | 107 | % Decay 108 | time_gap = curr_time - mem_time{layer}; 109 | decayfac = exp(-time_gap / opts.tau_m); 110 | mem{layer} = mem{layer} .* decayfac; 111 | 112 | % Add Impulse 113 | not_refrac = curr_time > refrac_end{layer}; 114 | if(direction == 1) 115 | impulse = edbn.erbm{edbn.fw_erbm(from_layer)}.W(:, from_addr)'; 116 | else 117 | impulse = edbn.erbm{edbn.fw_erbm(layer)}.W(from_addr, :); 118 | end 119 | mem{layer} = mem{layer} + sum(bsxfun(@times, not_refrac, impulse), 1); 120 | 121 | % Store new 'old' time 122 | mem_time{layer} = curr_time; 123 | 124 | % Check for firing; reset potential and store refrac 125 | firings = find(mem{layer} > opts.v_thr); 126 | mem{layer}(firings) = 0; 127 | refrac_end{layer}(firings) = curr_time + opts.t_ref; 128 | 129 | % Update queues 130 | if(numel(firings > 0)) 131 | insert_idx = find(curr_time + opts.delay < queue.times, 1, 'first'); 132 | if(isempty(insert_idx)) 133 | insert_idx = size(queue, 2); 134 | end 135 | queue.times = horzcat(queue.times(1:insert_idx-1), ... 136 | curr_time + opts.delay, ... 137 | queue.times(insert_idx:end)); 138 | queue.layers = horzcat(queue.layers(1:insert_idx-1), ... 139 | layer, ... 140 | queue.layers(insert_idx:end)); 141 | queue.addrs = horzcat(queue.addrs(1:insert_idx-1), ... 142 | {firings}, ... 143 | queue.addrs(insert_idx:end)); 144 | end 145 | end 146 | end 147 | 148 | % Display if desired 149 | if(curr_time - last_show > opts.show_dt) 150 | for i = 1:numel(edbn.sizes) 151 | last_spiked{i} = last_spiked{i} .* exp(-opts.show_dt / opts.vis_tau); 152 | if(any(i==opts.vis_layers)) 153 | set(image_handle{i}, 'CData', reshape(last_spiked{i}, opts.show_dims{i})'); 154 | end 155 | end 156 | pause(0.01); 157 | last_show = curr_time; 158 | end 159 | 160 | % Pick next time 161 | idx = idx + 1; 162 | end 163 | % Report finish and the timing 164 | fprintf('Completed %i input spikes occurring over %2.2f seconds, in %2.3f seconds of real time.\n', ... 165 | opts.numspikes, opts.timespan, toc); 166 | 167 | % De-parallelize output 168 | spike_list.addrs = cell2mat(queue.addrs); 169 | numrepeats = cellfun(@(x)size(x,2),queue.addrs); 170 | spike_list.times = cell2mat(arrayfun(@(time, repeats) ones(1,repeats)*time, queue.times, numrepeats, 'UniformOutput', 0)); 171 | spike_list.layers = cell2mat(arrayfun(@(layer, repeats) ones(1,repeats)*layer, queue.layers, numrepeats, 'UniformOutput', 0)); 172 | -------------------------------------------------------------------------------- /mnist_uint8.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyneil/edbn/a815348223c435d9cf9127af0962113126bbcfce/mnist_uint8.mat -------------------------------------------------------------------------------- /siegert.m: -------------------------------------------------------------------------------- 1 | function y=siegert(x,w,opts) 2 | % Compute the ouput of an array of Siegert Neurons. Siegert Neurons have 3 | % outputs that approximate the mean firing rate of leaky integrate-and-fire 4 | % neurons with the same parameters 5 | % 6 | % y=siegert(x,w,P) 7 | % 8 | % Example Usage: 9 | % nInputs=5; 10 | % nSamples=100; 11 | % nUnits=3; 12 | % x=bsxfun(@plus,rand(nInputs,nSamples),linspace(0,500,nSamples)); 13 | % w=.4*randn(nUnits,nInputs)+.8; 14 | % P=struct; opts.Vth=2; opts.tref=.01; 15 | % plot(siegert(x,w,P)'); xlabel 'input rate';ylabel 'output rate'; 16 | % 17 | % 18 | % Input Dim Description 19 | % x [nIn,nSamples] set of input vectors representing Poisson firing rates of inputs 20 | % w [nOut,nIn] weight matrix 21 | % P - structure identifying neuron parameters (see below). 22 | % 23 | % Ouput Dims Description 24 | % y [nOut,nSamples] set of output vectors representing Poisson output rate 25 | % 26 | % P is a structure array as fields of P 27 | % Vrest Resting Potential 28 | % Vth Threshold. This can be a scalar or a vector with 1 element per unit. 29 | % Vreset Post-Spike Reset Potential 30 | % taum Membrane Time Constent 31 | % tausyn Synaptic response time constant. LIF-equivalent synapic response is assumed to follow exp(x/tausyn)-exp(-x/(2*tausyn)) 32 | % tref Absolute refractory Time 33 | % normalize Normalize output to the maximum possible firing rate (1/tref) 34 | % nPoints Number of points in integral approximation 35 | % polyapprox Boolean indicating whether to approximate with a polynomial. Generally it's best to leave this true. 36 | % imagrownadult Don't waste time checking input to ensure that all inputs are positive. The caller can do this. 37 | % 38 | % See Florian Jug's poster explaining the Siegert neuron at: 39 | % http://www.cadmo.ethz.ch/as/people/members/fjug/personal_home/posters/2012_SiegertPoster.pdf 40 | % 41 | % Peter 42 | % oconnorp ..at.. ethz ..dot.. ch 43 | % 44 | 45 | 46 | %% Process Inputs 47 | % Add noise 48 | x = max(x + opts.temp * randn(size(x)), 0); 49 | 50 | % Check inputs 51 | k=sqrt( opts.tau_s / opts.tau_m); 52 | assert( size(w,2) == size(x,1),'Weight matrix columns must equal input matrix rows.'); 53 | assert( k < 1,'Tau_m must be greater than tau_s.'); 54 | assert( all(x(:) >= 0),'Input rate cannot be negative.'); 55 | %% Perform computation! 56 | % State constants 57 | opts.Vrest = 0; 58 | opts.Vreset = 0; 59 | opts.nPoints = 3; 60 | opts.polyapprox = 1; 61 | 62 | % Compute intermediates 63 | Y = opts.Vrest + opts.tau_m * w * x; 64 | G = sqrt(opts.tau_m * w.^2 * x / 2); 65 | 66 | % Load constant 67 | persistent SIEGERT_GAM; 68 | if(isempty(SIEGERT_GAM)) 69 | SIEGERT_GAM = abs(zeta(0.5)); 70 | end 71 | gam = SIEGERT_GAM; 72 | 73 | % Approximate the integral 74 | tmp = reshape(bsxfun(@plus, bsxfun(@times, ... 75 | meshgrid(1:opts.nPoints, 0:size(w,1)-1), ... 76 | opts.v_thr(:)-opts.Vreset(:))/(opts.nPoints-1), ... 77 | opts.Vreset), ... 78 | [size(w,1) 1 opts.nPoints]); 79 | u = bsxfun(@plus, (k * gam) * G, tmp); 80 | 81 | du = (opts.v_thr - opts.Vreset) / (opts.nPoints - 1); 82 | YminusuoverGroot2 = bsxfun(@rdivide, bsxfun(@minus, Y, u) , G * sqrt(2)); 83 | 84 | % Approximate the integral. 85 | % Note: erfcx is a more numerically stable version of exp(x.^2)*erfc(x) 86 | if opts.polyapprox 87 | z = erfcx(YminusuoverGroot2); 88 | % Apply Simpon's rule to approximate the integral. Generally, a good 89 | % approximation as the integration curve is close to a parabola. 90 | integral = bsxfun(@times, (opts.v_thr(:) - opts.Vreset(:)) / 6, ... 91 | z(:,:,1) + 4 * z(:,:,2) + z(:,:,3) ); 92 | else 93 | integral = du * sum(erfcx(YminusuoverGroot2), 3); 94 | end 95 | 96 | % Compute the output 97 | y = 1 ./ (opts.t_ref + (opts.tau_m ./ G) .* (sqrt(pi / 2)) .* integral); 98 | 99 | % Normalize to 1/t_ref 100 | y = y * opts.t_ref; -------------------------------------------------------------------------------- /struct2xml.m: -------------------------------------------------------------------------------- 1 | function varargout = struct2xml( s, varargin ) 2 | %Convert a MATLAB structure into a xml file 3 | % [ ] = struct2xml( s, file ) 4 | % xml = struct2xml( s ) 5 | % 6 | % A structure containing: 7 | % s.XMLname.Attributes.attrib1 = "Some value"; 8 | % s.XMLname.Element.Text = "Some text"; 9 | % s.XMLname.DifferentElement{1}.Attributes.attrib2 = "2"; 10 | % s.XMLname.DifferentElement{1}.Text = "Some more text"; 11 | % s.XMLname.DifferentElement{2}.Attributes.attrib3 = "2"; 12 | % s.XMLname.DifferentElement{2}.Attributes.attrib4 = "1"; 13 | % s.XMLname.DifferentElement{2}.Text = "Even more text"; 14 | % 15 | % Will produce: 16 | % 17 | % Some text 18 | % Some more text 19 | % Even more text 20 | % 21 | % 22 | % Please note that the following strings are substituted 23 | % '_dash_' by '-', '_colon_' by ':' and '_dot_' by '.' 24 | % 25 | % Written by W. Falkena, ASTI, TUDelft, 27-08-2010 26 | % On-screen output functionality added by P. Orth, 01-12-2010 27 | % Multiple space to single space conversion adapted for speed by T. Lohuis, 11-04-2011 28 | % Val2str subfunction bugfix by H. Gsenger, 19-9-2011 29 | 30 | if (nargin ~= 2) 31 | if(nargout ~= 1 || nargin ~= 1) 32 | error(['Supported function calls:' sprintf('\n')... 33 | '[ ] = struct2xml( s, file )' sprintf('\n')... 34 | 'xml = struct2xml( s )']); 35 | end 36 | end 37 | 38 | if(nargin == 2) 39 | file = varargin{1}; 40 | 41 | if (isempty(file)) 42 | error('Filename can not be empty'); 43 | end 44 | 45 | if (isempty(strfind(file,'.xml'))) 46 | file = [file '.xml']; 47 | end 48 | end 49 | 50 | if (~isstruct(s)) 51 | error([inputname(1) ' is not a structure']); 52 | end 53 | 54 | if (length(fieldnames(s)) > 1) 55 | error(['Error processing the structure:' sprintf('\n') 'There should be a single field in the main structure.']); 56 | end 57 | xmlname = fieldnames(s); 58 | xmlname = xmlname{1}; 59 | 60 | %substitute special characters 61 | xmlname_sc = xmlname; 62 | xmlname_sc = strrep(xmlname_sc,'_dash_','-'); 63 | xmlname_sc = strrep(xmlname_sc,'_colon_',':'); 64 | xmlname_sc = strrep(xmlname_sc,'_dot_','.'); 65 | 66 | %create xml structure 67 | docNode = com.mathworks.xml.XMLUtils.createDocument(xmlname_sc); 68 | 69 | %process the rootnode 70 | docRootNode = docNode.getDocumentElement; 71 | 72 | %append childs 73 | parseStruct(s.(xmlname),docNode,docRootNode,[inputname(1) '.' xmlname '.']); 74 | 75 | if(nargout == 0) 76 | %save xml file 77 | xmlwrite(file,docNode); 78 | else 79 | varargout{1} = xmlwrite(docNode); 80 | end 81 | end 82 | 83 | % ----- Subfunction parseStruct ----- 84 | function [] = parseStruct(s,docNode,curNode,pName) 85 | 86 | fnames = fieldnames(s); 87 | for i = 1:length(fnames) 88 | curfield = fnames{i}; 89 | 90 | %substitute special characters 91 | curfield_sc = curfield; 92 | curfield_sc = strrep(curfield_sc,'_dash_','-'); 93 | curfield_sc = strrep(curfield_sc,'_colon_',':'); 94 | curfield_sc = strrep(curfield_sc,'_dot_','.'); 95 | 96 | if (strcmp(curfield,'Attributes')) 97 | %Attribute data 98 | if (isstruct(s.(curfield))) 99 | attr_names = fieldnames(s.Attributes); 100 | for a = 1:length(attr_names) 101 | cur_attr = attr_names{a}; 102 | 103 | %substitute special characters 104 | cur_attr_sc = cur_attr; 105 | cur_attr_sc = strrep(cur_attr_sc,'_dash_','-'); 106 | cur_attr_sc = strrep(cur_attr_sc,'_colon_',':'); 107 | cur_attr_sc = strrep(cur_attr_sc,'_dot_','.'); 108 | 109 | [cur_str,succes] = val2str(s.Attributes.(cur_attr)); 110 | if (succes) 111 | curNode.setAttribute(cur_attr_sc,cur_str); 112 | else 113 | disp(['Warning. The text in ' pName curfield '.' cur_attr ' could not be processed.']); 114 | end 115 | end 116 | else 117 | disp(['Warning. The attributes in ' pName curfield ' could not be processed.']); 118 | disp(['The correct syntax is: ' pName curfield '.attribute_name = ''Some text''.']); 119 | end 120 | elseif (strcmp(curfield,'Text')) 121 | %Text data 122 | [txt,succes] = val2str(s.Text); 123 | if (succes) 124 | curNode.appendChild(docNode.createTextNode(txt)); 125 | else 126 | disp(['Warning. The text in ' pName curfield ' could not be processed.']); 127 | end 128 | else 129 | %Sub-element 130 | if (isstruct(s.(curfield))) 131 | %single element 132 | curElement = docNode.createElement(curfield_sc); 133 | curNode.appendChild(curElement); 134 | parseStruct(s.(curfield),docNode,curElement,[pName curfield '.']) 135 | elseif (iscell(s.(curfield))) 136 | %multiple elements 137 | for c = 1:length(s.(curfield)) 138 | curElement = docNode.createElement(curfield_sc); 139 | curNode.appendChild(curElement); 140 | if (isstruct(s.(curfield){c})) 141 | parseStruct(s.(curfield){c},docNode,curElement,[pName curfield '{' num2str(c) '}.']) 142 | else 143 | disp(['Warning. The cell ' pName curfield '{' num2str(c) '} could not be processed, since it contains no structure.']); 144 | end 145 | end 146 | else 147 | %eventhough the fieldname is not text, the field could 148 | %contain text. Create a new element and use this text 149 | curElement = docNode.createElement(curfield_sc); 150 | curNode.appendChild(curElement); 151 | [txt,succes] = val2str(s.(curfield)); 152 | if (succes) 153 | curElement.appendChild(docNode.createTextNode(txt)); 154 | else 155 | disp(['Warning. The text in ' pName curfield ' could not be processed.']); 156 | end 157 | end 158 | end 159 | end 160 | end 161 | 162 | %----- Subfunction val2str ----- 163 | function [str,succes] = val2str(val) 164 | 165 | succes = true; 166 | str = []; 167 | 168 | if (isempty(val)) 169 | return; %bugfix from H. Gsenger 170 | elseif (ischar(val)) 171 | %do nothing 172 | elseif (isnumeric(val)) 173 | val = num2str(val); 174 | else 175 | succes = false; 176 | end 177 | 178 | if (ischar(val)) 179 | %add line breaks to all lines except the last (for multiline strings) 180 | lines = size(val,1); 181 | val = [val char(sprintf('\n')*[ones(lines-1,1);0])]; 182 | 183 | %transpose is required since indexing (i.e., val(nonspace) or val(:)) produces a 1-D vector. 184 | %This should be row based (line based) and not column based. 185 | valt = val'; 186 | 187 | remove_multiple_white_spaces = true; 188 | if (remove_multiple_white_spaces) 189 | %remove multiple white spaces using isspace, suggestion of T. Lohuis 190 | whitespace = isspace(val); 191 | nonspace = (whitespace + [zeros(lines,1) whitespace(:,1:end-1)])~=2; 192 | nonspace(:,end) = [ones(lines-1,1);0]; %make sure line breaks stay intact 193 | str = valt(nonspace'); 194 | else 195 | str = valt(:); 196 | end 197 | end 198 | end 199 | -------------------------------------------------------------------------------- /visualize.m: -------------------------------------------------------------------------------- 1 | function r=visualize(X, mm, s1, s2) 2 | %FROM RBMLIB http://code.google.com/p/matrbm/ 3 | %Visualize weights X. If the function is called as a void method, 4 | %it does the plotting. But if the function is assigned to a variable 5 | %outside of this code, the formed image is returned instead. 6 | if ~exist('mm','var') 7 | mm = [min(X(:)) max(X(:))]; 8 | end 9 | if ~exist('s1','var') 10 | s1 = 0; 11 | end 12 | if ~exist('s2','var') 13 | s2 = 0; 14 | end 15 | 16 | [D,N]= size(X); 17 | s=sqrt(D); 18 | if s==floor(s) || (s1 ~=0 && s2 ~=0) 19 | if (s1 ==0 || s2 ==0) 20 | s1 = s; s2 = s; 21 | end 22 | %its a square, so data is probably an image 23 | num=ceil(sqrt(N)); 24 | a=mm(2)*ones(num*s2+num-1,num*s1+num-1); 25 | x=0; 26 | y=0; 27 | for i=1:N 28 | im = reshape(X(:,i),s1,s2)'; 29 | a(x*s2+1+x : x*s2+s2+x, y*s1+1+y : y*s1+s1+y)=im; 30 | x=x+1; 31 | if(x>=num) 32 | x=0; 33 | y=y+1; 34 | end 35 | end 36 | d=true; 37 | else 38 | %there is not much we can do 39 | a=X; 40 | end 41 | 42 | %return the image, or plot the image 43 | if nargout==1 44 | r=a; 45 | else 46 | 47 | imagesc(a, [mm(1) mm(2)]); 48 | axis equal 49 | colormap gray 50 | 51 | end 52 | --------------------------------------------------------------------------------