├── .gitignore
├── README.md
├── about.md
├── data
├── __init__.py
├── mnist.py
└── text.py
├── layers
├── __init__.py
├── activations.py
├── base.py
├── cwrnn.py
├── cwrnn_l1.py
├── cwrnn_norm.py
├── layer_utils.py
├── linear.py
├── softmax_ce_loss.py
└── tests.py
├── network.py
├── train_mnist.py
└── train_ptb.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | *.pyc
3 |
4 | results/
5 |
6 | datasets/mnist.pkl.gz
7 |
8 | data/mnist.pkl.gz
9 |
10 | data/ptb.txt
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # lrh
2 | Learning RNN Hierarchies
3 |
4 | `about.md` has the basic details of what the code is trying to do and why.
5 |
6 | * Layers are defined in `layers/` along with tests
7 | * All layers derive from abstract base class defined in `base.py`
8 | * All data and data prepping scripts are in `data/` folder
9 | * `network.py` has tools for taking in a list as a model and doing layer by layer forward/backward pass, getting gradients, setting/getting parameters
10 | * `train_ptb.py` trains a model on Penn Tree Bank text file, which has to be placed in the `data/` folder
11 | * `train_mnist.py` trains a model on Sequential MNIST. `mnist.pkl.gz` has to be placed in `data/` folder
12 | * As the network trains, logs are generated. Final logs and models are stored as pickle objects in `results/experiment_name`, where `experiment_name` is a string defined in `train_` scripts
13 |
14 | ##Requires:
15 | * Numpy
16 | * Scipy (for one special function to calculate entropy)
17 | * matplotlib
18 | * climin
19 |
20 |
--------------------------------------------------------------------------------
/about.md:
--------------------------------------------------------------------------------
1 | # Learning RNN Hierarchies
2 |
3 | ## Introduction
4 |
5 | Recognizing patterns in sequences requires detecting temporal **events** that are occurring at different levels of an abstraction hierarchy. An event is a simple specific pattern observed over time present in the sequence that is useful for recognizing a much more complex pattern. Lower level events are due to local structure in input stream (example: phonemes). Higher-level events could be a combination of several lower level events and even higher level events from the past (example: mood of the speaker depends on the conversational tone and words he speaks). RNNs can not only see a composition of such events like a regular NN, but they can also see the overall variation of these events over arbitrary gaps in time and hence are very powerful.
6 |
7 | In general vanilla RNNs are not that useful because they forget events from the past(belonging to any level of abstraction). This because of its multiplicative update rule for its hidden state, which is repeated over all the time steps, causes the memories of events to decay. Common and now successful approach to tackle this problem is to use the LSTM family of RNNs, which replace the multiplicative update rule with an additive update. This makes the RNN prone to explosion and makes it unstable, thus a protective gating mechanism is put in place. While this solves the Vanishing Gradients problem, a single LSTM layer won't give the best performance. There is abundant empirical evidence that suggests that stacking LSTMs (and RNNs in general) offers better performance compared to a single LSTM layer with the memory size fixed. If LSTMs can remember everything from the past and if LSTMs are already very deep in time, why stack them at all?
8 |
9 | The most intuitive and commonly given reason is that lower RNNs specialize to local events, while the higher level RNNs can focus on more abstract events. For example seq2seq architecture uses 4-stacked LSTMs for the encoder to compress the input sequence to a fixed length vector. Other possible reasons for this include ease of optimization, reduction of number parameters required per cell of memory, increased non linear depth per time step and many more (it is still an open research question). But it is clear that stacking RNNs is essential for good performance on complex tasks.
10 |
11 | If we can simultaneously do both of the following:
12 |
13 | 1. Solving the vanishing gradients problem and, at the same time
14 |
15 | 2. Make our models better at handling events in multiple levels of abstraction
16 |
17 | Using a single simpler model, such a system would be more efficient than an LSTM. The main objective of this work is to find such a model. Taking inspirations from previous methods and combining them with our novel contributions.
18 |
19 | ## Background
20 |
21 |
We can split up our big RNN into multiple smaller RNN modules. A module could either be active or inactive at a particular time step. If a module stays frequently inactive, more memory retention capability it possesses - these are slow modules. If a module stays frequently active, less memory retention capability it possesses - these are fast modules. Thus a combination of slower and faster RNNs can together retain memory for longer durations and thus make recognition of patterns based on temporally distant events possible.
22 |
23 | There have been a few attempts to do this in the past. Here we discuss their approaches, strengths and weaknesses:
24 |
25 | 1. __Chunker/Neural History Compressor (1991):__ It is a stack of simple RNNs. The lowest RNN layer gets actual inputs as input. Higher level take inputs only from the layer below it and give their outputs as inputs to the layer above it. Each of the RNN layers, starting from the lowest RNN are trained to predict the input it is going to receive in the next timestep, based on the history of inputs the RNN has received so far. This is an unsupervised step similar to greedy auto encoder training. The main trick is to activate a RNN at a level in proportion to the extent of failure by the RNN layer below it in predicting its current input. If predictions by lower RNN are frequently correct, then the RNN is rarely on, thus has longer memory. The higher-level RNN is now trained to predict its inputs from layer below it, which is only at timesteps where it failed to predict. This is done iteratively over all RNNs in a stack. *Each RNN layer has now learned to expect what is unexpected to the RNN below it*. Schmidhuber calls this history compression as predictability increases with more layers.
26 | __Pros__: Unsupervised. Triggers for higher RNNs are event driven, meaning RNN in higher layer can come in when an unexpected event occurs and gather all data it needs.
27 | __Cons__: Local predictability is necessary - not always possible. Cannot combine information from multiple levels effectively. Needs layer wise pretraining and then supervised fine tuning.
28 |
29 | 2. __Clockwork RNN (2014)__: This is a supervised variant of Chunker. RNNs are present in a cluster. Each RNN has a dedicated timer or a clock, which only activates the RNN module once per for every ___P___ time steps. P is chosen to form a hierarchy (example P_i = 2^i, where i is layer index). Further they restrict connections to go only from slow RNNs to fast RNNs and not vice versa. This is similar to Chunker, except events don’t trigger RNNs, rather clock pulses do. They activate according to a predefined period. This allows for supervised training as RNNs specialize automatically at their level.
30 | __Pros__: Supervised. Has lesser complexity than a vanilla RNN due to restricted activity and connectivity scheme.
31 | __Cons__: The major problem with this is that it requires hand engineered clock periods, which depend on the memory requirements of the task. This varies widely from task to task. Thus a lot of domain knowledge is required to setup the initial hierarchy. Further, it is not a trivial task to set this up. If there is a lot of gap between activities of 2 RNN modules, the slower RNN could miss the contents in faster RNN's memory as it would decay with time. Lastly, the connection scheme between modules is not good, both intuitively and in practice.
32 |
33 |
34 |
35 | ## Proposed Method
36 |
37 | We cast the learning process as a combination of
38 |
39 | 1. learning to do the task and
40 | 2. learning the hierarchical interconnected RNN architecture
41 |
42 | That is to come up with a model that can:
43 |
44 | 1. Learn the hierarchy
45 |
46 | 2. Learn how the modules are interconnected
47 |
48 | 3. Learn to activate based on events
49 |
50 | The methods for the last two have been developed and tested. The first one is tricky and this repo only tries to do that. Described how below
51 |
52 |
53 | ### Learning the hierarchy
54 |
55 | This is the most important and extremely challenging aspect of designing the model. Clock frequencies (inverse of clock periods) are a really good characteristic of a RNN's position in a hierarchy. A RNN having a small frequency naturally has to depend on contents stored in other (faster) RNNs as it rarely gets any input. Accordingly, this low frequency RNN learns to operate at a more abstract form of inputs, thus forming the higher levels of the hierarchy. Conversely, fast clocks form lower levels of the hierarchy. This makes the clock frequency a sufficient parameter that describes the RNN's position in a hierarchy. Thus, learning clock periods of a set of RNNs is equivalent to learning the hierarchy.
56 |
57 | This is more powerful than it seems. Learning a symmetric set of clock frequencies between 2 sets of RNNs is equivalent to learning the seq2seq model itself for example (see figure below). A stack of RNNs with continuously decreasing frequencies forms a abstraction pyramid. If this is combined with another set of RNNs, which is connected only to the top most RNN, but with continuously increasing frequencies, we now have a crude seq2seq model.
58 |
59 | 
60 |
61 | (Not claiming that this learning capability has been achieved, but just showing the representational power)
62 |
63 | Learning clock frequencies is not as trivial as learning just another parameter. Clocks used in Clockwork RNN were binary clocks. If we move to a smoother version of it, i. e the sine wave, we now have to learn the frequency of this sine wave. This sine wave represents the activation pattern of a RNN module in the hierarchy.
64 |
65 | Unfortunately, learning frequency directly is not possible. This is because of extremely large amount of local minima that is present. Consider the following example: current wave frequency is 1/4, but the required wave frequency is 1/8. If the frequency slightly decreases, to say 1/5 this frequency is actually worse than 1/4 as there is less agreement between 1/8 and 1/5 compared to 1/8 and 1/4. That is there is a local minima wherever there is an LCM between the clock periods. Thus learning frequency directly is not possible (learnt it the hard way weeks before ICML deadline)
66 |
67 | Instead of operating in amplitude-time domain, we move to amplitude-frequency domain. That is express the wave we want as a set of DCT coefficients. Perform inverse DCT to get the wave and use it to activate the modules. The error derivatives are also transferred to frequency domain during backward pass using DCT. This does not have the above problem of minima.
68 |
69 | This can be viewed as regularization of activities in the frequency domain. There can be many ways to restrict the learnt frequency to have just one major frequency component:
70 |
71 | 1. L1 penalty over coefficients
72 |
73 | 2. Softmax over the coefficients for discriminative choosing of frequencies.
74 |
75 |
76 |
77 | The code in this repo is only for this purpose. The others are not here, but in a separate repo. They have been independently tested to work, but not as a whole unit.
78 |
79 | Note: Due to some reasons, binary clocks seemed like a better fit. So instead of DCT bases, binary bases are used and this whole "transform" is just implemented as a dot product of a vector and a matrix.
80 |
81 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pranv/lrh/fec5fc6355c1fee3456ef35568815759867474f8/data/__init__.py
--------------------------------------------------------------------------------
/data/mnist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import gzip, cPickle, sys
4 |
5 | def to_categorical(y):
6 | y = np.asarray(y)
7 | Y = np.zeros((len(y), 10))
8 | for i in range(len(y)):
9 | Y[i, y[i]] = 1.
10 | return Y
11 |
12 | class loader(object):
13 | def __init__(self, batch_size=50, permuted=False):
14 | path = 'data/mnist.pkl.gz'
15 | f = gzip.open(path, 'rb')
16 | (X_train, y_train), (X_val, y_val), (X_test, y_test) = cPickle.load(f)
17 | f.close()
18 |
19 | X_train = X_train.reshape(X_train.shape[0], -1, 1)
20 | X_val = X_val.reshape(X_val.shape[0], -1, 1)
21 | X_test = X_test.reshape(X_test.shape[0], -1, 1)
22 |
23 | X_train = X_train.swapaxes(0, 1).swapaxes(1, 2)
24 | X_val = X_val.swapaxes(0, 1).swapaxes(1, 2)
25 | X_test = X_test.swapaxes(0, 1).swapaxes(1, 2)
26 |
27 | if permuted:
28 | p = range(28*28)
29 | np.random.shuffle(p)
30 | X_train = X_train[p]
31 | X_val = X_val[p]
32 | X_test = X_test[p]
33 |
34 | self.i = 0
35 |
36 | self.X_train = X_train
37 | self.X_val = X_val
38 | self.X_test = X_test
39 | self.y_train = to_categorical(y_train).T.reshape(1, 10, -1)
40 | self.y_val = to_categorical(y_val ).T.reshape(1, 10, -1)
41 | self.y_test = to_categorical(y_test).T.reshape(1, 10, -1)
42 |
43 | self.batch_size = batch_size
44 | self.permuted = permuted
45 | self.epoch = 1
46 | self.epoch_complete = False
47 |
48 | def fetch_train(self):
49 | X = self.X_train[:, :, self.i * self.batch_size: (self.i + 1) * self.batch_size]
50 | y = self.y_train[:, :, self.i * self.batch_size: (self.i + 1) * self.batch_size]
51 | self.i = (self.i + 1)
52 | if (self.i * self.batch_size) >= self.X_train.shape[2]:
53 | self.epoch_complete = True
54 | self.epoch += 1
55 | self.i = self.i % (self.X_train.shape[2] / self.batch_size)
56 | return (X, y)
57 |
58 | def fetch_val(self):
59 | return self.X_val, self.y_val
60 |
61 | def fetch_test(self):
62 | return self.X_test, self.y_test
63 |
--------------------------------------------------------------------------------
/data/text.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class OneHot(object):
4 | def __init__(self, alphabet_size, char_to_i):
5 | self.alphabet_size = alphabet_size
6 | self.matrix = np.eye(alphabet_size, dtype='uint8')
7 | self.to_i = char_to_i
8 |
9 | def __call__(self, chars):
10 | I = np.zeros((len(chars), self.alphabet_size))
11 | for c in range(len(chars)):
12 | i = self.to_i[chars[c]]
13 | I[c][i] = 1
14 | return I
15 |
16 |
17 | class UnOneHot(object):
18 | def __init__(self, i_to_char):
19 | self.to_c = i_to_char
20 |
21 | def __call__(self, vectors):
22 | chars = ''
23 | for vector in vectors:
24 | i = vector.argmax()
25 | chars += self.to_c[i]
26 | return chars
27 |
28 | class loader(object):
29 | def __init__(self, filename, sequence_length, batch_size):
30 | f = open(filename, 'r')
31 | lines = f.readlines()
32 |
33 | string = ''.join(lines)
34 |
35 | vocabulary = list(set(string))
36 | vocabulary_size = len(vocabulary)
37 | data_size = len(string)
38 |
39 | char_to_i = {ch:i for i,ch in enumerate(vocabulary)}
40 | i_to_char = {i:ch for i,ch in enumerate(vocabulary)}
41 |
42 | encoder = OneHot(vocabulary_size, char_to_i)
43 | decoder = UnOneHot(i_to_char)
44 |
45 | chars_per_batch = data_size / batch_size
46 | total_used_chars = (data_size / chars_per_batch) * chars_per_batch
47 | string = string[:total_used_chars]
48 | data_size = len(string)
49 | chars_per_batch = data_size / batch_size
50 | iterators = range(0, total_used_chars, chars_per_batch)
51 |
52 | self.sequence_length = sequence_length
53 | self.batch_size = batch_size
54 | self.string = string
55 | self.vocabulary = vocabulary
56 | self.vocabulary_size = vocabulary_size
57 | self.data_size = data_size
58 | self.char_to_i = char_to_i
59 | self.i_to_char = i_to_char
60 | self.encoder = encoder
61 | self.decoder = decoder
62 | self.chars_per_batch = chars_per_batch
63 | self.total_used_chars = total_used_chars
64 | self.string = string
65 | self.chars_per_batch = chars_per_batch
66 | self.iterators = iterators
67 |
68 | def fetch_train(self):
69 | T = self.sequence_length
70 | batch_string = ''
71 |
72 | for i in range(len(self.iterators)):
73 | batch_string += self.string[self.iterators[i]:self.iterators[i] + T]
74 | self.iterators[i] += T
75 |
76 | if self.iterators[0] + T >= self.chars_per_batch:
77 | self.iterators = range(0, self.total_used_chars, self.chars_per_batch)
78 |
79 | batch_x = self.encoder(batch_string)
80 | batch_y = self.encoder(batch_string[1:] + batch_string[0])
81 |
82 | X = batch_x.reshape((self.batch_size, T, self.vocabulary_size), order='C').swapaxes(1, 2).swapaxes(0, 2)
83 | Y = batch_y.reshape((self.batch_size, T, self.vocabulary_size), order='C').swapaxes(1, 2).swapaxes(0, 2)
84 |
85 | return X, Y
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from linear import Linear
2 | from softmax_ce_loss import SoftmaxCrossEntropyLoss
3 | from cwrnn import CWRNN
4 | from cwrnn_norm import CWRNN_NORM
5 | from cwrnn_l1 import CWRNN_L1
6 | from activations import *
7 |
--------------------------------------------------------------------------------
/layers/activations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from base import Layer
4 |
5 | class TanH(Layer):
6 | def forward(self, X):
7 | Y = np.tanh(X)
8 | self.Y = Y
9 | return Y
10 |
11 | def backward(self, dY):
12 | Y = self.Y
13 | dX = (1.0 - Y ** 2) * dY
14 | return dX
15 |
16 |
17 | class Sigmoid(Layer):
18 | def forward(self, X):
19 | Y = 1.0 / (1.0 + np.exp(-X))
20 | self.Y = Y
21 | return Y
22 |
23 | def backward(self, dY):
24 | Y = self.Y
25 | dX = Y * (1.0 - Y) * dY
26 | return dX
27 |
28 |
29 |
--------------------------------------------------------------------------------
/layers/base.py:
--------------------------------------------------------------------------------
1 | class Layer(object):
2 | def __init__(self):
3 | pass
4 |
5 | def forward(self):
6 | pass
7 |
8 | def backward(self):
9 | pass
10 |
11 | def get_params(self):
12 | pass
13 |
14 | def set_params(self):
15 | pass
16 |
17 | def get_grads(self):
18 | pass
19 |
20 | def clear_grads(self):
21 | pass
22 |
23 | def forget(self):
24 | pass
25 |
26 | def remember(self):
27 | pass
28 |
29 | def print_info(self):
30 | pass
--------------------------------------------------------------------------------
/layers/cwrnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from base import Layer
4 | from layer_utils import glorotize, orthogonalize
5 |
6 |
7 | class Softmax(Layer):
8 | def forward(self, X):
9 | exp = np.exp(X)
10 | probs = exp / np.sum(exp, axis=0, keepdims=True)
11 | self.probs = probs
12 | return probs
13 |
14 | def backward(self, dY):
15 | Y = self.probs
16 | dX = Y * dY
17 | sumdX = np.sum(dX, axis=0, keepdims=True)
18 | dX -= Y * sumdX
19 | return dX
20 |
21 |
22 | class CWRNN(Layer):
23 | def __init__(self, n_input, n_hidden, n_modules, T_max, last_state_only=False):
24 | assert(n_hidden % n_modules == 0)
25 |
26 | W = np.random.randn(n_hidden, n_input + n_hidden + 1) # +1 for bias, single combined matrix
27 | # for recurrent and input projections
28 |
29 | # glorotize and orthogonalize the non recurrent and recurrent aspects respectively
30 | W[:, :n_input] = glorotize(W[:, :n_input])
31 | W[:, n_input:-1] = orthogonalize(W[:, n_input:-1])
32 |
33 | # time kernel (T_max x n_clocks)
34 | C = np.repeat(np.arange(T_max).reshape(1, -1), T_max, axis=0)
35 | C = ((C % np.arange(1, T_max + 1).reshape(-1, 1)) == 0) * 1.0
36 | C = C.T
37 |
38 | # distribution over clocks for each module (T_max x n_modules)
39 | d = np.zeros((T_max, n_modules))
40 |
41 | self.softmax = Softmax()
42 |
43 | self.W = W
44 | self.d = d
45 | self.C = C
46 | self.n_input, self.n_hidden, self.n_modules, self.T_max, self.last_state_only = n_input, n_hidden, n_modules, T_max, last_state_only
47 |
48 |
49 | def forward(self, X):
50 | T, n, B = X.shape
51 | n_input = self.n_input
52 | n_hidden = self.n_hidden
53 | n_modules = self.n_modules
54 |
55 | D = self.softmax.forward(self.d) # get activations
56 | a = np.dot(self.C, D)
57 | A = np.repeat(a, n_hidden / n_modules, axis=1) # for each state in a module
58 | A = A[:, :, np.newaxis]
59 |
60 | V = np.zeros((T, n_input + n_hidden + 1, B))
61 | h_new = np.zeros((T, n_hidden, B))
62 | H_new = np.zeros((T, n_hidden, B))
63 | H = np.zeros((T, n_hidden, B))
64 |
65 | H_prev = np.zeros((n_hidden, B))
66 |
67 | for t in xrange(T):
68 | V[t] = np.concatenate([X[t], H_prev, np.ones((1, B))], axis=0)
69 | h_new[t] = np.dot(self.W, V[t])
70 | H_new[t] = np.tanh(h_new[t])
71 | H[t] = A[t] * H_new[t] + (1 - A[t]) * H_prev # leaky update
72 | H_prev = H[t]
73 |
74 | self.A, self.a = A, a
75 | self.V, self.h_new, self.H_new, self.H = V, h_new, H_new, H
76 |
77 | if self.last_state_only:
78 | return H[-1:]
79 | else:
80 | return H
81 |
82 | def backward(self, dH):
83 | if self.last_state_only:
84 | last_step_error = dH.copy()
85 | dH = np.zeros_like(self.H)
86 | dH[-1:] = last_step_error[:]
87 |
88 | T, _, B = dH.shape
89 | n_input = self.n_input
90 | n_hidden = self.n_hidden
91 | n_modules = self.n_modules
92 |
93 | A = self.A
94 | V, h_new, H_new, H = self.V, self.h_new, self.H_new, self.H
95 | dA, dH_prev, dW, dX = np.zeros_like(A), np.zeros((n_hidden, B)), \
96 | np.zeros_like(self.W), np.zeros((T, n_input, B))
97 |
98 | for t in reversed(xrange(T)):
99 | if t == 0:
100 | H_prev = np.zeros((n_hidden, B))
101 | else:
102 | H_prev = H[t - 1]
103 |
104 | dH_t = dH[t] + dH_prev
105 |
106 | dH_new = A[t] * dH_t
107 | dH_prev = (1 - A[t]) * dH_t
108 | dA[t] = np.sum((H_new[t] - H_prev) * dH_t, axis=1, keepdims=True)
109 |
110 | dh_new = (1.0 - H_new[t] ** 2) * dH_new
111 |
112 | dW += np.dot(dh_new, V[t].T)
113 | dV = np.dot(self.W.T, dh_new)
114 |
115 | dX[t] = dV[:n_input]
116 | dH_prev += dV[n_input:-1]
117 |
118 | dA = dA[:, :, 0]
119 | da = dA.reshape(self.T_max, -1, n_hidden / n_modules).sum(axis=-1)
120 | dD = np.dot(self.C.T, da)
121 | dd = self.softmax.backward(dD)
122 |
123 |
124 | self.dW = dW
125 | self.dd = dd
126 |
127 | return dX
128 |
129 | def get_params(self):
130 | W = self.W.flatten()
131 | d = self.d.flatten()
132 | return np.concatenate([W, d])
133 |
134 | def set_params(self, P):
135 | a, b = self.W.size, self.d.size
136 | W, d = np.split(P, [a])
137 | self.W = W.reshape(self.W.shape)
138 | self.d = d.reshape(self.d.shape)
139 |
140 | def get_grads(self):
141 | dW = self.dW.flatten()
142 | dd = self.dd.flatten()
143 | return np.concatenate([dW, dd])
144 |
145 | def clear_grads(self):
146 | self.dW = None
147 | self.dd = None
148 |
149 | def forget(self):
150 | pass
151 |
152 | def remember(self):
153 | pass
154 |
155 | def print_info(self):
156 | print 'dominant wave period: ', self.d.argmax(axis=0) + 1
157 | print 'avg. power (all coefficients): ', np.abs(self.d).mean()
158 | print 'avg. power activation waves: ', self.A.mean()
159 |
--------------------------------------------------------------------------------
/layers/cwrnn_l1.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from base import Layer
4 | from layer_utils import glorotize, orthogonalize
5 |
6 |
7 | class CWRNN_L1(Layer):
8 | def __init__(self, n_input, n_hidden, n_modules, T_max, last_state_only=False):
9 | assert(n_hidden % n_modules == 0)
10 |
11 | W = np.random.randn(n_hidden, n_input + n_hidden + 1) # +1 for bias, single combined matrix
12 | # for recurrent and input projections
13 |
14 | # glorotize and orthogonalize the non recurrent and recurrent aspects respectively
15 | W[:, :n_input] = glorotize(W[:, :n_input])
16 | W[:, n_input:-1] = orthogonalize(W[:, n_input:-1])
17 |
18 | # time kernel (T_max x n_clocks)
19 | C = np.repeat(np.arange(T_max).reshape(1, -1), T_max, axis=0)
20 | C = ((C % np.arange(1, T_max + 1).reshape(-1, 1)) == 0) * 1.0
21 | C = C.T
22 |
23 | # distribution over clocks for each module (T_max x n_modules)
24 | d = np.zeros((T_max, n_modules))
25 |
26 | self.W = W
27 | self.d = d
28 | self.C = C
29 | self.n_input, self.n_hidden, self.n_modules, self.T_max, self.last_state_only = n_input, n_hidden, n_modules, T_max, last_state_only
30 |
31 |
32 | def forward(self, X):
33 | T, n, B = X.shape
34 | n_input = self.n_input
35 | n_hidden = self.n_hidden
36 | n_modules = self.n_modules
37 |
38 | D = self.d # get activations
39 | a = np.dot(self.C, D)
40 | a = np.clip(a, 0.0, 1.0)
41 | A = np.repeat(a, n_hidden / n_modules, axis=1) # for each state in a module
42 | A = A[:, :, np.newaxis]
43 |
44 | V = np.zeros((T, n_input + n_hidden + 1, B))
45 | h_new = np.zeros((T, n_hidden, B))
46 | H_new = np.zeros((T, n_hidden, B))
47 | H = np.zeros((T, n_hidden, B))
48 |
49 | H_prev = np.zeros((n_hidden, B))
50 |
51 | for t in xrange(T):
52 | V[t] = np.concatenate([X[t], H_prev, np.ones((1, B))], axis=0)
53 | h_new[t] = np.dot(self.W, V[t])
54 | H_new[t] = np.tanh(h_new[t])
55 | H[t] = A[t] * H_new[t] + (1 - A[t]) * H_prev # leaky update
56 | H_prev = H[t]
57 |
58 | self.A, self.a = A, a
59 | self.V, self.h_new, self.H_new, self.H = V, h_new, H_new, H
60 |
61 | if self.last_state_only:
62 | return H[-1:]
63 | else:
64 | return H
65 |
66 | def backward(self, dH):
67 | if self.last_state_only:
68 | last_step_error = dH.copy()
69 | dH = np.zeros_like(self.H)
70 | dH[-1:] = last_step_error[:]
71 |
72 | T, _, B = dH.shape
73 | n_input = self.n_input
74 | n_hidden = self.n_hidden
75 | n_modules = self.n_modules
76 |
77 | A = self.A
78 | V, h_new, H_new, H = self.V, self.h_new, self.H_new, self.H
79 | dA, dH_prev, dW, dX = np.zeros_like(A), np.zeros((n_hidden, B)), \
80 | np.zeros_like(self.W), np.zeros((T, n_input, B))
81 |
82 | for t in reversed(xrange(T)):
83 | if t == 0:
84 | H_prev = np.zeros((n_hidden, B))
85 | else:
86 | H_prev = H[t - 1]
87 |
88 | dH_t = dH[t] + dH_prev
89 |
90 | dH_new = A[t] * dH_t
91 | dH_prev = (1 - A[t]) * dH_t
92 | dA[t] = np.sum((H_new[t] - H_prev) * dH_t, axis=1, keepdims=True)
93 |
94 | dh_new = (1.0 - H_new[t] ** 2) * dH_new
95 |
96 | dW += np.dot(dh_new, V[t].T)
97 | dV = np.dot(self.W.T, dh_new)
98 |
99 | dX[t] = dV[:n_input]
100 | dH_prev += dV[n_input:-1]
101 |
102 | dA = dA[:, :, 0]
103 | da = dA.reshape(self.T_max, -1, n_hidden / n_modules).sum(axis=-1)
104 | dD = np.dot(self.C.T, da)
105 | dd = dD
106 |
107 | self.dW = dW
108 | self.dd = dd + 0.01 * np.sign(self.d)
109 |
110 | return dX
111 |
112 | def get_params(self):
113 | W = self.W.flatten()
114 | d = self.d.flatten()
115 | return np.concatenate([W, d])
116 |
117 | def set_params(self, P):
118 | a, b = self.W.size, self.d.size
119 | W, d = np.split(P, [a])
120 | self.W = W.reshape(self.W.shape)
121 | self.d = d.reshape(self.d.shape)
122 |
123 | def get_grads(self):
124 | dW = self.dW.flatten()
125 | dd = self.dd.flatten()
126 | return np.concatenate([dW, dd])
127 |
128 | def clear_grads(self):
129 | self.dW = None
130 | self.dd = None
131 |
132 | def print_info(self):
133 | print 'dominant wave period: ', self.d.argmax(axis=0) + 1
134 | print 'avg. power (all coefficients): ', np.abs(self.d).mean()
135 | print 'avg. power activation waves: ', self.A.mean()
136 |
--------------------------------------------------------------------------------
/layers/cwrnn_norm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from base import Layer
4 | from layer_utils import glorotize, orthogonalize
5 |
6 | from scipy import special
7 |
8 |
9 | class Normalize(Layer):
10 | def __init__(self, axis=0):
11 | self.axis = axis
12 |
13 | def forward(self, X):
14 | mn = X.min(axis=self.axis)
15 | mx = X.max(axis=self.axis)
16 |
17 | amn = X.argmin(axis=self.axis)
18 | amx = X.argmax(axis=self.axis)
19 |
20 | Y = (X - mn) / (mx - mn)
21 |
22 | self.mn = mn
23 | self.mx = mx
24 | self.amn = amn
25 | self.amx = amx
26 | self.X = X
27 | self.Y = Y
28 | return Y
29 |
30 | def backward(self, dY):
31 | dY[self.amx, range(dY.shape[1])] = 0.0
32 | dY[self.amn, range(dY.shape[1])] = 0.0
33 | dX = dY / (self.mx - self.mn)
34 | dX[self.amx, range(dY.shape[1])] = np.sum(-self.Y * dY / (self.mx - self.mn), axis=self.axis)
35 | dX[self.amn, range(dY.shape[1])] = np.sum((self.X - self.mx) * dY / (self.mx - self.mn) ** 2, axis=self.axis)
36 |
37 | return dX
38 |
39 |
40 | class Softmax(Layer):
41 | def forward(self, X):
42 | exp = np.exp(X)
43 | probs = exp / np.sum(exp, axis=0, keepdims=True)
44 | self.probs = probs
45 | return probs
46 |
47 | def backward(self, dY):
48 | Y = self.probs
49 | dX = Y * dY
50 | sumdX = np.sum(dX, axis=0, keepdims=True)
51 | dX -= Y * sumdX
52 | return dX
53 |
54 |
55 | class CWRNN_NORM(Layer):
56 | def __init__(self, n_input, n_hidden, n_modules, T_max, last_state_only=False):
57 | assert(n_hidden % n_modules == 0)
58 |
59 | W = np.random.randn(n_hidden, n_input + n_hidden + 1) # +1 for bias, single combined matrix
60 | # for recurrent and input projections
61 |
62 | # glorotize and orthogonalize the recurrent and non recurrent aspects respectively
63 | W[:, :n_input] = glorotize(W[:, :n_input])
64 | W[:, n_input:-1] = orthogonalize(W[:, n_input:-1])
65 |
66 | # time kernel (T_max x T_max)
67 | C = np.repeat(np.arange(1, T_max + 1).reshape(1, -1), T_max, axis=0)
68 | C = ((C % np.arange(1, T_max + 1).reshape(-1, 1)) == 0) * 1.0
69 | C = C.T
70 |
71 | # distribution over clocks for each module (T_max x n_modules)
72 | d = np.random.randn(T_max, n_modules)
73 |
74 | self.softmax = Softmax()
75 | self.norm = Normalize()
76 |
77 | self.W = W
78 | self.d = d
79 | self.C = C
80 | self.n_input, self.n_hidden, self.n_modules, self.T_max, self.last_state_only = n_input, n_hidden, n_modules, T_max, last_state_only
81 |
82 |
83 | def forward(self, X):
84 | T, n, B = X.shape
85 | n_input = self.n_input
86 | n_hidden = self.n_hidden
87 | n_modules = self.n_modules
88 |
89 | D = self.softmax.forward(self.d) # get activations
90 | a = np.dot(self.C, D)
91 | a = self.norm.forward(a)
92 | A = np.repeat(a, n_hidden / n_modules, axis=1) # for each state in a module
93 | A = A[:, :, np.newaxis]
94 |
95 | V = np.zeros((T, n_input + n_hidden + 1, B))
96 | h_new = np.zeros((T, n_hidden, B))
97 | H_new = np.zeros((T, n_hidden, B))
98 | H = np.zeros((T, n_hidden, B))
99 |
100 | H_prev = np.zeros((n_hidden, B))
101 |
102 | for t in xrange(T):
103 | V[t] = np.concatenate([X[t], H_prev, np.ones((1, B))], axis=0)
104 | h_new[t] = np.dot(self.W, V[t])
105 | H_new[t] = np.tanh(h_new[t])
106 | H[t] = A[t] * H_new[t] + (1 - A[t]) * H_prev # leaky update
107 | H_prev = H[t]
108 |
109 | self.D, self.A, self.a = D, A, a
110 | self.V, self.h_new, self.H_new, self.H = V, h_new, H_new, H
111 |
112 | if self.last_state_only:
113 | return H[-1:]
114 | else:
115 | return H
116 |
117 | def backward(self, dH):
118 | if self.last_state_only:
119 | last_step_error = dH.copy()
120 | dH = np.zeros_like(self.H)
121 | dH[-1:] = last_step_error[:]
122 |
123 | T, _, B = dH.shape
124 | n_input = self.n_input
125 | n_hidden = self.n_hidden
126 | n_modules = self.n_modules
127 |
128 | A = self.A
129 | V, h_new, H_new, H = self.V, self.h_new, self.H_new, self.H
130 | dA, dH_prev, dW, dX = np.zeros_like(A), np.zeros((n_hidden, B)), \
131 | np.zeros_like(self.W), np.zeros((T, n_input, B))
132 |
133 | for t in reversed(xrange(T)):
134 | if t == 0:
135 | H_prev = np.zeros((n_hidden, B))
136 | else:
137 | H_prev = H[t - 1]
138 |
139 | dH_t = dH[t] + dH_prev
140 |
141 | dH_new = A[t] * dH_t
142 | dH_prev = (1 - A[t]) * dH_t
143 | dA[t] = np.sum((H_new[t] - H_prev) * dH_t, axis=1, keepdims=True)
144 |
145 | dh_new = (1.0 - H_new[t] ** 2) * dH_new
146 |
147 | dW += np.dot(dh_new, V[t].T)
148 | dV = np.dot(self.W.T, dh_new)
149 |
150 | dX[t] = dV[:n_input]
151 | dH_prev += dV[n_input:-1]
152 |
153 | dA = dA[:, :, 0]
154 | da = dA.reshape(self.T_max, -1, n_hidden / n_modules).sum(axis=-1)
155 | da = self.norm.backward(da)
156 | dD = np.dot(self.C.T, da)
157 | dd = self.softmax.backward(dD)
158 |
159 | self.dW = dW
160 | self.dd = dd
161 |
162 | return dX
163 |
164 | def get_params(self):
165 | W = self.W.flatten()
166 | d = self.d.flatten()
167 | return np.concatenate([W, d])
168 |
169 | def set_params(self, P):
170 | a, b = self.W.size, self.d.size
171 | W, d = np.split(P, [a])
172 | self.W = W.reshape(self.W.shape)
173 | self.d = d.reshape(self.d.shape)
174 |
175 | def get_grads(self):
176 | dW = self.dW.flatten()
177 | dd = self.dd.flatten()
178 | return np.concatenate([dW, dd])
179 |
180 | def clear_grads(self):
181 | self.dW = None
182 | self.dd = None
183 |
184 | def forget(self):
185 | pass
186 |
187 | def remember(self):
188 | pass
189 |
190 | def print_info(self):
191 | _D = self.d.copy()
192 | print 'dominant wave period: \n', _D.argmax(axis=0) + 1
193 | print '\n\navg. power (all):\t', np.abs(_D).mean()
194 | print 'avg. power waves:\t', self.A.mean()
195 | print '\n\nentropy:\n', special.entr(self.D).sum(axis=0)
196 | print '\n\n mean entropy: ', np.mean(special.entr(self.D).sum(axis=0))
197 | d = self.d.copy()
198 | maxx = d.max(axis=0)
199 | d[d.argmax(axis=0), range(d.shape[1])] = -10.0
200 | print '\n\ndiff b/w top 2:\n', maxx - d.max(axis=0)
201 |
202 |
203 |
204 |
--------------------------------------------------------------------------------
/layers/layer_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def glorotize(W):
5 | W *= np.sqrt(6)
6 | W /= np.sqrt(np.sum(W.shape))
7 | return W
8 |
9 |
10 | def orthogonalize(W):
11 | W, _, _ = np.linalg.svd(W)
12 | return W
13 |
--------------------------------------------------------------------------------
/layers/linear.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from base import Layer
4 | from layer_utils import glorotize
5 |
6 |
7 | class Linear(Layer):
8 | def __init__(self, n_input, n_output):
9 | W = np.random.randn(n_output, n_input + 1) # +1 for the bias
10 | W = glorotize(W)
11 | self.W = W
12 |
13 | self.n_input = n_input
14 | self.n_output = n_output
15 |
16 | def forward(self, X):
17 | T, n, B = X.shape
18 |
19 | X_flat = X.swapaxes(0, 1).reshape(n, -1) # flatten over time and batch
20 | X_flat = np.concatenate([X_flat, np.ones((1, B * T))], axis=0) # add bias
21 |
22 | Y = np.dot(self.W, X_flat)
23 | Y = Y.reshape((-1, T, B)).swapaxes(0,1)
24 |
25 | self.X_flat = X_flat
26 |
27 | return Y
28 |
29 | def backward(self, dY):
30 | T, n, B = dY.shape
31 |
32 | dY = dY.swapaxes(0,1).reshape(n, -1)
33 |
34 | self.dW = np.dot(dY, self.X_flat.T)
35 |
36 | dX = np.dot(self.W.T, dY)
37 |
38 | dX = dX[:-1] # skip the bias we added above
39 | dX = dX.reshape((-1, T, B)).swapaxes(0,1)
40 |
41 | return dX
42 |
43 | def get_params(self):
44 | return self.W.flatten()
45 |
46 | def set_params(self, W):
47 | self.W = W.reshape(self.W.shape)
48 |
49 | def get_grads(self):
50 | return self.dW.flatten()
51 |
52 | def clear_grads(self):
53 | self.dW = None
54 |
--------------------------------------------------------------------------------
/layers/softmax_ce_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from base import Layer
4 |
5 |
6 | class SoftmaxCrossEntropyLoss(Layer):
7 | def forward(self, X, target):
8 | T, _, B = X.shape
9 |
10 | exp = np.exp(X)
11 | probs = exp / np.sum(exp, axis=1, keepdims=True)
12 |
13 | loss = -1.0 * np.sum(target * np.log(probs)) / (T * B)
14 |
15 | self.probs = probs
16 | self.target, self.T, self.B = target, T, B
17 |
18 | return loss
19 |
20 | def backward(self):
21 | target, T, B = self.target, self.T, self.B
22 |
23 | dX = self.probs - target
24 |
25 | return dX / (T * B)
--------------------------------------------------------------------------------
/layers/tests.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from __init__ import *
4 |
5 |
6 | def finite_difference_check(layer, fwd, all_values, backpropagated_gradients, names, delta, error_threshold):
7 | error_count = 0
8 | for v in range(len(names)):
9 | values = all_values[v]
10 | dvalues = backpropagated_gradients[v]
11 | name = names[v]
12 |
13 | for i in range(values.size):
14 | actual = values.flat[i]
15 | values.flat[i] = actual + delta
16 | loss_minus = fwd()
17 | values.flat[i] = actual - delta
18 | loss_plus = fwd()
19 | values.flat[i] = actual
20 | backpropagated_gradient = dvalues.flat[i]
21 | numerical_gradient = (loss_minus - loss_plus) / (2 * delta)
22 |
23 | if numerical_gradient == 0 and backpropagated_gradient == 0:
24 | error = 0
25 | elif abs(numerical_gradient) < 1e-7 and abs(backpropagated_gradient) < 1e-7:
26 | error = 0
27 | else:
28 | error = abs(backpropagated_gradient - numerical_gradient) / abs(numerical_gradient + backpropagated_gradient)
29 |
30 | if error > error_threshold:
31 | print 'FAILURE!!!\n'
32 | print '\tparameter: ', name, '\tindex: ', np.unravel_index(i, values.shape)
33 | print '\tvalues: ', actual
34 | print '\tbackpropagated_gradient: ', backpropagated_gradient
35 | print '\tnumerical_gradient', numerical_gradient
36 | print '\terror: ', error
37 | print '----' * 20
38 | print '\n\n'
39 |
40 | error_count += 1
41 |
42 | if error_count == 0:
43 | print layer.__class__.__name__, 'Layer Gradient Check Passed'
44 | else:
45 | param_count = 0
46 | for val in all_values:
47 | param_count += val.size
48 | print layer.__class__.__name__, 'Layer Gradient Check Failed for {}/{} parameters'.format(error_count, param_count)
49 |
50 |
51 | def test_layer(layer):
52 | P = layer.get_params()
53 | Y = layer.forward(X)
54 | T_rand = np.random.randn(*Y.shape) # random target for a multiplicative loss
55 | loss = np.sum(Y * T_rand) # loss
56 |
57 | dX = layer.backward(T_rand)
58 | dP = layer.get_grads()
59 |
60 | def fwd():
61 | layer.forget()
62 | layer.set_params(P)
63 | return np.sum(layer.forward(X) * T_rand)
64 |
65 | all_values = [X, P]
66 | backpropagated_gradients = [dX, dP]
67 | names = ['X', 'P']
68 |
69 |
70 | finite_difference_check(layer, fwd, all_values, backpropagated_gradients, names, delta, error_threshold)
71 |
72 |
73 | def test_loss(layer):
74 |
75 | exp = np.exp(np.random.random(X.shape))
76 | target = exp / np.sum(exp, axis=1, keepdims=True) # random target for a multiplicative loss
77 | loss = layer.forward(X, target)
78 |
79 | dX = layer.backward()
80 |
81 | def fwd():
82 | return layer.forward(X, target)
83 |
84 | all_values = [X]
85 | backpropagated_gradients = [dX]
86 | names = ['X']
87 |
88 | finite_difference_check(layer, fwd, all_values, backpropagated_gradients, names, delta, error_threshold)
89 |
90 |
91 | def test_activation(layer):
92 | Y = layer.forward(X)
93 | T_rand = np.random.randn(*Y.shape) # random target for a multiplicative loss
94 | loss = np.sum(Y * T_rand) # loss
95 |
96 | dX = layer.backward(T_rand)
97 |
98 | def fwd():
99 | return np.sum(layer.forward(X) * T_rand)
100 |
101 | all_values = [X]
102 | backpropagated_gradients = [dX]
103 | names = ['X']
104 |
105 | finite_difference_check(layer, fwd, all_values, backpropagated_gradients, names, delta, error_threshold)
106 |
107 |
108 | delta = 1e-4
109 | error_threshold = 1e-3
110 | time_steps = 5
111 | n_input = 3
112 | batch_size = 7
113 |
114 | X = np.random.randn(time_steps, n_input, batch_size)
115 |
116 | n_output = 20
117 | layer = Linear(n_input, n_output)
118 | test_layer(layer=layer)
119 |
120 | layer = SoftmaxCrossEntropyLoss()
121 | test_loss(layer=layer)
122 |
123 | layer = CWRNN(n_input=n_input, n_hidden=8, n_modules=4, T_max=time_steps, last_state_only=False)
124 | test_layer(layer=layer)
125 |
126 | layer = CWRNN_NORM(n_input=n_input, n_hidden=8, n_modules=4, T_max=time_steps, last_state_only=False)
127 | test_layer(layer=layer)
128 |
129 | layer = CWRNN_L1(n_input=n_input, n_hidden=8, n_modules=4, T_max=time_steps, last_state_only=False)
130 | test_layer(layer=layer)
131 |
132 | layer = TanH()
133 | test_activation(layer=layer)
134 |
135 | layer = Sigmoid()
136 | test_activation(layer=layer)
137 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def forward(model, input, target):
5 | for layer in model[:-1]:
6 | input = layer.forward(input)
7 | output = model[-1].forward(input, target)
8 | return output
9 |
10 |
11 | def backward(model):
12 | gradient = model[-1].backward()
13 | for layer in reversed(model[:-1]):
14 | gradient = layer.backward(gradient)
15 | return gradient
16 |
17 |
18 | def load_params(model, W):
19 | for layer in model:
20 | w = layer.get_params()
21 | if w is None:
22 | continue
23 | w_shape = w.shape
24 | w, W = np.split(W, [np.prod(w_shape)])
25 | layer.set_params(w.reshape(w_shape))
26 |
27 |
28 | def extract_params(model):
29 | params = []
30 | for layer in model:
31 | w = layer.get_params()
32 | if w is None:
33 | continue
34 | params.append(w)
35 | W = np.concatenate(params)
36 | return np.array(W)
37 |
38 |
39 | def extract_grads(model):
40 | grads = []
41 | for layer in model:
42 | g = layer.get_grads()
43 | if g is None:
44 | continue
45 | grads.append(g)
46 | dW = np.concatenate(grads)
47 | return np.array(dW)
48 |
49 |
50 | def forget(model):
51 | for layer in model:
52 | layer.forget()
53 |
54 |
55 | def print_info(model):
56 | for layer in model:
57 | layer.print_info()
58 |
--------------------------------------------------------------------------------
/train_mnist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from layers import *
4 | from network import *
5 | from climin import RmsProp, Adam, GradientDescent
6 | from data.mnist import loader
7 |
8 | import time
9 | import pickle
10 | import os
11 |
12 | import matplotlib.pyplot as plt
13 | plt.ion()
14 | plt.style.use('kosh')
15 | plt.figure(figsize=(12, 7))
16 |
17 |
18 | np.random.seed(np.random.randint(1213))
19 |
20 | experiment_name = 'plot_freq_softmax'
21 |
22 | permuted = False
23 |
24 | n_input = 1
25 | n_hidden = 128
26 | n_modules = 8
27 | n_output = 10
28 |
29 | batch_size = 50
30 | learning_rate = 1e-3
31 | niterations = 20000
32 | momentum = 0.9
33 |
34 | gradient_clip = (-1.0, 1.0)
35 |
36 | save_every = 1000
37 | plot_every = 100
38 |
39 | logs = {}
40 |
41 | data = loader(batch_size=batch_size, permuted=permuted)
42 |
43 | def dW(W):
44 | load_params(model, W)
45 | forget(model)
46 | inputs, targets = data.fetch_train()
47 | loss = forward(model, inputs, targets)
48 | backward(model)
49 |
50 | gradients = extract_grads(model)
51 | clipped_gradients = np.clip(gradients, gradient_clip[0], gradient_clip[1])
52 |
53 | gradient_norm = (gradients ** 2).sum() / gradients.size
54 | clipped_gradient_norm = (clipped_gradients ** 2).sum() / gradients.size
55 |
56 | logs['loss'].append(loss)
57 | logs['smooth_loss'].append(loss * 0.01 + logs['smooth_loss'][-1] * 0.99)
58 | logs['gradient_norm'].append(gradient_norm)
59 | logs['clipped_gradient_norm'].append(clipped_gradient_norm)
60 |
61 | return clipped_gradients
62 |
63 |
64 | os.system('mkdir results/' + experiment_name)
65 | path = 'results/' + experiment_name + '/'
66 |
67 | logs['loss'] = []
68 | logs['val_loss'] = []
69 | logs['accuracy'] = []
70 | logs['smooth_loss'] = [np.log(10)]
71 | logs['gradient_norm'] = []
72 | logs['clipped_gradient_norm'] = []
73 |
74 |
75 | model = [
76 | CWRNN_L1(n_input=n_input, n_hidden=n_hidden, n_modules=n_modules, T_max=784, last_state_only=True),
77 | Linear(n_hidden, n_output),
78 | SoftmaxCrossEntropyLoss()
79 | ]
80 |
81 | W = extract_params(model)
82 |
83 | optimizer = Adam(W, dW, learning_rate, momentum=momentum)
84 |
85 | print 'Approx. Parameters: ', W.size
86 |
87 | for i in optimizer:
88 | if i['n_iter'] > niterations:
89 | break
90 |
91 | print '\n\n'
92 | print str(data.epoch) + '\t' + str(i['n_iter']), '\t',
93 | print logs['loss'][-1], '\t',
94 | print logs['gradient_norm'][-1]
95 | print_info(model)
96 | print '\n', '----' * 20, '\n'
97 |
98 | if data.epoch_complete:
99 | inputs, labels = data.fetch_val()
100 | nsamples = inputs.shape[2]
101 | inputs = np.split(inputs, nsamples / batch_size, axis=2)
102 | labels = np.split(labels, nsamples / batch_size, axis=2)
103 | val_loss = 0
104 | for j in range(len(inputs)):
105 | forget(model)
106 | input = inputs[j]
107 | label = labels[j]
108 | val_loss += forward(model, input, label)
109 | val_loss /= len(inputs)
110 | logs['val_loss'].append(val_loss)
111 | print '..' * 20
112 | print 'validation loss: ', val_loss
113 |
114 | inputs, labels = data.fetch_test()
115 | nsamples = inputs.shape[2]
116 | inputs = np.split(inputs, nsamples / batch_size, axis=2)
117 | labels = np.split(labels, nsamples / batch_size, axis=2)
118 |
119 | correct = 0
120 | for j in range(len(inputs)):
121 | forget(model)
122 | input = inputs[j]
123 | label = labels[j]
124 | _ = forward(model, input, label)
125 | pred = model[-1].probs
126 | good = np.sum(label.argmax(axis=1) == pred.argmax(axis=1))
127 | correct += good
128 |
129 | accuracy = correct / float(nsamples)
130 | logs['accuracy'].append(accuracy)
131 | print 'accuracy: ', accuracy
132 | print '..' * 20
133 |
134 | data.epoch_complete = False
135 |
136 | plt.figure(2)
137 | plt.clf()
138 | plt.plot(logs['val_loss'], label='validation')
139 | plt.legend()
140 | plt.draw()
141 | plt.figure(3)
142 | plt.clf()
143 | plt.plot(logs['accuracy'], label='accuracy')
144 | plt.legend()
145 | plt.draw()
146 |
147 |
148 | if i['n_iter'] % save_every == 0:
149 | print 'serializing model... '
150 | f = open(path + 'iter_' + str(i['n_iter']) +'.model', 'w')
151 | pickle.dump(model, f)
152 | f.close()
153 |
154 | if i['n_iter'] % plot_every == 0:
155 | plt.figure(1)
156 | plt.clf()
157 | plt.plot(logs['smooth_loss'], label='training')
158 | plt.legend()
159 | plt.draw()
160 |
161 |
162 | print 'serializing logs... '
163 | f = open(path + 'logs.logs', 'w')
164 | pickle.dump(logs, f)
165 | f.close()
166 |
167 | print 'serializing final model... '
168 | f = open(path + 'final.model', 'w')
169 | pickle.dump(model, f)
170 | f.close()
171 |
172 | plt.savefig(path + 'loss_curve')
173 |
--------------------------------------------------------------------------------
/train_ptb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from layers import *
4 | from network import *
5 | from climin import RmsProp, Adam, GradientDescent
6 | from data.text import loader
7 |
8 | import time
9 | import pickle
10 | import os
11 |
12 | import matplotlib.pyplot as plt
13 | plt.ion()
14 | plt.style.use('kosh')
15 | plt.figure(figsize=(12, 7))
16 |
17 |
18 | np.random.seed(np.random.randint(1213))
19 |
20 | experiment_name = 'penn_random_init_noscale_0.5binaryActi_0.01WDecay'
21 |
22 | text_file = 'ptb.txt'
23 |
24 | vocabulary_size = 49
25 |
26 | n_output = n_input = vocabulary_size
27 | n_hidden = 1024
28 | n_modules = 8
29 | noutputs = vocabulary_size
30 |
31 | sequence_length = 100
32 |
33 | batch_size = 64
34 | learning_rate = 2e-3
35 | niterations = 100000
36 | momentum = 0.9
37 |
38 | forget_every = 100
39 | gradient_clip = (-1.0, 1.0)
40 |
41 | sample_every = 1000
42 | save_every = 1000
43 | plot_every = 100
44 |
45 | logs = {}
46 |
47 | data = loader('data/' + text_file, sequence_length, batch_size)
48 |
49 |
50 | def dW(W):
51 | load_params(model, W)
52 | forget(model)
53 | inputs, targets = data.fetch_train()
54 | loss = forward(model, inputs, targets)
55 | backward(model)
56 |
57 | gradients = extract_grads(model)
58 | clipped_gradients = np.clip(gradients, gradient_clip[0], gradient_clip[1])
59 |
60 | gradient_norm = (gradients ** 2).sum() / gradients.size
61 | clipped_gradient_norm = (clipped_gradients ** 2).sum() / gradients.size
62 |
63 | logs['loss'].append(loss)
64 | logs['smooth_loss'].append(loss * 0.01 + logs['smooth_loss'][-1] * 0.99)
65 | logs['gradient_norm'].append(gradient_norm)
66 | logs['clipped_gradient_norm'].append(clipped_gradient_norm)
67 |
68 | return clipped_gradients
69 |
70 |
71 | os.system('mkdir results/' + experiment_name)
72 | path = 'results/' + experiment_name + '/'
73 |
74 | logs['loss'] = []
75 | logs['val_loss'] = []
76 | logs['accuracy'] = []
77 | logs['smooth_loss'] = [np.log(49)]
78 | logs['gradient_norm'] = []
79 | logs['clipped_gradient_norm'] = []
80 |
81 | model = [
82 | CWRNN2(n_input=n_input, n_hidden=n_hidden, n_modules=n_modules, T_max=sequence_length),
83 | Linear(n_hidden, n_output),
84 | SoftmaxCrossEntropyLoss()
85 | ]
86 |
87 | W = extract_params(model)
88 |
89 | optimizer = Adam(W, dW, learning_rate, momentum=momentum)
90 |
91 | print 'Approx. Parameters: ', W.size
92 |
93 | for i in optimizer:
94 | if i['n_iter'] > niterations:
95 | break
96 |
97 | print '\n\n'
98 | print str(i['n_iter']), '\t',
99 | print logs['loss'][-1], '\t',
100 | print logs['gradient_norm'][-1]
101 | print_info(model)
102 | print '\n', '----' * 20, '\n'
103 |
104 | if i['n_iter'] % sample_every == 0:
105 | forget(model)
106 | x = np.zeros((20, vocabulary_size, 1))
107 | input, _ = data.fetch_train()
108 | x[:20, :, :] = input[:20, :, 0:1]
109 | ixes = []
110 | for t in xrange(1000):
111 | forward(model, np.array(x), 1.0)
112 | p = model[-1].probs
113 | p = p[-1]
114 | ix = np.random.choice(range(vocabulary_size), p=p.ravel())
115 | x = np.zeros((1, vocabulary_size, 1))
116 | x[0, ix, 0] = 1
117 | ixes.append(ix)
118 | sample = ''.join(data.decoder.to_c[ix] for ix in ixes)
119 | print '----' * 20
120 | print sample
121 | print '----' * 20
122 | forget(model)
123 |
124 | if i['n_iter'] % save_every == 0:
125 | print 'serializing model... '
126 | f = open(path + 'iter_' + str(i['n_iter']) +'.model', 'w')
127 | pickle.dump(model, f)
128 | f.close()
129 |
130 | if i['n_iter'] % plot_every == 0:
131 | plt.clf()
132 | plt.plot(logs['smooth_loss'])
133 | plt.draw()
134 |
135 | print 'serializing logs... '
136 | f = open(path + 'logs.logs', 'w')
137 | pickle.dump(logs, f)
138 | f.close()
139 |
140 | print 'serializing final model... '
141 | f = open(path + 'final.model', 'w')
142 | pickle.dump(model, f)
143 | f.close()
144 |
145 | plt.savefig(path + 'loss_curve')
146 |
--------------------------------------------------------------------------------