├── doc ├── MILA.png ├── copy.png └── hrr.png ├── 20_images_from_imagenet.npy ├── README.md ├── .gitignore ├── LICENSE ├── utils.py ├── holographic_memory.py ├── main.py └── bricks.py /doc/MILA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpezeshki/Associative_LSTM/HEAD/doc/MILA.png -------------------------------------------------------------------------------- /doc/copy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpezeshki/Associative_LSTM/HEAD/doc/copy.png -------------------------------------------------------------------------------- /doc/hrr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpezeshki/Associative_LSTM/HEAD/doc/hrr.png -------------------------------------------------------------------------------- /20_images_from_imagenet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mpezeshki/Associative_LSTM/HEAD/20_images_from_imagenet.npy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![MILA Logo](/doc/MILA.png) 2 | 3 | # Associative Long Short-Term Memory (LSTM) 4 | 5 | A Blocks and Theano implementation of Associative LSTM ([arXiv](http://arxiv.org/abs/1602.03032)). 6 | 7 | 8 | ## How to run 9 | 10 | To test the Holographic Reduced Representations (HRR), run ```holographic_memory.py```. 11 | ![HRR](/doc/hrr.png) 12 | 13 | 14 | To get the results on eposodic copy task, run ```main.py```. 15 | ![copy](/doc/copy.png) 16 | 17 | 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | #Ipython Notebook 62 | .ipynb_checkpoints 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Mohammad Pezeshki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import numpy as np 3 | from blocks.initialization import NdarrayInitialization 4 | import logging 5 | from blocks.extensions import SimpleExtension 6 | logger = logging.getLogger('main.utils') 7 | 8 | 9 | class SaveLog(SimpleExtension): 10 | def __init__(self, show=None, **kwargs): 11 | super(SaveLog, self).__init__(**kwargs) 12 | self.add_condition(('before_training',), self.do) 13 | self.add_condition(('after_training',), self.do) 14 | self.add_condition(('on_interrupt',), self.do) 15 | 16 | def do(self, which_callback, *args): 17 | epoch = self.main_loop.status['iterations_done'] 18 | current_row = self.main_loop.log.current_row 19 | logger.info("\niterations_done:%d" % epoch) 20 | for element in current_row: 21 | logger.info(str(element) + ":%f" % current_row[element]) 22 | 23 | 24 | class Glorot(NdarrayInitialization): 25 | def generate(self, rng, shape): 26 | if len(shape) == 2: 27 | input_size, output_size = shape 28 | 29 | # if it is lstm's concatenated weight 30 | if (input_size * 4.5 == output_size): 31 | print 'Glorot 2' 32 | output_size = output_size / 4.5 33 | elif (input_size * 4 == output_size): 34 | print 'Glorot 1' 35 | output_size = output_size / 4 36 | 37 | high = np.sqrt(6) / np.sqrt(input_size + output_size) 38 | m = rng.uniform(-high, high, size=shape) 39 | return m.astype(theano.config.floatX) 40 | -------------------------------------------------------------------------------- /holographic_memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | import theano.tensor as T 4 | # import matplotlib 5 | # matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | B = 10 9 | F = 110 * 110 * 3 10 | C = 20 11 | 12 | # shape: C x F/2 13 | permutations = [] 14 | indices = np.arange(F / 2) 15 | for i in range(C): 16 | np.random.shuffle(indices) 17 | permutations.append(np.concatenate( 18 | [indices, 19 | [ind + F / 2 for ind in indices]])) 20 | # C x F (numpy) 21 | PERMUTATIONS = np.vstack(permutations) 22 | 23 | 24 | # input: B x F 25 | # output: C x B x F 26 | def permute(input): 27 | inputs_permuted = [] 28 | for i in range(PERMUTATIONS.shape[0]): 29 | inputs_permuted.append( 30 | input[:, PERMUTATIONS[i]].dimshuffle('x', 0, 1)) 31 | return T.concatenate(inputs_permuted, axis=0) 32 | 33 | 34 | # r: C x B x F 35 | # u: if mem: C x 1 x F 36 | # u: if value: 1 x B x F 37 | def complex_mult(r, u, inverse_r=False, moduli_1=False): 38 | _, _, F = u.shape 39 | r_rl = r[:, :, :F / 2] 40 | r_im = r[:, :, F / 2:] 41 | if inverse_r: 42 | if moduli_1: 43 | r_im = -r_im 44 | else: 45 | tmp = r_rl / (r_rl ** 2 + r_im ** 2) 46 | r_im = -r_im / (r_rl ** 2 + r_im ** 2) 47 | r_rl = tmp 48 | u_rl = u[:, :, :F / 2] 49 | u_im = u[:, :, F / 2:] 50 | res_rl = r_rl * u_rl - r_im * u_im 51 | res_im = r_rl * u_im + r_im * u_rl 52 | res = T.concatenate([res_rl, res_im], axis=2) 53 | # C x B x F 54 | return res 55 | 56 | 57 | # key: C x B x F 58 | # mem: C x F 59 | def read(key, mem): 60 | value = complex_mult( 61 | permute(key), 62 | mem.dimshuffle(0, 'x', 1), 63 | inverse_r=True, moduli_1=True) 64 | return value.mean(axis=0) 65 | 66 | 67 | # key: C x B x F 68 | # value: B x F 69 | # mem: C x F 70 | def write(key, value): 71 | coded_value = complex_mult(permute(key), value.dimshuffle('x', 0, 1)) 72 | # C x F 73 | return coded_value.sum(axis=1) 74 | 75 | if __name__ == "__main__": 76 | # B x F 77 | key = T.matrix('key') 78 | # B x F 79 | value = T.matrix('value') 80 | # C x F 81 | mem = T.matrix('mem') 82 | 83 | read_func = theano.function([key, mem], read(key, mem)) 84 | write_func = theano.function([key, value], write(key, value)) 85 | 86 | # shape: 20 x 110 x 110 x 3 87 | data = np.load('20_images_from_imagenet.npy')[:B] 88 | VALUES = data.reshape(B, F) - np.mean(data.reshape(B, F), 89 | axis=1, keepdims=True) 90 | 91 | phis = np.random.random((B, F / 2)) * 2 * np.pi 92 | KEYS = np.concatenate([np.cos(phis), np.sin(phis)], axis=1) 93 | 94 | MEM = write_func(KEYS, VALUES) 95 | 96 | all_imgs = read_func(KEYS, MEM) 97 | 98 | VALUES = VALUES + np.mean(data.reshape(B, F), axis=1, keepdims=True) 99 | VALUES = VALUES.reshape(B, 110, 110, 3) 100 | VALUES = np.swapaxes(VALUES, 0, 1) 101 | VALUES = np.reshape(VALUES, (110, 110 * B, 3)) 102 | plt.imshow(VALUES[:, :110 * B]) 103 | plt.show() 104 | 105 | all_imgs = all_imgs + np.mean(data.reshape(B, F), axis=1, keepdims=True) 106 | all_imgs = all_imgs.reshape(B, 110, 110, 3) 107 | all_imgs = np.swapaxes(all_imgs, 0, 1) 108 | all_imgs = np.reshape(all_imgs, (110, 110 * B, 3)) 109 | plt.imshow(all_imgs[:, :110 * B]) 110 | plt.show() 111 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import os 3 | import numpy as np 4 | from theano import tensor 5 | from blocks.initialization import Constant 6 | from blocks.bricks import Linear, Tanh, NDimensionalSoftmax 7 | from bricks import AssociativeLSTM, LSTM 8 | from fuel.datasets import IterableDataset 9 | from fuel.streams import DataStream 10 | from blocks.model import Model 11 | from blocks.bricks.cost import CategoricalCrossEntropy 12 | from blocks.algorithms import (GradientDescent, 13 | StepClipping, CompositeRule, 14 | Adam) 15 | from blocks.extensions.monitoring import TrainingDataMonitoring 16 | from blocks.main_loop import MainLoop 17 | from blocks.extensions import Printing 18 | from blocks.graph import ComputationGraph 19 | import logging 20 | from utils import SaveLog, Glorot 21 | logger = logging.getLogger('main') 22 | logger.setLevel(logging.INFO) 23 | floatX = theano.config.floatX 24 | 25 | 26 | def get_episodic_copy_data(time_steps, n_data, n_sequence, batch_size): 27 | seq = np.random.randint(1, high=9, size=(n_data, n_sequence)) 28 | zeros1 = np.zeros((n_data, time_steps - 1)) 29 | zeros2 = np.zeros((n_data, time_steps)) 30 | marker = 9 * np.ones((n_data, 1)) 31 | zeros3 = np.zeros((n_data, n_sequence)) 32 | 33 | x = np.concatenate((seq, zeros1, marker, zeros3), axis=1).astype('int32') 34 | y = np.concatenate((zeros3, zeros2, seq), axis=1).astype('int32') 35 | 36 | x = x.reshape(n_data / batch_size, batch_size, 1, -1) 37 | x = np.swapaxes(x, 2, 3) 38 | x = np.swapaxes(x, 1, 2) 39 | x = x[..., 0] 40 | z = np.zeros(x.shape) 41 | one_hot_x = np.zeros((x.shape[0], x.shape[1], x.shape[2], 10)) 42 | for c in range(10): 43 | z = z * 0 44 | z[np.where(x == c)] = 1 45 | one_hot_x[..., c] += z 46 | 47 | y = y.reshape(n_data / batch_size, batch_size, 1, -1) 48 | y = np.swapaxes(y, 2, 3) 49 | y = np.swapaxes(y, 1, 2) 50 | y = y[..., 0] 51 | z = np.zeros(y.shape) 52 | one_hot_y = np.zeros((y.shape[0], y.shape[1], y.shape[2], 9)) 53 | for c in range(9): 54 | z = z * 0 55 | z[np.where(y == c)] = 1 56 | one_hot_y[..., c] += z 57 | 58 | return one_hot_x, one_hot_y 59 | 60 | batch_size = 2 61 | num_copies = 1 62 | x_dim = 10 63 | h_dim = 128 64 | o_dim = 9 65 | model = 'alstm' 66 | if model == 'lstm': 67 | coeff = 4 68 | bias = 0 69 | save_path = 'lstm_path' 70 | elif model == 'lstm_f1': 71 | coeff = 4 72 | bias = 1 73 | save_path = 'lstm_f1_path' 74 | elif model == 'alstm': 75 | coeff = 4.5 76 | use_W_xu = False 77 | save_path = 'alstm_path' 78 | 79 | print 'Building model ...' 80 | # T x B x F 81 | x = tensor.tensor3('x', dtype=floatX) 82 | # T x B x F' 83 | y = tensor.tensor3('y', dtype=floatX) 84 | 85 | x_to_h = Linear(name='x_to_h', 86 | input_dim=x_dim, 87 | output_dim=coeff * h_dim) 88 | x_transform = x_to_h.apply(x) 89 | if model == 'alstm': 90 | lstm = AssociativeLSTM(activation=Tanh(), 91 | dim=h_dim, 92 | num_copies=num_copies, 93 | use_W_xu=use_W_xu, 94 | name="lstm") 95 | else: 96 | lstm = LSTM(activation=Tanh(), 97 | dim=h_dim, 98 | bias=bias, 99 | name="lstm") 100 | h, c = lstm.apply(x_transform) 101 | h_to_o = Linear(name='h_to_o', 102 | input_dim=h_dim, 103 | output_dim=o_dim) 104 | o = h_to_o.apply(h) 105 | o = NDimensionalSoftmax().apply(o, extra_ndim=1) 106 | 107 | for brick in (lstm, x_to_h, h_to_o): 108 | brick.weights_init = Glorot() 109 | brick.biases_init = Constant(0) 110 | brick.initialize() 111 | 112 | cost = CategoricalCrossEntropy().apply(y, o) 113 | cost.name = 'CE' 114 | 115 | print 'Bulding training process...' 116 | shapes = [] 117 | for param in ComputationGraph(cost).parameters: 118 | # shapes.append((param.name, param.eval().shape)) 119 | shapes.append(np.prod(list(param.eval().shape))) 120 | print "Total number of parameters: " + str(np.sum(shapes)) 121 | 122 | if not os.path.exists(save_path): 123 | os.makedirs(save_path) 124 | log_path = save_path + '/log.txt' 125 | fh = logging.FileHandler(filename=log_path) 126 | fh.setLevel(logging.DEBUG) 127 | logger.addHandler(fh) 128 | 129 | algorithm = GradientDescent(cost=cost, 130 | parameters=ComputationGraph(cost).parameters, 131 | step_rule=CompositeRule([StepClipping(10.0), 132 | Adam(1e-3)])) # 3e-4 133 | monitor_cost = TrainingDataMonitoring([cost], 134 | prefix='train', 135 | after_epoch=False, 136 | before_training=True, 137 | every_n_batches=1000) 138 | 139 | data = get_episodic_copy_data(100, int(1e6), 10, batch_size) 140 | dataset = IterableDataset({'x': data[0].astype('int8'), 141 | 'y': data[1].astype('int8')}) 142 | stream = DataStream(dataset) 143 | 144 | model = Model(cost) 145 | main_loop = MainLoop(data_stream=stream, algorithm=algorithm, 146 | extensions=[monitor_cost, 147 | Printing(after_epoch=False, 148 | every_n_batches=1000), 149 | SaveLog(every_n_batches=1000)], 150 | model=model) 151 | 152 | print 'Starting training ...' 153 | main_loop.run() 154 | -------------------------------------------------------------------------------- /bricks.py: -------------------------------------------------------------------------------- 1 | from blocks.bricks import Initializable, Tanh, Logistic 2 | from blocks.bricks.base import application, lazy 3 | from blocks.roles import add_role, WEIGHT, INITIAL_STATE 4 | from blocks.utils import shared_floatx_nans, shared_floatx_zeros 5 | from blocks.bricks.recurrent import BaseRecurrent, recurrent 6 | import theano.tensor as tensor 7 | import numpy 8 | from holographic_memory import complex_mult 9 | 10 | 11 | class AssociativeLSTM(BaseRecurrent, Initializable): 12 | @lazy(allocation=['dim']) 13 | def __init__(self, dim, num_copies, use_W_xu, activation=None, 14 | gate_activation=None, **kwargs): 15 | self.dim = dim 16 | self.num_copies = num_copies 17 | self.use_W_xu = use_W_xu 18 | 19 | # shape: C x F/2 20 | permutations = [] 21 | indices = numpy.arange(self.dim / 2) 22 | for i in range(self.num_copies): 23 | numpy.random.shuffle(indices) 24 | permutations.append(numpy.concatenate( 25 | [indices, 26 | [ind + self.dim / 2 for ind in indices]])) 27 | # C x F (numpy) 28 | self.permutations = numpy.vstack(permutations) 29 | 30 | if not activation: 31 | activation = Tanh() 32 | if not gate_activation: 33 | gate_activation = Logistic() 34 | self.activation = activation 35 | self.gate_activation = gate_activation 36 | 37 | children = ([self.activation, self.gate_activation] + 38 | kwargs.get('children', [])) 39 | super(AssociativeLSTM, self).__init__(children=children, **kwargs) 40 | 41 | def get_dim(self, name): 42 | if name == 'inputs': 43 | return self.dim * 4.5 44 | if name in ['states', 'cells']: 45 | return self.dim 46 | if name == 'mask': 47 | return 0 48 | return super(AssociativeLSTM, self).get_dim(name) 49 | 50 | def _allocate(self): 51 | self.W_state = shared_floatx_nans((self.dim, 4.5 * self.dim), 52 | name='W_state') 53 | # The underscore is required to prevent collision with 54 | # the `initial_state` application method 55 | self.initial_state_ = shared_floatx_zeros((self.dim,), 56 | name="initial_state") 57 | self.initial_cells = shared_floatx_zeros((self.num_copies, self.dim), 58 | name="initial_cells") 59 | add_role(self.W_state, WEIGHT) 60 | # add_role(self.initial_state_, INITIAL_STATE) 61 | # add_role(self.initial_cells, INITIAL_STATE) 62 | 63 | self.parameters = [self.W_state] 64 | 65 | def _initialize(self): 66 | self.weights_init.initialize(self.parameters[0], self.rng) 67 | 68 | # The activation function that bound values between 0 and 1 69 | # input_: B x F 70 | def bound(self, input_): 71 | sq = input_ ** 2 72 | d = tensor.sqrt(tensor.maximum( 73 | 1, sq[:, :self.dim / 2] + sq[:, self.dim / 2:])) 74 | d = tensor.concatenate([d, d], axis=1) 75 | return input_ / d 76 | 77 | # input: B x F 78 | # output: C x B x F 79 | def permute(self, input): 80 | inputs_permuted = [] 81 | for i in range(self.permutations.shape[0]): 82 | inputs_permuted.append( 83 | input[:, self.permutations[i]].dimshuffle('x', 0, 1)) 84 | return tensor.concatenate(inputs_permuted, axis=0) 85 | 86 | @recurrent(sequences=['inputs', 'mask'], states=['states', 'cells'], 87 | contexts=[], outputs=['states', 'cells']) 88 | def apply(self, inputs, states, cells, mask=None): 89 | def slice_(x, no): 90 | # Gates dimension is dim/2. 91 | if no in [0, 1, 2]: 92 | return x[:, no * self.dim / 2: (no + 1) * self.dim / 2] 93 | # Keys and u dimension is dim. 94 | elif no in [3, 4, 5]: 95 | return x[:, int((no - 1.5) * self.dim): 96 | int((no - 0.5) * self.dim)] 97 | 98 | activation = tensor.dot(states, self.W_state) + inputs 99 | 100 | in_gate = self.gate_activation.apply(slice_(activation, 0)) 101 | in_gate = tensor.concatenate([in_gate, in_gate], axis=1) 102 | forget_gate = self.gate_activation.apply(slice_(activation, 1)) 103 | forget_gate = tensor.concatenate([forget_gate, forget_gate], axis=1) 104 | out_gate = self.gate_activation.apply(slice_(activation, 2)) 105 | out_gate = tensor.concatenate([out_gate, out_gate], axis=1) 106 | 107 | in_key = self.bound(slice_(activation, 3)) 108 | # B x F --> C x B x F 109 | in_keys = self.permute(in_key) 110 | out_key = self.bound(slice_(activation, 4)) 111 | # B x F --> C x B x F 112 | out_keys = self.permute(out_key) 113 | 114 | if self.use_W_xu: 115 | u = self.bound(slice_(activation, 5)) 116 | else: 117 | u = self.bound(slice_( 118 | tensor.dot(states, self.W_state * 0.00001) + inputs, 5)) 119 | 120 | # 1 x B x F , C x B x F --> C x B x F 121 | f_x_c = forget_gate.dimshuffle('x', 0, 1) * cells 122 | # B x F , B x F --> 1 x B x F 123 | i_x_u = (in_gate * u).dimshuffle('x', 0, 1) 124 | next_cells = (f_x_c + complex_mult(in_keys, i_x_u)) 125 | 126 | # C x B x F , C x B x F --> C x B x F 127 | o_x_c = complex_mult(out_keys, next_cells) 128 | next_states = out_gate * self.bound(tensor.mean(o_x_c, axis=0)) 129 | 130 | if mask: 131 | next_states = (mask[:, None] * next_states + 132 | (1 - mask[:, None]) * states) 133 | next_cells = (mask[:, None] * next_cells + 134 | (1 - mask[:, None]) * cells) 135 | 136 | return next_states, next_cells 137 | 138 | @application(outputs=apply.states) 139 | def initial_states(self, batch_size, *args, **kwargs): 140 | return [tensor.repeat(self.initial_state_[None, :], batch_size, 0), 141 | tensor.repeat(self.initial_cells[:, None, :], batch_size, 1)] 142 | 143 | 144 | class LSTM(BaseRecurrent, Initializable): 145 | @lazy(allocation=['dim']) 146 | def __init__(self, dim, bias, activation=None, 147 | gate_activation=None, **kwargs): 148 | self.dim = dim 149 | self.bias = bias 150 | 151 | if not activation: 152 | activation = Tanh() 153 | if not gate_activation: 154 | gate_activation = Logistic() 155 | self.activation = activation 156 | self.gate_activation = gate_activation 157 | 158 | children = ([self.activation, self.gate_activation] + 159 | kwargs.get('children', [])) 160 | super(LSTM, self).__init__(children=children, **kwargs) 161 | 162 | def get_dim(self, name): 163 | if name == 'inputs': 164 | return self.dim * 4 165 | if name in ['states', 'cells']: 166 | return self.dim 167 | if name == 'mask': 168 | return 0 169 | return super(LSTM, self).get_dim(name) 170 | 171 | def _allocate(self): 172 | self.W_state = shared_floatx_nans((self.dim, 4 * self.dim), 173 | name='W_state') 174 | # The underscore is required to prevent collision with 175 | # the `initial_state` application method 176 | self.initial_state_ = shared_floatx_zeros((self.dim,), 177 | name="initial_state") 178 | self.initial_cells = shared_floatx_zeros((self.dim,), 179 | name="initial_cells") 180 | add_role(self.W_state, WEIGHT) 181 | add_role(self.initial_state_, INITIAL_STATE) 182 | add_role(self.initial_cells, INITIAL_STATE) 183 | 184 | self.parameters = [ 185 | self.W_state, self.initial_state_, self.initial_cells] 186 | 187 | def _initialize(self): 188 | for weights in self.parameters[:1]: 189 | self.weights_init.initialize(weights, self.rng) 190 | 191 | @recurrent(sequences=['inputs', 'mask'], states=['states', 'cells'], 192 | contexts=[], outputs=['states', 'cells']) 193 | def apply(self, inputs, states, cells, mask=None): 194 | def slice_last(x, no): 195 | return x[:, no * self.dim: (no + 1) * self.dim] 196 | 197 | activation = tensor.dot(states, self.W_state) + inputs 198 | in_gate = self.gate_activation.apply( 199 | slice_last(activation, 0)) 200 | pre = slice_last(activation, 1) 201 | forget_gate = self.gate_activation.apply( 202 | pre + self.bias * tensor.ones_like(pre)) 203 | next_cells = ( 204 | forget_gate * cells + 205 | in_gate * self.activation.apply(slice_last(activation, 2))) 206 | out_gate = self.gate_activation.apply( 207 | slice_last(activation, 3)) 208 | next_states = out_gate * self.activation.apply(next_cells) 209 | 210 | if mask: 211 | next_states = (mask[:, None] * next_states + 212 | (1 - mask[:, None]) * states) 213 | next_cells = (mask[:, None] * next_cells + 214 | (1 - mask[:, None]) * cells) 215 | 216 | return next_states, next_cells 217 | 218 | @application(outputs=apply.states) 219 | def initial_states(self, batch_size, *args, **kwargs): 220 | return [tensor.repeat(self.initial_state_[None, :], batch_size, 0), 221 | tensor.repeat(self.initial_cells[None, :], batch_size, 0)] 222 | --------------------------------------------------------------------------------