├── README.md └── bnlstm.py /README.md: -------------------------------------------------------------------------------- 1 | # Batch normalized LSTM with Pytorch 2 | An implementation of [Recurrent Batch Normalization](https://arxiv.org/abs/1603.09025) by Cooijmans et al. 3 | 4 | ## Requirements 5 | - pytorch 0.4 6 | - python 3.x -------------------------------------------------------------------------------- /bnlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class BNLSTMCell(nn.Module): 5 | def __init__(self, input_size, hidden_size): 6 | super(BNLSTMCell, self).__init__() 7 | self.input_size = input_size 8 | self.hidden_size = hidden_size 9 | self.weight_ih = nn.Parameter(torch.Tensor(input_size, 4 * hidden_size)) 10 | self.weight_hh = nn.Parameter(torch.Tensor(hidden_size, 4 * hidden_size)) 11 | self.bias = nn.Parameter(torch.zeros(4 * hidden_size)) 12 | 13 | self.bn_ih = nn.BatchNorm1d(4 * self.hidden_size, affine=False) 14 | self.bn_hh = nn.BatchNorm1d(4 * self.hidden_size, affine=False) 15 | self.bn_c = nn.BatchNorm1d(self.hidden_size) 16 | 17 | self.reset_parameters() 18 | 19 | def reset_parameters(self): 20 | nn.init.orthogonal_(self.weight_ih.data) 21 | nn.init.orthogonal_(self.weight_hh.data[:, :self.hidden_size]) 22 | nn.init.orthogonal_(self.weight_hh.data[:, self.hidden_size:2 * self.hidden_size]) 23 | nn.init.orthogonal_(self.weight_hh.data[:, 2 * self.hidden_size:3 * self.hidden_size]) 24 | nn.init.eye_(self.weight_hh.data[:, 3 * self.hidden_size:]) 25 | self.weight_hh.data[:, 3 * self.hidden_size:] *= 0.95 26 | 27 | 28 | def forward(self, input, hx): 29 | h, c = hx 30 | ih = torch.matmul(input, self.weight_ih) 31 | hh = torch.matmul(h, self.weight_hh) 32 | bn_ih = self.bn_ih(ih) 33 | bn_hh = self.bn_hh(hh) 34 | hidden = bn_ih + bn_hh + self.bias 35 | 36 | i, f, o, g = torch.split(hidden, split_size_or_sections=self.hidden_size, dim=1) 37 | new_c = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g) 38 | new_h = torch.sigmoid(o) * torch.tanh(self.bn_c(new_c)) 39 | return (new_h, new_c) 40 | 41 | 42 | class BNLSTM(nn.Module): 43 | def __init__(self, input_size, hidden_size, batch_first=False, bidirectional=False): 44 | super(BNLSTM, self).__init__() 45 | self.input_size = input_size 46 | self.hidden_size = hidden_size 47 | self.batch_first = batch_first 48 | self.bidirectional = bidirectional 49 | 50 | self.lstm_f = BNLSTMCell(input_size, hidden_size) 51 | if bidirectional: 52 | self.lstm_b = BNLSTMCell(input_size, hidden_size) 53 | self.h0 = nn.Parameter(torch.Tensor(2 if self.bidirectional else 1, 1, self.hidden_size)) 54 | self.c0 = nn.Parameter(torch.Tensor(2 if self.bidirectional else 1, 1, self.hidden_size)) 55 | nn.init.normal_(self.h0, mean=0, std=0.1) 56 | nn.init.normal_(self.c0, mean=0, std=0.1) 57 | 58 | def forward(self, input, hx=None): 59 | if not self.batch_first: 60 | input = input.transpose(0, 1) 61 | batch_size, seq_len, dim = input.size() 62 | if hx: init_state = hx 63 | else: init_state = (self.h0.repeat(1, batch_size, 1), self.c0.repeat(1, batch_size, 1)) 64 | 65 | hiddens_f = [] 66 | final_hx_f = None 67 | hx = (init_state[0][0], init_state[1][0]) 68 | for i in range(seq_len): 69 | hx = self.lstm_f(input[:, i, :], hx) 70 | hiddens_f.append(hx[0]) 71 | final_hx_f = hx 72 | hiddens_f = torch.stack(hiddens_f, 1) 73 | 74 | if self.bidirectional: 75 | hiddens_b = [] 76 | final_hx_b = None 77 | hx = (init_state[0][1], init_state[1][1]) 78 | for i in range(seq_len-1, -1, -1): 79 | hx = self.lstm_b(input[:, i, :], hx) 80 | hiddens_b.append(hx[0]) 81 | final_hx_b = hx 82 | hiddens_b.reverse() 83 | hiddens_b = torch.stack(hiddens_b, 1) 84 | 85 | if self.bidirectional: 86 | hiddens = torch.cat([hiddens_f, hiddens_b], -1) 87 | hx = (torch.stack([final_hx_f[0], final_hx_b[0]], 0), torch.stack([final_hx_f[1], final_hx_b[1]], 0)) 88 | else: 89 | hiddens = hiddens_f 90 | hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(1)) 91 | if not self.batch_first: 92 | hiddens = hiddens.transpose(0, 1) 93 | return hiddens, hx 94 | 95 | 96 | if __name__ == '__main__': 97 | lstm = BNLSTM(input_size=10, hidden_size=7, batch_first=False, bidirectional=False)#.cuda(0) 98 | input = torch.randn(3, 11, 10)#.cuda(0) 99 | o, h = lstm(input) 100 | o = torch.sum(o) 101 | o.backward() --------------------------------------------------------------------------------