├── .DS_Store ├── fig ├── .DS_Store ├── digit3.png ├── image3.png ├── NTMCell.png ├── alien_wave.png ├── qz_t_digit3.png ├── qz_t_wave0.png ├── qz_t_wave1.png ├── digit3_series.png ├── wave_transitions.png ├── digit3_transition.png ├── digit7_transition.png ├── digit9_transition.png └── digit_transitions.png ├── README.md └── kLSTM.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/.DS_Store -------------------------------------------------------------------------------- /fig/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/.DS_Store -------------------------------------------------------------------------------- /fig/digit3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit3.png -------------------------------------------------------------------------------- /fig/image3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/image3.png -------------------------------------------------------------------------------- /fig/NTMCell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/NTMCell.png -------------------------------------------------------------------------------- /fig/alien_wave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/alien_wave.png -------------------------------------------------------------------------------- /fig/qz_t_digit3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/qz_t_digit3.png -------------------------------------------------------------------------------- /fig/qz_t_wave0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/qz_t_wave0.png -------------------------------------------------------------------------------- /fig/qz_t_wave1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/qz_t_wave1.png -------------------------------------------------------------------------------- /fig/digit3_series.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit3_series.png -------------------------------------------------------------------------------- /fig/wave_transitions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/wave_transitions.png -------------------------------------------------------------------------------- /fig/digit3_transition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit3_transition.png -------------------------------------------------------------------------------- /fig/digit7_transition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit7_transition.png -------------------------------------------------------------------------------- /fig/digit9_transition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit9_transition.png -------------------------------------------------------------------------------- /fig/digit_transitions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HHTseng/MarkovRNNs/HEAD/fig/digit_transitions.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Markov Recurrent Neural Networks 2 | 3 | This repository is the PyTorch implementation of [Markov Recurrent Neural Networks](https://github.com/NCTUMLlab/Che-Yu-Kuo-MarkovRNN.git) with two temporal datasets as quick demonstration. 4 | 5 | ## Architecture 6 | - **Paper**: [Markov Recurrent Neural Networks (MRNN)](https://ieeexplore.ieee.org/document/8517074) 7 | 8 | 9 | 10 | 11 | - **Heuristic explanation:** 12 | MRNN is built as a deep learning model for time series, such as NLP, stock price prediction, or gravitational wave detection. The main idea is to create several *parallel* [RNNs](https://en.wikipedia.org/wiki/Recurrent_neural_network) [(LSTMs)](http://colah.github.io/posts/2015-08-Understanding-LSTMs/) to learn the time dependence of the data simultaneously. If data has complex temporal structures (behaviour), single RNN may not be enough to carry out the pattern. *k* parallel RNNs (*k=1,2,3,...*) can read same input signal at the same time, each learns different character of data. Then another latent variable *z* (also trained by networks) will determine when and which LSTM should be listened for attaining learning task (see Fig *q_z(t)* & *z(t)* below). The choosing mechanism by *z* itself is a process stochastic modeling of transitions between *k* LSTMs based on Markov property, and hence the name MRNN. 13 | 14 | - **Note:** 15 | The transition variable z between *k* LSTMs can regarded as an [attention mechanism](https://nlp.stanford.edu/pubs/emnlp15_attn.pdf) over individual LSTM hidden states. 16 | 17 | 18 | ## Datasets 19 | 20 | - [**MNIST**](https://en.wikipedia.org/wiki/MNIST_database) viewed in series as sequential input: 21 | 22 | 23 | - **Artificial alien signals**: I am imagining we are able to recognize radio signals sent by aliens from the sky such as [SETI](https://setiathome.berkeley.edu/), where I generated two kinds of wave forms for Markov RNN to distinguish: 24 | 25 | 26 | 27 | ## Prerequisites 28 | - [Python 3.6](https://www.python.org/) 29 | - [Jupyter notebook](https://jupyter.org/) 30 | - [PyTorch 1.0](https://pytorch.org/) 31 | - [Numpy 1.15.0](http://www.numpy.org/) 32 | - [Sklearn 0.20.2](https://scikit-learn.org/stable/) 33 | - [Matplotlib](https://matplotlib.org/) 34 | 35 | 36 | ## Usage 37 | No installation required except the prerequisites. File `kLSTM.py` contains all the modules needed for running Markov RNN. Two examples are provided in Jupyter notebook formats: 38 | ``` 39 | MRNN_MNIST.ipynb 40 | MRNN_detect_alien_signal.ipynb 41 | ``` 42 | 43 | ## Results & Interpretations 44 | 1. **Take *k=4* LSTM for MNIST.** 45 | 46 | - In figure of *q_z(t)*, the horizontal axis is time, vertical axis shows which LSTM to look. The color palette indicates the probability of the LSTM being used. 47 | 48 | - In figure of *z(t)*, the yellow color indicates which LSTM was actually used (by Gumbel softmax sampling). 49 | 50 | 51 | - Over all probability of which LSTM being chosen. **[Left]: digit 7**, **[right]: digit 9** 52 | 53 | 54 | 55 | 2. **Take *k=4* LSTM for alien signal (binary) classification** 56 | - Non-alien signal 57 | - Fig *q_z(t)* & *z(t)* slightly shows periodicity 58 | 59 | 60 | 61 | - Alien signal (maybe say Hi or invasion) 62 | - Fig *q_z(t)* & *z(t)* use 4 LSTMs to detect irregular wave form from aliens. 63 | 64 | 65 | - Over all probability of which LSTM being chosen. **[Left]: non-alien signal**, **[right]: alien signal** 66 | 67 | 68 | 69 | 70 | ## Improvements 71 | This code of Pytorch is extended such that the Markov RNN can have more than 1 hidden layers in every LSTM, where it was restricted to only 1 hidden layer in the original code of [Tensorflow version](https://github.com/NCTUMLlab/Che-Yu-Kuo-MarkovRNN.git). 72 | -------------------------------------------------------------------------------- /kLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import init 5 | from PIL import Image 6 | import torch.nn.functional as F 7 | import math 8 | 9 | def plot_image_series(images, ncols, nrows, photow, photoh, marl, mart, marr, marb, padding): 10 | """ 11 | Make a contact sheet from a group of filenames: 12 | 13 | fnames A list of names of the image files 14 | 15 | ncols Number of columns in the contact sheet 16 | nrows Number of rows in the contact sheet 17 | photow The width of the photo thumbs in pixels 18 | photoh The height of the photo thumbs in pixels 19 | 20 | marl The left margin in pixels 21 | mart The top margin in pixels 22 | marr The right margin in pixels 23 | marl The left margin in pixels 24 | 25 | padding The padding between images in pixels 26 | 27 | returns a PIL image object. 28 | """ 29 | 30 | # Read in all images and resize appropriately 31 | imgs = [Image.fromarray(images[i]).resize((photow,photoh)) for i in range(len(images))] 32 | # imgs = np_img.fromarray() 33 | 34 | # Calculate the size of the output image, based on the 35 | # photo thumb sizes, margins, and padding 36 | marw = marl+marr 37 | marh = mart+ marb 38 | 39 | padw = (ncols-1)*padding 40 | padh = (nrows-1)*padding 41 | isize = (ncols*photow+marw+padw,nrows*photoh+marh+padh) 42 | 43 | # Create the new image. The background doesn't have to be white 44 | white = (255,255,255) 45 | inew = Image.new('RGB',isize,white) 46 | 47 | # Insert each thumb: 48 | for irow in range(nrows): 49 | for icol in range(ncols): 50 | left = marl + icol*(photow+padding) 51 | right = left + photow 52 | upper = mart + irow*(photoh+padding) 53 | lower = upper + photoh 54 | bbox = (left,upper,right,lower) 55 | try: 56 | img = imgs.pop(0) 57 | except: 58 | break 59 | inew.paste(img,bbox) 60 | return inew 61 | 62 | 63 | 64 | class LSTMCell(nn.Module): 65 | """A basic LSTM cell.""" 66 | def __init__(self, input_size, hidden_size, k_cells, is_training=True, use_bias=True): 67 | super(LSTMCell, self).__init__() 68 | self.n = input_size 69 | self.m = hidden_size 70 | self.k = k_cells # number of RNN cells 71 | self.is_training = is_training 72 | self.W_xz = nn.Parameter(torch.FloatTensor(self.n, self.k)) # affine transform x_t -> (z_1, z_2, ..., z_K) 73 | self.W_hz = nn.Parameter(torch.FloatTensor(self.m, self.k)) # affine transform h -> (z_1, z_2, ..., z_K) 74 | 75 | self.W_x_4gates = nn.Parameter(torch.FloatTensor(self.n, 4 * self.m * self.k)) 76 | self.W_h_4gates = nn.Parameter(torch.FloatTensor(self.m, 4 * self.m * self.k)) 77 | 78 | self.use_bias = use_bias 79 | if use_bias: 80 | self.b_xz = nn.Parameter(torch.FloatTensor(self.k)) # size = (1, k) 81 | # self.b_hz = nn.Parameter(torch.FloatTensor(self.k)) # size = (1, k) 82 | self.b_4gates = nn.Parameter(torch.FloatTensor(4 * self.m * self.k)) # size = (1, 4 * m * k) 83 | else: 84 | self.register_parameter('bias', None) 85 | 86 | self.reset_parameters() 87 | print("LSTM cell parameters reseted!") 88 | 89 | 90 | def reset_parameters(self): 91 | stdv1 = 1. / math.sqrt(self.W_xz.data.size(1)) 92 | self.W_xz.data.uniform_(-stdv1, stdv1) 93 | 94 | stdv2 = 1. / math.sqrt(self.W_hz.data.size(1)) 95 | self.W_hz.data.uniform_(-stdv2, stdv2) 96 | 97 | stdv3 = 1. / math.sqrt(self.W_x_4gates.data.size(1)) 98 | self.W_x_4gates.data.uniform_(-stdv3, stdv3) 99 | 100 | stdv4 = 1. / math.sqrt(self.W_h_4gates.data.size(1)) 101 | self.W_h_4gates.data.uniform_(-stdv4, stdv4) 102 | 103 | if self.b_xz is not None: 104 | self.b_xz.data.uniform_(-stdv1, stdv1) 105 | 106 | if self.b_4gates is not None: 107 | self.b_4gates.data.uniform_(-stdv4, stdv4) 108 | 109 | def forward(self, x_t, tau, s_t, is_training): 110 | """ Args: 111 | x_t: input at time step t = (batch, input_size) tensor containing input features. 112 | tau: annealing temperature for Gumbel softmax 113 | s_t: state at time step t (h_0, c_0), which contains the initial hidden 114 | and cell state, where the size of both states is 115 | (batch, hidden_size). 116 | is_training: indicating status of training or not. 117 | Returns: 118 | state = (h_t, c_t): Tensors containing the next hidden and cell state. """ 119 | 120 | h_0, c_0 = s_t # state s_t = (h_t, c_t) with size h_0 = size c_0 = (N, m) 121 | N = h_0.size(0) # batch size 122 | self.tau = tau 123 | self.is_training = is_training 124 | 125 | # expand (repeat) bias for batch processing 126 | batch_b_xz = (self.b_xz.unsqueeze(0).expand(N, *self.b_xz.size())) # size = (N, k) 127 | # batch_b_hz = (self.b_hz.unsqueeze(0).expand(N, *self.b_hz.size())) # size = (N, k) 128 | batch_b_4gates = (self.b_4gates.unsqueeze(0).expand(N, *self.b_4gates.size())) # size = (N, 4* m* k) 129 | 130 | ''' logit encoder: logit_z = (W1 * x_t + b1) + (W2 * h_{t-1} + b2) in R^K ''' 131 | logit_z = torch.addmm(batch_b_xz, x_t, self.W_xz) + torch.mm(h_0, self.W_hz) 132 | 133 | # probability q_z(x_t, h_{t-1}) in R^K 134 | q_z = F.softmax(logit_z, dim=1) 135 | 136 | if self.is_training is True: 137 | z = F.gumbel_softmax(logit_z, self.tau, hard=False, eps=1e-10) 138 | else: 139 | if q_z.is_cuda: 140 | z = torch.cuda.FloatTensor(N, self.k).zero_() # create a GPU zero tensor for 1-hot 141 | else: 142 | z = torch.FloatTensor(N, self.k).zero_() # create a CPU zero tensor for 1-hot 143 | 144 | z.scatter_(1, torch.max(q_z, dim=1)[1].view(N, 1), 1) # find which position is max 145 | 146 | # k LSTM's part 147 | A = torch.mm(x_t, self.W_x_4gates) # (W_x_4gates * x + b_4gates) , output dim = 4* m* k 148 | B = torch.addmm(batch_b_4gates, h_0, self.W_h_4gates) # (W_h_4gates * h_0 + b_4gates), output dim = 4* m* k 149 | ''' this step B is a bit strange? because it may breakdown the independence of the k LSTMs''' 150 | 151 | # 4 gates in LSTM 152 | I_gate, G_gate, F_gate, O_gate = torch.split(A + B, split_size_or_sections=(self.m * self.k), dim=1) 153 | 154 | # Expand c_0 -> C_0 has dim=k 155 | C_0 = c_0.repeat(1, self.k) 156 | 157 | '''for K LSTMs: C_t = (c^1_t, c^2_t, .... , c^K_t ) 158 | H_t = (h^1_t, h^2_t, .... , h^K_t ) [vectorized version for K RNNs] ''' 159 | C_t = torch.mul(torch.sigmoid(F_gate), C_0) + torch.mul(torch.sigmoid(I_gate), torch.tanh(G_gate)) # C_t = new c of K dim 160 | H_t = torch.mul(torch.sigmoid(O_gate), torch.tanh(C_t)) # H_t = new h of K dim 161 | 162 | # reshape C_t & H_t has dim = (N, k, m) (hidden_size = m) 163 | C_t = C_t.view([N, self.k, self.m]) 164 | H_t = H_t.view([N, self.k, self.m]) 165 | 166 | # sum over K LSTMs (like K ensemble) to become 1 LSTM cell & hidden state: (h_t , c_t) 167 | h_t = torch.einsum('nkm,nk->nm', (H_t, z)) # size= (batch, output-dim), no time 168 | c_t = torch.einsum('nkm,nk->nm', (C_t, z)) # size= (batch, output-dim), no time 169 | 170 | return z, q_z, h_t, c_t 171 | 172 | def __repr__(self): 173 | s = '{name}({input_size}, {hidden_size})' 174 | return s.format(name=self.__class__.__name__, **self.__dict__) 175 | 176 | 177 | class LSTM(nn.Module): 178 | """A module that runs multiple steps of LSTM.""" 179 | def __init__(self, cell_class, input_size, hidden_size, output_size, 180 | num_layers=1, k_cells=2, use_bias=True, dropout_prob=0.): 181 | super(LSTM, self).__init__() 182 | self.cell_class = cell_class 183 | self.n = input_size 184 | self.m = hidden_size 185 | self.k = k_cells # number of RNN cells 186 | self.output_size = output_size 187 | self.num_layers = num_layers 188 | self.use_bias = use_bias 189 | self.dropout_prob = dropout_prob 190 | 191 | self.fc = nn.Linear(self.m, self.output_size) # output = CNN embedding latent variables 192 | 193 | for layer in range(num_layers): 194 | layer_input_size = self.n if layer == 0 else self.m 195 | cell = cell_class(input_size=layer_input_size, hidden_size=self.m, k_cells=self.k, use_bias=self.use_bias) 196 | setattr(self, 'cell_{}'.format(layer), cell) 197 | self.dropout_layer = nn.Dropout(dropout_prob) 198 | self.reset_parameters() 199 | 200 | def get_cell(self, layer): 201 | return getattr(self, 'cell_{}'.format(layer)) 202 | 203 | def reset_parameters(self): 204 | for layer in range(self.num_layers): 205 | cell = self.get_cell(layer) 206 | cell.reset_parameters() 207 | 208 | @staticmethod 209 | def _forward_rnn(cell, x, tau, length, s_t, is_training): 210 | T = x.size(1) # time is the second dimension 211 | 212 | z_T = [] # z_T = (z_1, z_2, .... , ) collecting all z's over time 213 | qz_T = [] 214 | h_T = [] # h_T = (h_1, h_2, .... , ) collecting all hidden state h_t over time 215 | S_T = [] # S_T = [ (h_1, c_1), (h_2, c_2), .... , ] collecting all (h_t, c_t) over time 216 | # output = [] 217 | for t in range(T): 218 | # one time step 219 | z, q_z, h_t, c_t = cell(x[:, t, :], tau, s_t, is_training) 220 | 221 | # to bound time steps of time sequences 222 | time_mask = (t < length).float().unsqueeze(1).expand_as(h_t) 223 | h_t, c_t = h_t * time_mask + s_t[0] * (1 - time_mask), c_t * time_mask + s_t[1] * (1 - time_mask) 224 | s_t = (h_t, c_t) # state t = (h_t, c_t) 225 | 226 | h_T.append(h_t) 227 | z_T.append(z) 228 | qz_T.append(q_z) 229 | 230 | z_T = torch.stack(z_T, 0).transpose_(0, 1) # [transpose to batch first], size=(N, T, output-dim) 231 | qz_T = torch.stack(qz_T, 0).transpose_(0, 1) # [transpose to batch first], size=(N, T, output-dim) 232 | h_T = torch.stack(h_T, 0).transpose_(0, 1) # [transpose to batch first], size=(N, T, output-dim) 233 | 234 | return z_T, qz_T, h_T, s_t 235 | 236 | def forward(self, x, tau, length=None, s_t=None, is_training=True): 237 | 238 | # batch is assumed first dimension of input x 239 | N, T, n = x.size() 240 | ''' N = batch size, T = total time steps, n = input feature dimension ''' 241 | 242 | # RNN input temporal length limit 243 | if length is None: 244 | length = Variable(torch.LongTensor([T] * N)) 245 | if x.is_cuda: 246 | device = x.get_device() 247 | length = length.cuda(device) 248 | if s_t is None: 249 | # put an initialization 250 | s_t = Variable(x.data.new(N, self.m).zero_()) 251 | s_t = (s_t, s_t) 252 | 253 | all_layer_h_t = [] 254 | all_layer_c_t = [] 255 | layer_h_T = None 256 | 257 | # creating depth of LSTMs 258 | for layer in range(self.num_layers): 259 | cell = self.get_cell(layer) # get the cell of certain layer 260 | layer_z_T, layer_qz_T, layer_h_T, (layer_h_t, layer_c_t) = LSTM._forward_rnn(cell, x, tau, 261 | length, s_t, is_training) 262 | 263 | ''' x=data input if layer=0; x=hidden units if layer > 0 ''' 264 | x = self.dropout_layer(layer_h_T) 265 | all_layer_h_t.append(layer_h_t) 266 | all_layer_c_t.append(layer_c_t) 267 | 268 | all_layer_h_t = torch.stack(all_layer_h_t, 0) 269 | all_layer_c_t = torch.stack(all_layer_c_t, 0) 270 | 271 | output = self.fc(layer_h_t) 272 | 273 | return output, layer_z_T, layer_qz_T, layer_h_T, (all_layer_h_t, all_layer_c_t) 274 | 275 | --------------------------------------------------------------------------------