├── .DS_Store
├── fig
├── .DS_Store
├── digit3.png
├── image3.png
├── NTMCell.png
├── alien_wave.png
├── qz_t_digit3.png
├── qz_t_wave0.png
├── qz_t_wave1.png
├── digit3_series.png
├── wave_transitions.png
├── digit3_transition.png
├── digit7_transition.png
├── digit9_transition.png
└── digit_transitions.png
├── README.md
└── kLSTM.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/.DS_Store
--------------------------------------------------------------------------------
/fig/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/.DS_Store
--------------------------------------------------------------------------------
/fig/digit3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit3.png
--------------------------------------------------------------------------------
/fig/image3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/image3.png
--------------------------------------------------------------------------------
/fig/NTMCell.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/NTMCell.png
--------------------------------------------------------------------------------
/fig/alien_wave.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/alien_wave.png
--------------------------------------------------------------------------------
/fig/qz_t_digit3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/qz_t_digit3.png
--------------------------------------------------------------------------------
/fig/qz_t_wave0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/qz_t_wave0.png
--------------------------------------------------------------------------------
/fig/qz_t_wave1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/qz_t_wave1.png
--------------------------------------------------------------------------------
/fig/digit3_series.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit3_series.png
--------------------------------------------------------------------------------
/fig/wave_transitions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/wave_transitions.png
--------------------------------------------------------------------------------
/fig/digit3_transition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit3_transition.png
--------------------------------------------------------------------------------
/fig/digit7_transition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit7_transition.png
--------------------------------------------------------------------------------
/fig/digit9_transition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit9_transition.png
--------------------------------------------------------------------------------
/fig/digit_transitions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit_transitions.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Markov Recurrent Neural Networks
2 |
3 | This repository is the PyTorch implementation of [Markov Recurrent Neural Networks](https://github.com/NCTUMLlab/Che-Yu-Kuo-MarkovRNN.git) with two temporal datasets as quick demonstration.
4 |
5 | ## Architecture
6 | - **Paper**: [Markov Recurrent Neural Networks (MRNN)](https://ieeexplore.ieee.org/document/8517074)
7 |
8 |
9 |
10 |
11 | - **Heuristic explanation:**
12 | MRNN is built as a deep learning model for time series, such as NLP, stock price prediction, or gravitational wave detection. The main idea is to create several *parallel* [RNNs](https://en.wikipedia.org/wiki/Recurrent_neural_network) [(LSTMs)](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) to learn the time dependence of the data simultaneously. If data has complex temporal structures (behaviour), single RNN may not be enough to carry out the pattern. *k* parallel RNNs (*k=1,2,3,...*) can read same input signal at the same time, each learns different character of data. Then another latent variable *z* (also trained by networks) will determine when and which LSTM should be listened for attaining learning task (see Fig *q_z(t)* & *z(t)* below). The choosing mechanism by *z* itself is a process stochastic modeling of transitions between *k* LSTMs based on Markov property, and hence the name MRNN.
13 |
14 | - **Note:**
15 | The transition variable z between *k* LSTMs can regarded as an [attention mechanism](https://nlp.stanford.edu/pubs/emnlp15_attn.pdf) over individual LSTM hidden states.
16 |
17 |
18 | ## Datasets
19 |
20 | - [**MNIST**](https://en.wikipedia.org/wiki/MNIST_database) viewed in series as sequential input:
21 |
22 |
23 | - **Artificial alien signals**: I am imagining we are able to recognize radio signals sent by aliens from the sky such as [SETI](https://setiathome.berkeley.edu/), where I generated two kinds of wave forms for Markov RNN to distinguish:
24 |
25 |
26 |
27 | ## Prerequisites
28 | - [Python 3.6](https://www.python.org/)
29 | - [Jupyter notebook](https://jupyter.org/)
30 | - [PyTorch 1.0](https://pytorch.org/)
31 | - [Numpy 1.15.0](http://www.numpy.org/)
32 | - [Sklearn 0.20.2](https://scikit-learn.org/stable/)
33 | - [Matplotlib](https://matplotlib.org/)
34 |
35 |
36 | ## Usage
37 | No installation required except the prerequisites. File `kLSTM.py` contains all the modules needed for running Markov RNN. Two examples are provided in Jupyter notebook formats:
38 | ```
39 | MRNN_MNIST.ipynb
40 | MRNN_detect_alien_signal.ipynb
41 | ```
42 |
43 | ## Results & Interpretations
44 | 1. **Take *k=4* LSTM for MNIST.**
45 |
46 | - In figure of *q_z(t)*, the horizontal axis is time, vertical axis shows which LSTM to look. The color palette indicates the probability of the LSTM being used.
47 |
48 | - In figure of *z(t)*, the yellow color indicates which LSTM was actually used (by Gumbel softmax sampling).
49 |
50 |
51 | - Over all probability of which LSTM being chosen. **[Left]: digit 7**, **[right]: digit 9**
52 |
53 |
54 |
55 | 2. **Take *k=4* LSTM for alien signal (binary) classification**
56 | - Non-alien signal
57 | - Fig *q_z(t)* & *z(t)* slightly shows periodicity
58 |
59 |
60 |
61 | - Alien signal (maybe say Hi or invasion)
62 | - Fig *q_z(t)* & *z(t)* use 4 LSTMs to detect irregular wave form from aliens.
63 |
64 |
65 | - Over all probability of which LSTM being chosen. **[Left]: non-alien signal**, **[right]: alien signal**
66 |
67 |
68 |
69 |
70 | ## Improvements
71 | This code of Pytorch is extended such that the Markov RNN can have more than 1 hidden layers in every LSTM, where it was restricted to only 1 hidden layer in the original code of [Tensorflow version](https://github.com/NCTUMLlab/Che-Yu-Kuo-MarkovRNN.git).
72 |
--------------------------------------------------------------------------------
/kLSTM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 | from torch.nn import init
5 | from PIL import Image
6 | import torch.nn.functional as F
7 | import math
8 |
9 | def plot_image_series(images, ncols, nrows, photow, photoh, marl, mart, marr, marb, padding):
10 | """
11 | Make a contact sheet from a group of filenames:
12 |
13 | fnames A list of names of the image files
14 |
15 | ncols Number of columns in the contact sheet
16 | nrows Number of rows in the contact sheet
17 | photow The width of the photo thumbs in pixels
18 | photoh The height of the photo thumbs in pixels
19 |
20 | marl The left margin in pixels
21 | mart The top margin in pixels
22 | marr The right margin in pixels
23 | marl The left margin in pixels
24 |
25 | padding The padding between images in pixels
26 |
27 | returns a PIL image object.
28 | """
29 |
30 | # Read in all images and resize appropriately
31 | imgs = [Image.fromarray(images[i]).resize((photow,photoh)) for i in range(len(images))]
32 | # imgs = np_img.fromarray()
33 |
34 | # Calculate the size of the output image, based on the
35 | # photo thumb sizes, margins, and padding
36 | marw = marl+marr
37 | marh = mart+ marb
38 |
39 | padw = (ncols-1)*padding
40 | padh = (nrows-1)*padding
41 | isize = (ncols*photow+marw+padw,nrows*photoh+marh+padh)
42 |
43 | # Create the new image. The background doesn't have to be white
44 | white = (255,255,255)
45 | inew = Image.new('RGB',isize,white)
46 |
47 | # Insert each thumb:
48 | for irow in range(nrows):
49 | for icol in range(ncols):
50 | left = marl + icol*(photow+padding)
51 | right = left + photow
52 | upper = mart + irow*(photoh+padding)
53 | lower = upper + photoh
54 | bbox = (left,upper,right,lower)
55 | try:
56 | img = imgs.pop(0)
57 | except:
58 | break
59 | inew.paste(img,bbox)
60 | return inew
61 |
62 |
63 |
64 | class LSTMCell(nn.Module):
65 | """A basic LSTM cell."""
66 | def __init__(self, input_size, hidden_size, k_cells, is_training=True, use_bias=True):
67 | super(LSTMCell, self).__init__()
68 | self.n = input_size
69 | self.m = hidden_size
70 | self.k = k_cells # number of RNN cells
71 | self.is_training = is_training
72 | self.W_xz = nn.Parameter(torch.FloatTensor(self.n, self.k)) # affine transform x_t -> (z_1, z_2, ..., z_K)
73 | self.W_hz = nn.Parameter(torch.FloatTensor(self.m, self.k)) # affine transform h -> (z_1, z_2, ..., z_K)
74 |
75 | self.W_x_4gates = nn.Parameter(torch.FloatTensor(self.n, 4 * self.m * self.k))
76 | self.W_h_4gates = nn.Parameter(torch.FloatTensor(self.m, 4 * self.m * self.k))
77 |
78 | self.use_bias = use_bias
79 | if use_bias:
80 | self.b_xz = nn.Parameter(torch.FloatTensor(self.k)) # size = (1, k)
81 | # self.b_hz = nn.Parameter(torch.FloatTensor(self.k)) # size = (1, k)
82 | self.b_4gates = nn.Parameter(torch.FloatTensor(4 * self.m * self.k)) # size = (1, 4 * m * k)
83 | else:
84 | self.register_parameter('bias', None)
85 |
86 | self.reset_parameters()
87 | print("LSTM cell parameters reseted!")
88 |
89 |
90 | def reset_parameters(self):
91 | stdv1 = 1. / math.sqrt(self.W_xz.data.size(1))
92 | self.W_xz.data.uniform_(-stdv1, stdv1)
93 |
94 | stdv2 = 1. / math.sqrt(self.W_hz.data.size(1))
95 | self.W_hz.data.uniform_(-stdv2, stdv2)
96 |
97 | stdv3 = 1. / math.sqrt(self.W_x_4gates.data.size(1))
98 | self.W_x_4gates.data.uniform_(-stdv3, stdv3)
99 |
100 | stdv4 = 1. / math.sqrt(self.W_h_4gates.data.size(1))
101 | self.W_h_4gates.data.uniform_(-stdv4, stdv4)
102 |
103 | if self.b_xz is not None:
104 | self.b_xz.data.uniform_(-stdv1, stdv1)
105 |
106 | if self.b_4gates is not None:
107 | self.b_4gates.data.uniform_(-stdv4, stdv4)
108 |
109 | def forward(self, x_t, tau, s_t, is_training):
110 | """ Args:
111 | x_t: input at time step t = (batch, input_size) tensor containing input features.
112 | tau: annealing temperature for Gumbel softmax
113 | s_t: state at time step t (h_0, c_0), which contains the initial hidden
114 | and cell state, where the size of both states is
115 | (batch, hidden_size).
116 | is_training: indicating status of training or not.
117 | Returns:
118 | state = (h_t, c_t): Tensors containing the next hidden and cell state. """
119 |
120 | h_0, c_0 = s_t # state s_t = (h_t, c_t) with size h_0 = size c_0 = (N, m)
121 | N = h_0.size(0) # batch size
122 | self.tau = tau
123 | self.is_training = is_training
124 |
125 | # expand (repeat) bias for batch processing
126 | batch_b_xz = (self.b_xz.unsqueeze(0).expand(N, *self.b_xz.size())) # size = (N, k)
127 | # batch_b_hz = (self.b_hz.unsqueeze(0).expand(N, *self.b_hz.size())) # size = (N, k)
128 | batch_b_4gates = (self.b_4gates.unsqueeze(0).expand(N, *self.b_4gates.size())) # size = (N, 4* m* k)
129 |
130 | ''' logit encoder: logit_z = (W1 * x_t + b1) + (W2 * h_{t-1} + b2) in R^K '''
131 | logit_z = torch.addmm(batch_b_xz, x_t, self.W_xz) + torch.mm(h_0, self.W_hz)
132 |
133 | # probability q_z(x_t, h_{t-1}) in R^K
134 | q_z = F.softmax(logit_z, dim=1)
135 |
136 | if self.is_training is True:
137 | z = F.gumbel_softmax(logit_z, self.tau, hard=False, eps=1e-10)
138 | else:
139 | if q_z.is_cuda:
140 | z = torch.cuda.FloatTensor(N, self.k).zero_() # create a GPU zero tensor for 1-hot
141 | else:
142 | z = torch.FloatTensor(N, self.k).zero_() # create a CPU zero tensor for 1-hot
143 |
144 | z.scatter_(1, torch.max(q_z, dim=1)[1].view(N, 1), 1) # find which position is max
145 |
146 | # k LSTM's part
147 | A = torch.mm(x_t, self.W_x_4gates) # (W_x_4gates * x + b_4gates) , output dim = 4* m* k
148 | B = torch.addmm(batch_b_4gates, h_0, self.W_h_4gates) # (W_h_4gates * h_0 + b_4gates), output dim = 4* m* k
149 | ''' this step B is a bit strange? because it may breakdown the independence of the k LSTMs'''
150 |
151 | # 4 gates in LSTM
152 | I_gate, G_gate, F_gate, O_gate = torch.split(A + B, split_size_or_sections=(self.m * self.k), dim=1)
153 |
154 | # Expand c_0 -> C_0 has dim=k
155 | C_0 = c_0.repeat(1, self.k)
156 |
157 | '''for K LSTMs: C_t = (c^1_t, c^2_t, .... , c^K_t )
158 | H_t = (h^1_t, h^2_t, .... , h^K_t ) [vectorized version for K RNNs] '''
159 | C_t = torch.mul(torch.sigmoid(F_gate), C_0) + torch.mul(torch.sigmoid(I_gate), torch.tanh(G_gate)) # C_t = new c of K dim
160 | H_t = torch.mul(torch.sigmoid(O_gate), torch.tanh(C_t)) # H_t = new h of K dim
161 |
162 | # reshape C_t & H_t has dim = (N, k, m) (hidden_size = m)
163 | C_t = C_t.view([N, self.k, self.m])
164 | H_t = H_t.view([N, self.k, self.m])
165 |
166 | # sum over K LSTMs (like K ensemble) to become 1 LSTM cell & hidden state: (h_t , c_t)
167 | h_t = torch.einsum('nkm,nk->nm', (H_t, z)) # size= (batch, output-dim), no time
168 | c_t = torch.einsum('nkm,nk->nm', (C_t, z)) # size= (batch, output-dim), no time
169 |
170 | return z, q_z, h_t, c_t
171 |
172 | def __repr__(self):
173 | s = '{name}({input_size}, {hidden_size})'
174 | return s.format(name=self.__class__.__name__, **self.__dict__)
175 |
176 |
177 | class LSTM(nn.Module):
178 | """A module that runs multiple steps of LSTM."""
179 | def __init__(self, cell_class, input_size, hidden_size, output_size,
180 | num_layers=1, k_cells=2, use_bias=True, dropout_prob=0.):
181 | super(LSTM, self).__init__()
182 | self.cell_class = cell_class
183 | self.n = input_size
184 | self.m = hidden_size
185 | self.k = k_cells # number of RNN cells
186 | self.output_size = output_size
187 | self.num_layers = num_layers
188 | self.use_bias = use_bias
189 | self.dropout_prob = dropout_prob
190 |
191 | self.fc = nn.Linear(self.m, self.output_size) # output = CNN embedding latent variables
192 |
193 | for layer in range(num_layers):
194 | layer_input_size = self.n if layer == 0 else self.m
195 | cell = cell_class(input_size=layer_input_size, hidden_size=self.m, k_cells=self.k, use_bias=self.use_bias)
196 | setattr(self, 'cell_{}'.format(layer), cell)
197 | self.dropout_layer = nn.Dropout(dropout_prob)
198 | self.reset_parameters()
199 |
200 | def get_cell(self, layer):
201 | return getattr(self, 'cell_{}'.format(layer))
202 |
203 | def reset_parameters(self):
204 | for layer in range(self.num_layers):
205 | cell = self.get_cell(layer)
206 | cell.reset_parameters()
207 |
208 | @staticmethod
209 | def _forward_rnn(cell, x, tau, length, s_t, is_training):
210 | T = x.size(1) # time is the second dimension
211 |
212 | z_T = [] # z_T = (z_1, z_2, .... , ) collecting all z's over time
213 | qz_T = []
214 | h_T = [] # h_T = (h_1, h_2, .... , ) collecting all hidden state h_t over time
215 | S_T = [] # S_T = [ (h_1, c_1), (h_2, c_2), .... , ] collecting all (h_t, c_t) over time
216 | # output = []
217 | for t in range(T):
218 | # one time step
219 | z, q_z, h_t, c_t = cell(x[:, t, :], tau, s_t, is_training)
220 |
221 | # to bound time steps of time sequences
222 | time_mask = (t < length).float().unsqueeze(1).expand_as(h_t)
223 | h_t, c_t = h_t * time_mask + s_t[0] * (1 - time_mask), c_t * time_mask + s_t[1] * (1 - time_mask)
224 | s_t = (h_t, c_t) # state t = (h_t, c_t)
225 |
226 | h_T.append(h_t)
227 | z_T.append(z)
228 | qz_T.append(q_z)
229 |
230 | z_T = torch.stack(z_T, 0).transpose_(0, 1) # [transpose to batch first], size=(N, T, output-dim)
231 | qz_T = torch.stack(qz_T, 0).transpose_(0, 1) # [transpose to batch first], size=(N, T, output-dim)
232 | h_T = torch.stack(h_T, 0).transpose_(0, 1) # [transpose to batch first], size=(N, T, output-dim)
233 |
234 | return z_T, qz_T, h_T, s_t
235 |
236 | def forward(self, x, tau, length=None, s_t=None, is_training=True):
237 |
238 | # batch is assumed first dimension of input x
239 | N, T, n = x.size()
240 | ''' N = batch size, T = total time steps, n = input feature dimension '''
241 |
242 | # RNN input temporal length limit
243 | if length is None:
244 | length = Variable(torch.LongTensor([T] * N))
245 | if x.is_cuda:
246 | device = x.get_device()
247 | length = length.cuda(device)
248 | if s_t is None:
249 | # put an initialization
250 | s_t = Variable(x.data.new(N, self.m).zero_())
251 | s_t = (s_t, s_t)
252 |
253 | all_layer_h_t = []
254 | all_layer_c_t = []
255 | layer_h_T = None
256 |
257 | # creating depth of LSTMs
258 | for layer in range(self.num_layers):
259 | cell = self.get_cell(layer) # get the cell of certain layer
260 | layer_z_T, layer_qz_T, layer_h_T, (layer_h_t, layer_c_t) = LSTM._forward_rnn(cell, x, tau,
261 | length, s_t, is_training)
262 |
263 | ''' x=data input if layer=0; x=hidden units if layer > 0 '''
264 | x = self.dropout_layer(layer_h_T)
265 | all_layer_h_t.append(layer_h_t)
266 | all_layer_c_t.append(layer_c_t)
267 |
268 | all_layer_h_t = torch.stack(all_layer_h_t, 0)
269 | all_layer_c_t = torch.stack(all_layer_c_t, 0)
270 |
271 | output = self.fc(layer_h_t)
272 |
273 | return output, layer_z_T, layer_qz_T, layer_h_T, (all_layer_h_t, all_layer_c_t)
274 |
275 |
--------------------------------------------------------------------------------