├── .gitignore ├── README.md ├── bnlstm.py └── train_mnist.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.pyc 3 | data 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recurrent Batch Normalization 2 | PyTorch implementation of Recurrent Batch Normalization proposed by [Cooijmans et al. (2017)](https://arxiv.org/abs/1603.09025). 3 | -------------------------------------------------------------------------------- /bnlstm.py: -------------------------------------------------------------------------------- 1 | """Implementation of batch-normalized LSTM.""" 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from torch.nn import functional, init 6 | 7 | 8 | class SeparatedBatchNorm1d(nn.Module): 9 | 10 | """ 11 | A batch normalization module which keeps its running mean 12 | and variance separately per timestep. 13 | """ 14 | 15 | def __init__(self, num_features, max_length, eps=1e-5, momentum=0.1, 16 | affine=True): 17 | """ 18 | Most parts are copied from 19 | torch.nn.modules.batchnorm._BatchNorm. 20 | """ 21 | 22 | super(SeparatedBatchNorm1d, self).__init__() 23 | self.num_features = num_features 24 | self.max_length = max_length 25 | self.affine = affine 26 | self.eps = eps 27 | self.momentum = momentum 28 | if self.affine: 29 | self.weight = nn.Parameter(torch.FloatTensor(num_features)) 30 | self.bias = nn.Parameter(torch.FloatTensor(num_features)) 31 | else: 32 | self.register_parameter('weight', None) 33 | self.register_parameter('bias', None) 34 | for i in range(max_length): 35 | self.register_buffer( 36 | 'running_mean_{}'.format(i), torch.zeros(num_features)) 37 | self.register_buffer( 38 | 'running_var_{}'.format(i), torch.ones(num_features)) 39 | self.reset_parameters() 40 | 41 | def reset_parameters(self): 42 | for i in range(self.max_length): 43 | running_mean_i = getattr(self, 'running_mean_{}'.format(i)) 44 | running_var_i = getattr(self, 'running_var_{}'.format(i)) 45 | running_mean_i.zero_() 46 | running_var_i.fill_(1) 47 | if self.affine: 48 | self.weight.data.uniform_() 49 | self.bias.data.zero_() 50 | 51 | def _check_input_dim(self, input_): 52 | if input_.size(1) != self.running_mean_0.nelement(): 53 | raise ValueError('got {}-feature tensor, expected {}' 54 | .format(input_.size(1), self.num_features)) 55 | 56 | def forward(self, input_, time): 57 | self._check_input_dim(input_) 58 | if time >= self.max_length: 59 | time = self.max_length - 1 60 | running_mean = getattr(self, 'running_mean_{}'.format(time)) 61 | running_var = getattr(self, 'running_var_{}'.format(time)) 62 | return functional.batch_norm( 63 | input=input_, running_mean=running_mean, running_var=running_var, 64 | weight=self.weight, bias=self.bias, training=self.training, 65 | momentum=self.momentum, eps=self.eps) 66 | 67 | def __repr__(self): 68 | return ('{name}({num_features}, eps={eps}, momentum={momentum},' 69 | ' max_length={max_length}, affine={affine})' 70 | .format(name=self.__class__.__name__, **self.__dict__)) 71 | 72 | 73 | class LSTMCell(nn.Module): 74 | 75 | """A basic LSTM cell.""" 76 | 77 | def __init__(self, input_size, hidden_size, use_bias=True): 78 | """ 79 | Most parts are copied from torch.nn.LSTMCell. 80 | """ 81 | 82 | super(LSTMCell, self).__init__() 83 | self.input_size = input_size 84 | self.hidden_size = hidden_size 85 | self.use_bias = use_bias 86 | self.weight_ih = nn.Parameter( 87 | torch.FloatTensor(input_size, 4 * hidden_size)) 88 | self.weight_hh = nn.Parameter( 89 | torch.FloatTensor(hidden_size, 4 * hidden_size)) 90 | if use_bias: 91 | self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size)) 92 | else: 93 | self.register_parameter('bias', None) 94 | self.reset_parameters() 95 | 96 | def reset_parameters(self): 97 | """ 98 | Initialize parameters following the way proposed in the paper. 99 | """ 100 | 101 | init.orthogonal(self.weight_ih.data) 102 | weight_hh_data = torch.eye(self.hidden_size) 103 | weight_hh_data = weight_hh_data.repeat(1, 4) 104 | self.weight_hh.data.set_(weight_hh_data) 105 | # The bias is just set to zero vectors. 106 | if self.use_bias: 107 | init.constant(self.bias.data, val=0) 108 | 109 | def forward(self, input_, hx): 110 | """ 111 | Args: 112 | input_: A (batch, input_size) tensor containing input 113 | features. 114 | hx: A tuple (h_0, c_0), which contains the initial hidden 115 | and cell state, where the size of both states is 116 | (batch, hidden_size). 117 | 118 | Returns: 119 | h_1, c_1: Tensors containing the next hidden and cell state. 120 | """ 121 | 122 | h_0, c_0 = hx 123 | batch_size = h_0.size(0) 124 | bias_batch = (self.bias.unsqueeze(0) 125 | .expand(batch_size, *self.bias.size())) 126 | wh_b = torch.addmm(bias_batch, h_0, self.weight_hh) 127 | wi = torch.mm(input_, self.weight_ih) 128 | f, i, o, g = torch.split(wh_b + wi, 129 | split_size=self.hidden_size, dim=1) 130 | c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) 131 | h_1 = torch.sigmoid(o) * torch.tanh(c_1) 132 | return h_1, c_1 133 | 134 | def __repr__(self): 135 | s = '{name}({input_size}, {hidden_size})' 136 | return s.format(name=self.__class__.__name__, **self.__dict__) 137 | 138 | 139 | class BNLSTMCell(nn.Module): 140 | 141 | """A BN-LSTM cell.""" 142 | 143 | def __init__(self, input_size, hidden_size, max_length, use_bias=True): 144 | 145 | super(BNLSTMCell, self).__init__() 146 | self.input_size = input_size 147 | self.hidden_size = hidden_size 148 | self.max_length = max_length 149 | self.use_bias = use_bias 150 | self.weight_ih = nn.Parameter( 151 | torch.FloatTensor(input_size, 4 * hidden_size)) 152 | self.weight_hh = nn.Parameter( 153 | torch.FloatTensor(hidden_size, 4 * hidden_size)) 154 | if use_bias: 155 | self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size)) 156 | else: 157 | self.register_parameter('bias', None) 158 | # BN parameters 159 | self.bn_ih = SeparatedBatchNorm1d( 160 | num_features=4 * hidden_size, max_length=max_length) 161 | self.bn_hh = SeparatedBatchNorm1d( 162 | num_features=4 * hidden_size, max_length=max_length) 163 | self.bn_c = SeparatedBatchNorm1d( 164 | num_features=hidden_size, max_length=max_length) 165 | self.reset_parameters() 166 | 167 | def reset_parameters(self): 168 | """ 169 | Initialize parameters following the way proposed in the paper. 170 | """ 171 | 172 | # The input-to-hidden weight matrix is initialized orthogonally. 173 | init.orthogonal(self.weight_ih.data) 174 | # The hidden-to-hidden weight matrix is initialized as an identity 175 | # matrix. 176 | weight_hh_data = torch.eye(self.hidden_size) 177 | weight_hh_data = weight_hh_data.repeat(1, 4) 178 | self.weight_hh.data.set_(weight_hh_data) 179 | # The bias is just set to zero vectors. 180 | init.constant(self.bias.data, val=0) 181 | # Initialization of BN parameters. 182 | self.bn_ih.reset_parameters() 183 | self.bn_hh.reset_parameters() 184 | self.bn_c.reset_parameters() 185 | self.bn_ih.bias.data.fill_(0) 186 | self.bn_hh.bias.data.fill_(0) 187 | self.bn_ih.weight.data.fill_(0.1) 188 | self.bn_hh.weight.data.fill_(0.1) 189 | self.bn_c.weight.data.fill_(0.1) 190 | 191 | def forward(self, input_, hx, time): 192 | """ 193 | Args: 194 | input_: A (batch, input_size) tensor containing input 195 | features. 196 | hx: A tuple (h_0, c_0), which contains the initial hidden 197 | and cell state, where the size of both states is 198 | (batch, hidden_size). 199 | time: The current timestep value, which is used to 200 | get appropriate running statistics. 201 | 202 | Returns: 203 | h_1, c_1: Tensors containing the next hidden and cell state. 204 | """ 205 | 206 | h_0, c_0 = hx 207 | batch_size = h_0.size(0) 208 | bias_batch = (self.bias.unsqueeze(0) 209 | .expand(batch_size, *self.bias.size())) 210 | wh = torch.mm(h_0, self.weight_hh) 211 | wi = torch.mm(input_, self.weight_ih) 212 | bn_wh = self.bn_hh(wh, time=time) 213 | bn_wi = self.bn_ih(wi, time=time) 214 | f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch, 215 | split_size=self.hidden_size, dim=1) 216 | c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) 217 | h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time)) 218 | return h_1, c_1 219 | 220 | 221 | class LSTM(nn.Module): 222 | 223 | """A module that runs multiple steps of LSTM.""" 224 | 225 | def __init__(self, cell_class, input_size, hidden_size, num_layers=1, 226 | use_bias=True, batch_first=False, dropout=0, **kwargs): 227 | super(LSTM, self).__init__() 228 | self.cell_class = cell_class 229 | self.input_size = input_size 230 | self.hidden_size = hidden_size 231 | self.num_layers = num_layers 232 | self.use_bias = use_bias 233 | self.batch_first = batch_first 234 | self.dropout = dropout 235 | 236 | for layer in range(num_layers): 237 | layer_input_size = input_size if layer == 0 else hidden_size 238 | cell = cell_class(input_size=layer_input_size, 239 | hidden_size=hidden_size, 240 | **kwargs) 241 | setattr(self, 'cell_{}'.format(layer), cell) 242 | self.dropout_layer = nn.Dropout(dropout) 243 | self.reset_parameters() 244 | 245 | def get_cell(self, layer): 246 | return getattr(self, 'cell_{}'.format(layer)) 247 | 248 | def reset_parameters(self): 249 | for layer in range(self.num_layers): 250 | cell = self.get_cell(layer) 251 | cell.reset_parameters() 252 | 253 | @staticmethod 254 | def _forward_rnn(cell, input_, length, hx): 255 | max_time = input_.size(0) 256 | output = [] 257 | for time in range(max_time): 258 | if isinstance(cell, BNLSTMCell): 259 | h_next, c_next = cell(input_=input_[time], hx=hx, time=time) 260 | else: 261 | h_next, c_next = cell(input_=input_[time], hx=hx) 262 | mask = (time < length).float().unsqueeze(1).expand_as(h_next) 263 | h_next = h_next*mask + hx[0]*(1 - mask) 264 | c_next = c_next*mask + hx[1]*(1 - mask) 265 | hx_next = (h_next, c_next) 266 | output.append(h_next) 267 | hx = hx_next 268 | output = torch.stack(output, 0) 269 | return output, hx 270 | 271 | def forward(self, input_, length=None, hx=None): 272 | if self.batch_first: 273 | input_ = input_.transpose(0, 1) 274 | max_time, batch_size, _ = input_.size() 275 | if length is None: 276 | length = Variable(torch.LongTensor([max_time] * batch_size)) 277 | if input_.is_cuda: 278 | device = input_.get_device() 279 | length = length.cuda(device) 280 | if hx is None: 281 | hx = (Variable(nn.init.xavier_uniform(weight.new(self.num_layers, batch_size, self.hidden_size))), 282 | Variable(nn.init.xavier_uniform(weight.new(self.num_layers, batch_size, self.hidden_size)))) 283 | h_n = [] 284 | c_n = [] 285 | layer_output = None 286 | for layer in range(self.num_layers): 287 | cell = self.get_cell(layer) 288 | hx_layer = (hx[0][layer,:,:], hx[1][layer,:,:]) 289 | 290 | if layer == 0: 291 | layer_output, (layer_h_n, layer_c_n) = LSTM._forward_rnn( 292 | cell=cell, input_=input_, length=length, hx=hx_layer) 293 | else: 294 | layer_output, (layer_h_n, layer_c_n) = LSTM._forward_rnn( 295 | cell=cell, input_=layer_output, length=length, hx=hx_layer) 296 | 297 | input_ = self.dropout_layer(layer_output) 298 | h_n.append(layer_h_n) 299 | c_n.append(layer_c_n) 300 | output = layer_output 301 | h_n = torch.stack(h_n, 0) 302 | c_n = torch.stack(c_n, 0) 303 | return output, (h_n, c_n) 304 | -------------------------------------------------------------------------------- /train_mnist.py: -------------------------------------------------------------------------------- 1 | """Train the model using MNIST dataset.""" 2 | import argparse 3 | import os 4 | from datetime import datetime 5 | from functools import partial 6 | 7 | import torch 8 | from torch import nn, optim 9 | from torch.autograd import Variable 10 | from torch.nn.utils import clip_grad_norm 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets, transforms 13 | from pycrayon import CrayonClient 14 | 15 | from bnlstm import LSTM, LSTMCell, BNLSTMCell 16 | 17 | 18 | def transform_flatten(tensor): 19 | return tensor.view(-1, 1).contiguous() 20 | 21 | 22 | def transform_permute(tensor, perm): 23 | return tensor.index_select(0, perm) 24 | 25 | 26 | def main(): 27 | data_path = args.data 28 | model_name = args.model 29 | save_dir = args.save 30 | hidden_size = args.hidden_size 31 | pmnist = args.pmnist 32 | batch_size = args.batch_size 33 | max_iter = args.max_iter 34 | use_gpu = args.gpu 35 | 36 | if not os.path.exists(save_dir): 37 | os.makedirs(save_dir) 38 | 39 | if pmnist: 40 | perm = torch.randperm(784) 41 | else: 42 | perm = torch.arange(0, 784).long() 43 | train_dataset = datasets.MNIST( 44 | root=data_path, train=True, 45 | transform=transforms.Compose([transforms.ToTensor(), 46 | transform_flatten, 47 | partial(transform_permute, perm=perm)]), 48 | download=True) 49 | valid_dataset = datasets.MNIST( 50 | root=data_path, train=False, 51 | transform=transforms.Compose([transforms.ToTensor(), 52 | transform_flatten, 53 | partial(transform_permute, perm=perm)]), 54 | download=True) 55 | 56 | tb_client = CrayonClient() 57 | tb_xp_name = '{}-{}'.format(datetime.now().strftime("%y%m%d-%H%M%S"), 58 | save_dir) 59 | tb_xp_train = tb_client.create_experiment('{}/train'.format(tb_xp_name)) 60 | tb_xp_valid = tb_client.create_experiment('{}/valid'.format(tb_xp_name)) 61 | 62 | if model_name == 'bnlstm': 63 | model = LSTM(cell_class=BNLSTMCell, input_size=1, 64 | hidden_size=hidden_size, batch_first=True, 65 | max_length=784) 66 | elif model_name == 'lstm': 67 | model = LSTM(cell_class=LSTMCell, input_size=1, 68 | hidden_size=hidden_size, batch_first=True) 69 | else: 70 | raise ValueError 71 | fc = nn.Linear(in_features=hidden_size, out_features=10) 72 | loss_fn = nn.CrossEntropyLoss() 73 | params = list(model.parameters()) + list(fc.parameters()) 74 | optimizer = optim.RMSprop(params=params, lr=1e-3, momentum=0.9) 75 | 76 | def compute_loss_accuracy(data, label): 77 | hx = None 78 | if not pmnist: 79 | h0 = Variable(data.data.new(data.size(0), hidden_size) 80 | .normal_(0, 0.1)) 81 | c0 = Variable(data.data.new(data.size(0), hidden_size) 82 | .normal_(0, 0.1)) 83 | hx = (h0, c0) 84 | _, (h_n, _) = model(input_=data, hx=hx) 85 | logits = fc(h_n[0]) 86 | loss = loss_fn(input=logits, target=label) 87 | accuracy = (logits.max(1)[1] == label).float().mean() 88 | return loss, accuracy 89 | 90 | if use_gpu: 91 | model.cuda() 92 | fc.cuda() 93 | 94 | iter_cnt = 0 95 | valid_loader = DataLoader(dataset=valid_dataset, 96 | batch_size=batch_size, 97 | shuffle=True, pin_memory=True) 98 | while iter_cnt < max_iter: 99 | train_loader = DataLoader(dataset=train_dataset, 100 | batch_size=batch_size, 101 | shuffle=True, pin_memory=True) 102 | for train_batch in train_loader: 103 | train_data, train_label = train_batch 104 | train_data = Variable(train_data) 105 | train_label = Variable(train_label) 106 | if use_gpu: 107 | train_data = train_data.cuda() 108 | train_label = train_label.cuda() 109 | model.train(True) 110 | model.zero_grad() 111 | train_loss, train_accuracy = compute_loss_accuracy( 112 | data=train_data, label=train_label) 113 | train_loss.backward() 114 | clip_grad_norm(parameters=params, max_norm=1) 115 | optimizer.step() 116 | tb_xp_train.add_scalar_dict( 117 | data={'loss': train_loss.data[0], 118 | 'accuracy': train_accuracy.data[0]}, 119 | step=iter_cnt) 120 | 121 | if iter_cnt % 50 == 49: 122 | for valid_batch in valid_loader: 123 | valid_data, valid_label = valid_batch 124 | # Dirty, but don't get other solutions 125 | break 126 | valid_data = Variable(valid_data, volatile=True) 127 | valid_label = Variable(valid_label, volatile=True) 128 | if use_gpu: 129 | valid_data = valid_data.cuda() 130 | valid_label = valid_label.cuda() 131 | model.train(False) 132 | valid_loss, valid_accuracy = compute_loss_accuracy( 133 | data=valid_data, label=valid_label) 134 | tb_xp_valid.add_scalar_dict( 135 | data={'loss': valid_loss.data[0], 136 | 'accuracy': valid_accuracy.data[0]}, 137 | step=iter_cnt) 138 | save_path = '{}/{}'.format(save_dir, iter_cnt) 139 | torch.save(model, save_path) 140 | iter_cnt += 1 141 | if iter_cnt == max_iter: 142 | break 143 | 144 | 145 | if __name__ == '__main__': 146 | parser = argparse.ArgumentParser('Train the model using MNIST dataset.') 147 | parser.add_argument('--data', required=True, 148 | help='The path to save MNIST dataset, or ' 149 | 'the path the dataset is located') 150 | parser.add_argument('--model', required=True, choices=['lstm', 'bnlstm'], 151 | help='The name of a model to use') 152 | parser.add_argument('--save', required=True, 153 | help='The path to save model files') 154 | parser.add_argument('--hidden-size', required=True, type=int, 155 | help='The number of hidden units') 156 | parser.add_argument('--pmnist', default=False, action='store_true', 157 | help='If set, it uses permutated-MNIST dataset') 158 | parser.add_argument('--batch-size', required=True, type=int, 159 | help='The size of each batch') 160 | parser.add_argument('--max-iter', required=True, type=int, 161 | help='The maximum iteration count') 162 | parser.add_argument('--gpu', default=False, action='store_true', 163 | help='The value specifying whether to use GPU') 164 | args = parser.parse_args() 165 | main() 166 | --------------------------------------------------------------------------------