├── LICENSE ├── README.md ├── __init__.py ├── batch_normalization_LSTM.py ├── data ├── processed │ ├── test.pt │ └── training.pt └── raw │ ├── t10k-images-idx3-ubyte │ ├── t10k-labels-idx1-ubyte │ ├── train-images-idx3-ubyte │ └── train-labels-idx1-ubyte └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 JiKang Nie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Implementation of Batch Normalization LSTM in Pytorch 2 | 3 | *Tim Cooijmans etl.* [Recurrent Batch Normalization(arxiv1603.09025)](https://arxiv.org/abs/1603.09025) 4 | 5 | Frok from [sysuNie](https://github.com/sysuNie/batch_normalized_LSTM) 6 | 7 | Modified to be compatible with Pytorch 1.0.0 8 | 9 | # To use: 10 | 11 | ```sh 12 | import torch 13 | import torch.nn as nn 14 | from batch_normalization_LSTM import BNLSTMCell, LSTM 15 | 16 | 17 | model = LSTM(cell_class=BNLSTMCell, input_size=28, hidden_size=512, batch_first=True, max_length=152) 18 | 19 | if __name__ == "__main__": 20 | size = 28 21 | dummy = torch.rand(300, 2, size) 22 | out = model(dummy) 23 | print(model) 24 | print(out[0]) 25 | ``` 26 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h-jia/batch_normalized_LSTM/4bea527397eaf39298df43e77eb16c293c781b98/__init__.py -------------------------------------------------------------------------------- /batch_normalization_LSTM.py: -------------------------------------------------------------------------------- 1 | """Implementation of batch-normalized LSTM.""" 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional, init 5 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | 8 | 9 | class SeparatedBatchNorm1d(nn.Module): 10 | """ 11 | A batch normalization module which keeps its running mean 12 | and variance separately per timestep. 13 | """ 14 | 15 | def __init__(self, num_features, max_length, eps=1e-5, momentum=0.1, 16 | affine=True): 17 | """ 18 | Most parts are copied from 19 | torch.nn.modules.batchnorm._BatchNorm. 20 | """ 21 | 22 | super(SeparatedBatchNorm1d, self).__init__() 23 | self.num_features = num_features 24 | self.max_length = max_length 25 | self.affine = affine 26 | self.eps = eps 27 | self.momentum = momentum 28 | if self.affine: 29 | self.weight = nn.Parameter(torch.FloatTensor(num_features)) 30 | self.bias = nn.Parameter(torch.FloatTensor(num_features)) 31 | else: 32 | self.register_parameter('weight', None) 33 | self.register_parameter('bias', None) 34 | for i in range(max_length): 35 | self.register_buffer( 36 | 'running_mean_{}'.format(i), torch.zeros(num_features)) 37 | self.register_buffer( 38 | 'running_var_{}'.format(i), torch.ones(num_features)) 39 | self.reset_parameters() 40 | 41 | def reset_parameters(self): 42 | for i in range(self.max_length): 43 | running_mean_i = getattr(self, 'running_mean_{}'.format(i)) 44 | running_var_i = getattr(self, 'running_var_{}'.format(i)) 45 | running_mean_i.zero_() 46 | running_var_i.fill_(1) 47 | if self.affine: 48 | self.weight.data.uniform_() 49 | self.bias.data.zero_() 50 | 51 | def _check_input_dim(self, input_): 52 | if input_.size(1) != self.running_mean_0.nelement(): 53 | raise ValueError('got {}-feature tensor, expected {}' 54 | .format(input_.size(1), self.num_features)) 55 | 56 | def forward(self, input_, time): 57 | self._check_input_dim(input_) 58 | if time >= self.max_length: 59 | time = self.max_length - 1 60 | running_mean = getattr(self, 'running_mean_{}'.format(time)) 61 | running_var = getattr(self, 'running_var_{}'.format(time)) 62 | return functional.batch_norm( 63 | input=input_, running_mean=running_mean, running_var=running_var, 64 | weight=self.weight, bias=self.bias, training=self.training, 65 | momentum=self.momentum, eps=self.eps) 66 | 67 | def __repr__(self): 68 | return ('{name}({num_features}, eps={eps}, momentum={momentum},' 69 | ' max_length={max_length}, affine={affine})' 70 | .format(name=self.__class__.__name__, **self.__dict__)) 71 | 72 | 73 | class LSTMCell(nn.Module): 74 | """A basic LSTM cell.""" 75 | 76 | def __init__(self, input_size, hidden_size, use_bias=True): 77 | """ 78 | Most parts are copied from torch.nn.LSTMCell. 79 | """ 80 | 81 | super(LSTMCell, self).__init__() 82 | self.input_size = input_size 83 | self.hidden_size = hidden_size 84 | self.use_bias = use_bias 85 | self.weight_ih = nn.Parameter( 86 | torch.FloatTensor(input_size, 4 * hidden_size)) 87 | self.weight_hh = nn.Parameter( 88 | torch.FloatTensor(hidden_size, 4 * hidden_size)) 89 | if use_bias: 90 | self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size)) 91 | else: 92 | self.register_parameter('bias', None) 93 | self.reset_parameters() 94 | 95 | def reset_parameters(self): 96 | """ 97 | Initialize parameters following the way proposed in the paper. 98 | """ 99 | 100 | init.orthogonal_(self.weight_ih.data) 101 | weight_hh_data = torch.eye(self.hidden_size) 102 | weight_hh_data = weight_hh_data.repeat(1, 4) 103 | with torch.no_grad(): 104 | self.weight_hh.set_(weight_hh_data) 105 | # The bias is just set to zero vectors. 106 | if self.use_bias: 107 | init.constant_(self.bias.data, val=0) 108 | 109 | def forward(self, input_, hx): 110 | """ 111 | Args: 112 | input_: A (batch, input_size) tensor containing input 113 | features. 114 | hx: A tuple (h_0, c_0), which contains the initial hidden 115 | and cell state, where the size of both states is 116 | (batch, hidden_size). 117 | Returns: 118 | h_1, c_1: Tensors containing the next hidden and cell state. 119 | """ 120 | 121 | h_0, c_0 = hx 122 | batch_size = h_0.size(0) 123 | bias_batch = (self.bias.unsqueeze(0) 124 | .expand(batch_size, *self.bias.size())) 125 | wh_b = torch.addmm(bias_batch, h_0, self.weight_hh) 126 | wi = torch.mm(input_, self.weight_ih) 127 | f, i, o, g = torch.split(wh_b + wi, self.hidden_size, dim=1) 128 | c_1 = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g) 129 | h_1 = torch.sigmoid(o) * torch.tanh(c_1) 130 | return h_1, c_1 131 | 132 | def __repr__(self): 133 | s = '{name}({input_size}, {hidden_size})' 134 | return s.format(name=self.__class__.__name__, **self.__dict__) 135 | 136 | 137 | class BNLSTMCell(nn.Module): 138 | """A BN-LSTM cell.""" 139 | 140 | def __init__(self, input_size, hidden_size, max_length, use_bias=True): 141 | 142 | super(BNLSTMCell, self).__init__() 143 | self.input_size = input_size 144 | self.hidden_size = hidden_size 145 | self.max_length = max_length 146 | self.use_bias = use_bias 147 | self.weight_ih = nn.Parameter( 148 | torch.FloatTensor(input_size, 4 * hidden_size)) 149 | self.weight_hh = nn.Parameter( 150 | torch.FloatTensor(hidden_size, 4 * hidden_size)) 151 | if use_bias: 152 | self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size)) 153 | else: 154 | self.register_parameter('bias', None) 155 | # BN parameters 156 | self.bn_ih = SeparatedBatchNorm1d( 157 | num_features=4 * hidden_size, max_length=max_length) 158 | self.bn_hh = SeparatedBatchNorm1d( 159 | num_features=4 * hidden_size, max_length=max_length) 160 | self.bn_c = SeparatedBatchNorm1d( 161 | num_features=hidden_size, max_length=max_length) 162 | self.reset_parameters() 163 | 164 | def reset_parameters(self): 165 | """ 166 | Initialize parameters following the way proposed in the paper. 167 | """ 168 | 169 | # The input-to-hidden weight matrix is initialized orthogonally. 170 | init.orthogonal_(self.weight_ih.data) 171 | # The hidden-to-hidden weight matrix is initialized as an identity 172 | # matrix. 173 | weight_hh_data = torch.eye(self.hidden_size) 174 | weight_hh_data = weight_hh_data.repeat(1, 4) 175 | with torch.no_grad(): 176 | self.weight_hh.set_(weight_hh_data) 177 | # The bias is just set to zero vectors. 178 | init.constant_(self.bias.data, val=0) 179 | # Initialization of BN parameters. 180 | self.bn_ih.reset_parameters() 181 | self.bn_hh.reset_parameters() 182 | self.bn_c.reset_parameters() 183 | self.bn_ih.bias.data.fill_(0) 184 | self.bn_hh.bias.data.fill_(0) 185 | self.bn_ih.weight.data.fill_(0.1) 186 | self.bn_hh.weight.data.fill_(0.1) 187 | self.bn_c.weight.data.fill_(0.1) 188 | 189 | def forward(self, input_, hx, time): 190 | """ 191 | Args: 192 | input_: A (batch, input_size) tensor containing input 193 | features. 194 | hx: A tuple (h_0, c_0), which contains the initial hidden 195 | and cell state, where the size of both states is 196 | (batch, hidden_size). 197 | time: The current timestep value, which is used to 198 | get appropriate running statistics. 199 | Returns: 200 | h_1, c_1: Tensors containing the next hidden and cell state. 201 | """ 202 | 203 | h_0, c_0 = hx 204 | batch_size = h_0.size(0) 205 | bias_batch = (self.bias.unsqueeze(0) 206 | .expand(batch_size, *self.bias.size())) 207 | wh = torch.mm(h_0, self.weight_hh) 208 | wi = torch.mm(input_, self.weight_ih) 209 | bn_wh = self.bn_hh(wh, time=time) 210 | bn_wi = self.bn_ih(wi, time=time) 211 | f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch, self.hidden_size, dim=1) 212 | c_1 = torch.sigmoid(f) * c_0 + torch.sigmoid(i) * torch.tanh(g) 213 | h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time)) 214 | return h_1, c_1 215 | 216 | 217 | class LSTM(nn.Module): 218 | """A module that runs multiple steps of LSTM.""" 219 | 220 | def __init__(self, cell_class, input_size, hidden_size, num_layers=1, 221 | use_bias=True, batch_first=False, dropout=0, **kwargs): 222 | super(LSTM, self).__init__() 223 | self.cell_class = cell_class 224 | self.input_size = input_size 225 | self.hidden_size = hidden_size 226 | self.num_layers = num_layers 227 | self.use_bias = use_bias 228 | self.batch_first = batch_first 229 | self.dropout = dropout 230 | 231 | for layer in range(num_layers): 232 | layer_input_size = input_size if layer == 0 else hidden_size 233 | cell = cell_class(input_size=layer_input_size, 234 | hidden_size=hidden_size, 235 | **kwargs) 236 | setattr(self, 'cell_{}'.format(layer), cell) 237 | self.dropout_layer = nn.Dropout(dropout) 238 | self.reset_parameters() 239 | 240 | def get_cell(self, layer): 241 | return getattr(self, 'cell_{}'.format(layer)) 242 | 243 | def reset_parameters(self): 244 | for layer in range(self.num_layers): 245 | cell = self.get_cell(layer) 246 | cell.reset_parameters() 247 | 248 | @staticmethod 249 | def _forward_rnn(cell, input_, length, hx): 250 | max_time = input_.size(0) 251 | output = [] 252 | for time in range(max_time): 253 | if isinstance(cell, BNLSTMCell): 254 | h_next, c_next = cell(input_=input_[time], hx=hx, time=time) 255 | else: 256 | h_next, c_next = cell(input_=input_[time], hx=hx) 257 | mask = (time < length).float().unsqueeze(1).expand_as(h_next).to(device) 258 | h_next = (h_next * mask + hx[0] * (1 - mask)).to(device) 259 | c_next = (c_next * mask + hx[1] * (1 - mask)).to(device) 260 | hx_next = (h_next, c_next) 261 | output.append(h_next) 262 | hx = hx_next 263 | output = torch.stack(output, 0) 264 | return output, hx 265 | 266 | def forward(self, input_, length=None, hx=None): 267 | if self.batch_first: 268 | input_ = input_.transpose(0, 1) 269 | max_time, batch_size, _ = input_.size() 270 | if length is None: 271 | length = torch.LongTensor([max_time] * batch_size) 272 | # if input_.is_cuda: 273 | # device = input_.get_device() 274 | # length = length.cuda(device) 275 | if hx is None: 276 | hx = input_.data.new(batch_size, self.hidden_size).zero_() 277 | hx = (hx, hx) 278 | h_n = [] 279 | c_n = [] 280 | layer_output = None 281 | for layer in range(self.num_layers): 282 | cell = self.get_cell(layer) 283 | layer_output, (layer_h_n, layer_c_n) = LSTM._forward_rnn( 284 | cell=cell, input_=input_, length=length, hx=hx) 285 | input_ = self.dropout_layer(layer_output) 286 | h_n.append(layer_h_n) 287 | c_n.append(layer_c_n) 288 | output = layer_output 289 | h_n = torch.stack(h_n, 0) 290 | c_n = torch.stack(c_n, 0) 291 | return output, (h_n, c_n) 292 | -------------------------------------------------------------------------------- /data/processed/test.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h-jia/batch_normalized_LSTM/4bea527397eaf39298df43e77eb16c293c781b98/data/processed/test.pt -------------------------------------------------------------------------------- /data/processed/training.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h-jia/batch_normalized_LSTM/4bea527397eaf39298df43e77eb16c293c781b98/data/processed/training.pt -------------------------------------------------------------------------------- /data/raw/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h-jia/batch_normalized_LSTM/4bea527397eaf39298df43e77eb16c293c781b98/data/raw/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /data/raw/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /data/raw/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h-jia/batch_normalized_LSTM/4bea527397eaf39298df43e77eb16c293c781b98/data/raw/train-images-idx3-ubyte -------------------------------------------------------------------------------- /data/raw/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h-jia/batch_normalized_LSTM/4bea527397eaf39298df43e77eb16c293c781b98/data/raw/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import argparse 3 | 4 | import torch 5 | from torch import nn, optim 6 | from torch.utils.data import DataLoader 7 | from torch.nn.utils import clip_grad_norm 8 | from torchvision import datasets, transforms 9 | from torch.autograd import Variable 10 | from batch_normalization_LSTM import LSTMCell, BNLSTMCell, LSTM 11 | 12 | parser = argparse.ArgumentParser(description='PyTorch train BNLSTM model on MNIST dataset') 13 | 14 | parser.add_argument('--data', type=str, default='data/', 15 | help='The path to save MNIST dataset') 16 | parser.add_argument('--model', required=True, choices=['lstm', 'bnlstm'], 17 | help='The name of a model to use') 18 | parser.add_argument('--save', type=str, default='log/', 19 | help='The path to save model files') 20 | parser.add_argument('--hidden-size', type=int, default=1000, 21 | help='The number of hidden unit size') 22 | parser.add_argument('--batch-size', type=int, default=128, 23 | help='The size of each batch') 24 | parser.add_argument('--epoches', type=int, default=20, 25 | help='The iteration count') 26 | parser.add_argument('--cuda', default=False, action='store_true', 27 | help='The value specifying whether to use GPU') 28 | parser.add_argument('--seed', type=int, default=1, metavar='S', 29 | help='random seed (default: 1)') 30 | args = parser.parse_args() 31 | 32 | 33 | torch.manual_seed(args.seed) 34 | if args.cuda: 35 | torch.cuda.manual_seed(args.cuda) 36 | 37 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 38 | train_dataset = datasets.MNIST( 39 | root=args.data, train=True, 40 | transform=transforms.Compose([transforms.ToTensor(), 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.1307,), (0.3081,)) 43 | ]), download=True) 44 | test_dataset = datasets.MNIST( 45 | root=args.data, train=False, 46 | transform=transforms.Compose([transforms.ToTensor(), 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.1307,), (0.3081,)) 49 | ]), download=True) 50 | train_loader = DataLoader(dataset=train_dataset, 51 | batch_size=args.batch_size, 52 | shuffle=True, pin_memory=True) 53 | 54 | test_dataset = DataLoader(dataset=test_dataset, 55 | batch_size=args.batch_size, 56 | shuffle=True, pin_memory=True) 57 | 58 | model_name = args.model 59 | if model_name == 'bnlstm': 60 | model = LSTM(cell_class=BNLSTMCell, input_size=28, 61 | hidden_size=args.hidden_size, batch_first=True, max_length=784) 62 | elif model_name == 'lstm': 63 | model = LSTM(cell_class=LSTMCell, input_size=28, 64 | hidden_size=args.hidden_size, batch_first=True) 65 | else: 66 | raise ValueError 67 | 68 | fc = nn.Linear(in_features=args.hidden_size, out_features=10) 69 | criterion = nn.CrossEntropyLoss() 70 | params = list(model.parameters()) + list(fc.parameters()) 71 | optimizer = optim.SGD(params=params, lr=1e-3, momentum=0.9) 72 | 73 | def computer_loss(data, label): 74 | h0 = Variable(data.data.new(data.size(0), args.hidden_size).normal_(0, 0.1)) 75 | c0 = Variable(data.data.new(data.size(0), args.hidden_size).normal_(0, 0.1)) 76 | hx = (h0, c0) 77 | 78 | _, (h_n, _) = model(input_=data, hx=hx) 79 | logits = fc(h_n[0]) 80 | loss = criterion(input=logits, target=label) 81 | return loss 82 | 83 | for epoch in range(args.epoches): 84 | for i, (images, labels) in enumerate(train_loader): 85 | images = Variable(images) 86 | labels = Variable(labels) 87 | if args.cuda(): 88 | images = Variable(images).cuda() 89 | labels = Variable(labels).cuda() 90 | optimizer.zero_grad() 91 | loss = computer_loss(data=images, labels=labels) 92 | loss.backend() 93 | clip_grad_norm(parameters=params, max_norm=1) 94 | optimizer.step() 95 | 96 | if(i+1) % 100 == 0: 97 | if (i + 1) % 100 == 0: 98 | print ('Epoch [%d/%d],Iter [%d/%d] Loss: %.4f' 99 | % (epoch + 1, args.epoches, i + 1, len(train_dataset) // args.batch_size, 100 | loss.data[0])) 101 | --------------------------------------------------------------------------------