├── README.md ├── data.py ├── data └── penn │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── generate.py ├── rev_main.py └── reversible.py /README.md: -------------------------------------------------------------------------------- 1 | # Reversible Recurrent Neural Network 2 | Pytorch implementation for Reversible Recurrent Neural Networks. Data for Penn Tree Bank in data folder. 3 | ## Requirements: 4 | * pytorch 5 | * tqdm 6 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | class Dictionary(object): 5 | def __init__(self): 6 | self.word2idx = {} 7 | self.idx2word = [] 8 | 9 | def add_word(self, word): 10 | if word not in self.word2idx: 11 | self.idx2word.append(word) 12 | self.word2idx[word] = len(self.idx2word) - 1 13 | return self.word2idx[word] 14 | 15 | def __len__(self): 16 | return len(self.idx2word) 17 | 18 | 19 | class Corpus(object): 20 | def __init__(self, path): 21 | self.dictionary = Dictionary() 22 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 23 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 24 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 25 | 26 | def tokenize(self, path): 27 | """Tokenizes a text file.""" 28 | assert os.path.exists(path) 29 | # Add words to the dictionary 30 | with open(path, 'r') as f: 31 | tokens = 0 32 | for line in f: 33 | words = line.split() + [''] 34 | tokens += len(words) 35 | for word in words: 36 | self.dictionary.add_word(word) 37 | 38 | # Tokenize file content 39 | with open(path, 'r') as f: 40 | ids = torch.LongTensor(tokens) 41 | token = 0 42 | for line in f: 43 | words = line.split() + [''] 44 | for word in words: 45 | ids[token] = self.dictionary.word2idx[word] 46 | token += 1 47 | 48 | return ids 49 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Language Modeling on Penn Tree Bank 3 | # 4 | # This file generates new sentences sampled from the language model 5 | # 6 | ############################################################################### 7 | 8 | import argparse 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | 13 | import data 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch PTB Language Model') 16 | 17 | # Model parameters. 18 | parser.add_argument('--data', type=str, default='./data/penn', 19 | help='location of the data corpus') 20 | parser.add_argument('--checkpoint', type=str, default='./model.pt', 21 | help='model checkpoint to use') 22 | parser.add_argument('--outf', type=str, default='generated.txt', 23 | help='output file for generated text') 24 | parser.add_argument('--words', type=int, default='1000', 25 | help='number of words to generate') 26 | parser.add_argument('--seed', type=int, default=1111, 27 | help='random seed') 28 | parser.add_argument('--cuda', action='store_true', 29 | help='use CUDA') 30 | parser.add_argument('--temperature', type=float, default=1.0, 31 | help='temperature - higher will increase diversity') 32 | parser.add_argument('--log-interval', type=int, default=100, 33 | help='reporting interval') 34 | args = parser.parse_args() 35 | 36 | # Set the random seed manually for reproducibility. 37 | torch.manual_seed(args.seed) 38 | if torch.cuda.is_available(): 39 | if not args.cuda: 40 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 41 | else: 42 | torch.cuda.manual_seed(args.seed) 43 | 44 | if args.temperature < 1e-3: 45 | parser.error("--temperature has to be greater or equal 1e-3") 46 | 47 | with open(args.checkpoint, 'rb') as f: 48 | model = torch.load(f) 49 | model.eval() 50 | 51 | if args.cuda: 52 | model.cuda() 53 | else: 54 | model.cpu() 55 | 56 | corpus = data.Corpus(args.data) 57 | ntokens = len(corpus.dictionary) 58 | hidden = model.init_hidden(1) 59 | input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) 60 | if args.cuda: 61 | input.data = input.data.cuda() 62 | 63 | with open(args.outf, 'w') as outf: 64 | for i in range(args.words): 65 | output, hidden = model(input, hidden) 66 | word_weights = output.squeeze().data.div(args.temperature).exp().cpu() 67 | word_idx = torch.multinomial(word_weights, 1)[0] 68 | input.data.fill_(word_idx) 69 | word = corpus.dictionary.idx2word[word_idx] 70 | 71 | outf.write(word + ('\n' if i % 20 == 19 else ' ')) 72 | 73 | if i % args.log_interval == 0: 74 | print('| Generated {}/{} words'.format(i, args.words)) 75 | -------------------------------------------------------------------------------- /rev_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | from functools import reduce 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.autograd import Variable, grad 9 | from tqdm import tqdm 10 | 11 | import data 12 | import reversible 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 15 | parser.add_argument('--data', type=str, default='./data/penn', 16 | help='location of the data corpus') 17 | parser.add_argument('--log', type=str, default='./log.txt') 18 | parser.add_argument('--model', type=str, default='LSTM', 19 | help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)') 20 | parser.add_argument('--emsize', type=int, default=200, 21 | help='size of word embeddings') 22 | parser.add_argument('--nhid', type=int, default=200, 23 | help='number of hidden units per layer') 24 | parser.add_argument('--nlayers', type=int, default=2, 25 | help='number of layers') 26 | parser.add_argument('--lr', type=float, default=0.01, 27 | help='initial learning rate') 28 | parser.add_argument('--clip', type=float, default=0.25, 29 | help='gradient clipping') 30 | parser.add_argument('--epochs', type=int, default=40, 31 | help='upper epoch limit') 32 | parser.add_argument('--batch_size', type=int, default=20, metavar='N', 33 | help='batch size') 34 | parser.add_argument('--bptt', type=int, default=35, 35 | help='sequence length') 36 | parser.add_argument('--dropout', type=float, default=0.2, 37 | help='dropout applied to layers (0 = no dropout)') 38 | parser.add_argument('--tied', action='store_true', 39 | help='tie the word embedding and softmax weights') 40 | parser.add_argument('--seed', type=int, default=1111, 41 | help='random seed') 42 | parser.add_argument('--cuda', action='store_true', 43 | help='use CUDA') 44 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 45 | help='report interval') 46 | parser.add_argument('--save', type=str, default='model.pt', 47 | help='path to save the final model') 48 | args = parser.parse_args() 49 | 50 | # Set the random seed manually for reproducibility. 51 | torch.manual_seed(args.seed) 52 | if torch.cuda.is_available(): 53 | if not args.cuda: 54 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 55 | else: 56 | torch.cuda.manual_seed(args.seed) 57 | 58 | ############################################################################### 59 | # Load data 60 | ############################################################################### 61 | 62 | corpus = data.Corpus(args.data) 63 | 64 | def batchify(data, bsz): 65 | # Work out how cleanly we can divide the dataset into bsz parts. 66 | nbatch = data.size(0) // bsz 67 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 68 | data = data.narrow(0, 0, nbatch * bsz) 69 | # Evenly divide the data across the bsz batches. 70 | data = data.view(bsz, -1).t().contiguous() 71 | if args.cuda: 72 | data = data.cuda() 73 | return data 74 | 75 | eval_batch_size = 10 76 | train_data = batchify(corpus.train, args.batch_size) 77 | val_data = batchify(corpus.valid, eval_batch_size) 78 | test_data = batchify(corpus.test, eval_batch_size) 79 | 80 | ############################################################################### 81 | # Build the model 82 | ############################################################################### 83 | 84 | ntokens = len(corpus.dictionary) 85 | model = reversible.RevRNNModel(ntokens, args.emsize, args.nhid, args.tied) 86 | if args.cuda: 87 | model.cuda() 88 | 89 | criterion = nn.CrossEntropyLoss() 90 | reg_criterion = nn.L1Loss() 91 | ############################################################################### 92 | # Training code 93 | ############################################################################### 94 | 95 | def repackage_hidden(h): 96 | """Wraps hidden states in new Variables, to detach them from their history.""" 97 | if type(h) == Variable: 98 | return Variable(h.data) 99 | else: 100 | return tuple(repackage_hidden(v) for v in h) 101 | 102 | def attach(h): 103 | if type(h) == Variable: 104 | h.requires_grad=True 105 | return h 106 | else: 107 | return tuple(attach(v) for v in h) 108 | 109 | def detach(h): 110 | if type(h) == Variable: 111 | h.volatile=True 112 | h.requires_grad=False 113 | return h 114 | else: 115 | return tuple(attach(v) for v in h) 116 | 117 | def flatten(input_list): 118 | output = () 119 | for item in input_list: 120 | output += flatten(item) if type(item) == list or type(item) == tuple else (item,) 121 | return output 122 | 123 | def group(input_list): 124 | return tuple(zip(*[iter(L)]*2)) 125 | 126 | def get_batch(source, i, evaluation=False): 127 | #seq_len = min(args.bptt, len(source) - 1 - i) 128 | seq_len = 1 129 | data = Variable(source[i], volatile=evaluation) 130 | target = Variable(source[i+1].view(-1)) 131 | return data, target 132 | 133 | 134 | def evaluate(data_source): 135 | # Turn on evaluation mode which disables dropout. 136 | model.eval() 137 | total_loss = 0 138 | ntokens = len(corpus.dictionary) 139 | hidden = model.init_hidden(eval_batch_size) 140 | for i in range(0, data_source.size(0) - 1, args.bptt): 141 | data, targets = get_batch(data_source, i, evaluation=True) 142 | output, hidden = model(data, hidden) 143 | output_flat = output.view(-1, ntokens) 144 | total_loss += len(data) * criterion(output_flat, targets).data 145 | hidden = repackage_hidden(hidden) 146 | return total_loss[0] / len(data_source) 147 | 148 | 149 | def train(): 150 | # Turn on training mode which enables dropout. 151 | model.train() 152 | total_loss = 0 153 | total_loss_error = 0 154 | start_time = time.time() 155 | ntokens = len(corpus.dictionary) 156 | cell_states = model.init_hidden(args.batch_size) 157 | #optimizer = optim.Adam(model.parameters(), lr = lr) 158 | with open(args.log, 'w') as f: 159 | f.write("") 160 | for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): 161 | cell_states = repackage_hidden(cell_states) 162 | #cell_states = model.init_hidden(args.batch_size) 163 | state_grad = model.init_hidden(args.batch_size) 164 | #optimizer.zero_grad() 165 | model.zero_grad() 166 | forward_loss = 0 167 | backward_loss = 0 168 | 169 | for j in range(0, args.bptt): 170 | data, targets = get_batch(train_data, i+j) 171 | 172 | output, cell_states = model(data, cell_states) 173 | #import pdb; pdb.set_trace() 174 | loss = criterion(output.view(-1, ntokens), targets)/args.bptt 175 | forward_loss += loss.data 176 | del output, data, targets 177 | cell_states = repackage_hidden(cell_states) 178 | new_cell_states = repackage_hidden(cell_states) 179 | cell_states = repackage_hidden(cell_states) 180 | 181 | for j in range(args.bptt-1, -1, -1): 182 | data, targets = get_batch(train_data, i+j) 183 | output, old_cell_states = model.reconstruct(data, new_cell_states) 184 | 185 | old_cell_states = attach(repackage_hidden(old_cell_states)) 186 | output, new_cell_states_recon = model(data, old_cell_states) 187 | loss = criterion(output.view(-1, ntokens), targets)/args.bptt 188 | backward_loss += loss.data 189 | #import pdb; pdb.set_trace() 190 | #reg_loss = reduce((lambda x,y: x+y),(5*torch.mean(torch.pow(p, 2)) for p in model.parameters() if p.ndimension() == 2)) 191 | #loss += reg_loss + reduce((lambda x,y: x+y),(torch.mean(torch.pow((x-y),2)) for x,y in zip(flatten(new_cell_states_recon), flatten(detach(new_cell_states))))) 192 | 193 | state_grad = grad((loss,)+flatten(new_cell_states_recon), flatten(old_cell_states), (None,)+flatten(state_grad), only_inputs=False) 194 | del loss, data, targets 195 | new_cell_states = repackage_hidden(old_cell_states) 196 | state_grad = repackage_hidden(state_grad) 197 | del old_cell_states 198 | 199 | 200 | # Starting each batch, we detach the hidden state from how it was previously produced. 201 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 202 | # print(i, forward_loss, backward_loss) 203 | total_loss_error = (forward_loss - backward_loss)**2 204 | parameter_norm = 0 205 | # print(total_loss_error) 206 | # import pdb; pdb.set_trace() 207 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 208 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) 209 | for p in model.parameters(): 210 | #import pdb; pdb.set_trace() 211 | #print(p, p.grad) 212 | if p.ndimension() == 2: 213 | parameter_norm += torch.mean(torch.pow(p, 2)) 214 | p.data.add_(-lr, p.grad.data) 215 | #optimizer.step() 216 | total_loss += backward_loss 217 | #cur_loss = total_loss[0] / args.log_interval 218 | #print(parameter_norm.data.max(), total_loss_error.max(), math.exp(cur_loss)) 219 | if batch % args.log_interval == 0 and batch > 0: 220 | cur_loss = total_loss[0] / args.log_interval 221 | elapsed = time.time() - start_time 222 | string = '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '\ 223 | 'loss {:5.2f} | ppl {:8.2f} | norm {} | error {} |'.format( 224 | epoch, batch, len(train_data) // args.bptt, lr, 225 | elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss), parameter_norm.data.max(), total_loss_error.max()) 226 | print(string) 227 | with open(args.log, 'a') as f: 228 | f.write(string+'\n') 229 | total_loss = 0 230 | start_time = time.time() 231 | #print(parameter_norm.data.max(), total_loss_error.max()) 232 | 233 | # Loop over epochs. 234 | lr = args.lr 235 | best_val_loss = None 236 | 237 | # At any point you can hit Ctrl + C to break out of training early. 238 | try: 239 | for epoch in range(1, args.epochs+1): 240 | epoch_start_time = time.time() 241 | train() 242 | val_loss = evaluate(val_data) 243 | print('-' * 89) 244 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 245 | 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 246 | val_loss, math.exp(val_loss))) 247 | print('-' * 89) 248 | # Save the model if the validation loss is the best we've seen so far. 249 | if not best_val_loss or val_loss < best_val_loss: 250 | with open(args.save, 'wb') as f: 251 | torch.save(model, f) 252 | best_val_loss = val_loss 253 | else: 254 | # Anneal the learning rate if no improvement has been seen in the validation dataset. 255 | lr /= 4.0 256 | except KeyboardInterrupt: 257 | print('-' * 89) 258 | print('Exiting from training early') 259 | 260 | # Load the best saved model. 261 | with open(args.save, 'rb') as f: 262 | model = torch.load(f) 263 | 264 | # Run on test data. 265 | test_loss = evaluate(test_data) 266 | print('=' * 89) 267 | print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 268 | test_loss, math.exp(test_loss))) 269 | print('=' * 89) 270 | -------------------------------------------------------------------------------- /reversible.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | from time import time 6 | 7 | 8 | class RevRNNModel(nn.Module): 9 | """Container module with an encoder, a recurrent module, and a decoder.""" 10 | 11 | def __init__(self, ntoken, ninp, nhid, tie_weights=False): 12 | super(RevRNNModel, self).__init__() 13 | self.encoder = nn.Embedding(ntoken, ninp) 14 | self.rnn1 = RevLSTMCell(ninp, nhid) 15 | self.rnn2 = RevLSTMCell(nhid, nhid) 16 | self.rnn = MultiRNNCell([self.rnn1, self.rnn2]) 17 | #self.rnn = RevLSTMCell(ninp, nhid) 18 | self.decoder = nn.Linear(nhid, ntoken) 19 | 20 | # Optionally tie weights as in: 21 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 22 | # https://arxiv.org/abs/1608.05859 23 | # and 24 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 25 | # https://arxiv.org/abs/1611.01462 26 | if tie_weights: 27 | if nhid != ninp: 28 | raise ValueError('When using the tied flag, nhid must be equal to emsize') 29 | self.decoder.weight = self.encoder.weight 30 | 31 | self.init_weights() 32 | 33 | def init_weights(self): 34 | initrange = 0.1 35 | self.encoder.weight.data.uniform_(-initrange, initrange) 36 | self.decoder.bias.data.fill_(0) 37 | self.decoder.weight.data.uniform_(-initrange, initrange) 38 | 39 | def forward(self, input, hidden): 40 | emb = self.encoder(input) 41 | output, hidden = self.rnn(emb, hidden) 42 | decoded = self.decoder(output) 43 | return decoded, hidden 44 | 45 | def reconstruct(self, input, states): 46 | emb = self.encoder(input) 47 | output, old_cell_states = self.rnn.reconstruct(emb, states) 48 | return output, old_cell_states 49 | 50 | def init_hidden(self, bsz): 51 | return self.rnn.init_hidden(bsz) 52 | 53 | class MultiRNNCell(nn.Module): 54 | """Defining multi layer rnn""" 55 | def __init__(self, cells): 56 | super(MultiRNNCell, self).__init__() 57 | self._cells = cells 58 | def forward(self, input, cell_states): 59 | current_input = input 60 | new_cell_states = [] 61 | for i, cell in enumerate(self._cells): 62 | #import pdb; pdb.set_trace() 63 | current_input, state = cell(current_input, cell_states[i]) 64 | new_cell_states.append(state) 65 | return current_input, new_cell_states 66 | def reconstruct(self, input, new_cell_states): 67 | current_input = input 68 | cell_states = [] 69 | for i, cell in enumerate(self._cells): 70 | current_input, state = cell.reconstruct(current_input, new_cell_states[i]) 71 | cell_states.append(state) 72 | return current_input, cell_states 73 | def init_hidden(self, bsz): 74 | return tuple(x.init_hidden(bsz) for x in self._cells) 75 | 76 | class RevLSTMCell(nn.Module): 77 | """ Defining Network Completely along with gradients to Variables """ 78 | def __init__(self, input_size, hidden_size): 79 | super(RevLSTMCell, self).__init__() 80 | self.f_layer= nn.Linear(input_size + hidden_size, hidden_size) 81 | self.i_layer= nn.Linear(input_size + hidden_size, hidden_size) 82 | self.uc_layer= nn.Linear(input_size + hidden_size, hidden_size) 83 | self.u_layer= nn.Linear(input_size + hidden_size, hidden_size) 84 | self.r_layer= nn.Linear(input_size + hidden_size, hidden_size) 85 | self.uh_layer= nn.Linear(input_size + hidden_size, hidden_size) 86 | self.tanh = nn.Tanh() 87 | self.sigmoid = nn.Sigmoid() 88 | self.hidden_size = hidden_size 89 | self.init_weights() 90 | 91 | def forward(self, x, state): 92 | c, h = state 93 | #import pdb; pdb.set_trace() 94 | concat1 = torch.cat((x, h), dim=-1) 95 | f = self.sigmoid(self.f_layer(concat1)) 96 | i = self.sigmoid(self.i_layer(concat1)) 97 | c_ = self.tanh(self.uc_layer(concat1)) 98 | new_c = (f*c + i*c_)/2 99 | 100 | concat2 = torch.cat((x, new_c), dim=-1) 101 | u = self.sigmoid(self.u_layer(concat2)) 102 | r = self.sigmoid(self.r_layer(concat2)) 103 | h_ = self.tanh(self.uh_layer(concat2)) 104 | new_h = (r*h + u*h_)/2 105 | return new_h, (new_c, new_h) 106 | 107 | def reconstruct(self, x, state): 108 | new_c, new_h = state 109 | 110 | concat2 = torch.cat((x, new_c), dim=-1) 111 | u = self.sigmoid(self.u_layer(concat2)) 112 | r = self.sigmoid(self.r_layer(concat2)) 113 | h_ = self.tanh(self.uh_layer(concat2)) 114 | h = (2*new_h - u*h_)/(r+1e-64) 115 | 116 | 117 | concat1 = torch.cat((x, h), dim=-1) 118 | f = self.sigmoid(self.f_layer(concat1)) 119 | i = self.sigmoid(self.i_layer(concat1)) 120 | c_ = self.tanh(self.uc_layer(concat1)) 121 | c = (2*new_c - i*c_)/(f+1e-64) 122 | 123 | return h, (c, h) 124 | 125 | def init_weights(self): 126 | for parameter in self.parameters(): 127 | if parameter.ndimension() == 2: 128 | nn.init.xavier_uniform(parameter, gain=0.01) 129 | 130 | def init_hidden(self, bsz): 131 | weight = next(self.parameters()).data 132 | return (Variable(weight.new(bsz, self.hidden_size).zero_()), 133 | Variable(weight.new(bsz, self.hidden_size).zero_())) 134 | 135 | class RevLSTMCell2(nn.Module): 136 | """ Defining Network Completely along with gradients to Variables """ 137 | def __init__(self, input_size, hidden_size): 138 | super(RevLSTMCell2, self).__init__() 139 | self.f_layer= nn.Linear(input_size + hidden_size, hidden_size) 140 | self.i_layer= nn.Linear(input_size + hidden_size, hidden_size) 141 | self.u_layer= nn.Linear(input_size + hidden_size, hidden_size) 142 | self.r_layer= nn.Linear(input_size + hidden_size, hidden_size) 143 | self.relu = nn.ReLU() 144 | self.hidden_size = hidden_size 145 | self.init_weights() 146 | 147 | def forward(self, x, state): 148 | c, h = state 149 | #import pdb; pdb.set_trace() 150 | concat1 = torch.cat((x, h), dim=-1) 151 | f = self.relu(self.f_layer(concat1)) 152 | i = self.relu(self.i_layer(concat1)) 153 | new_c = (c - f + i)/2 154 | 155 | concat2 = torch.cat((x, new_c), dim=-1) 156 | u = self.relu(self.u_layer(concat2)) 157 | r = self.relu(self.r_layer(concat2)) 158 | new_h = (h - r + u)/2 159 | return new_h, (new_c, new_h) 160 | 161 | def reconstruct(self, x, state): 162 | new_c, new_h = state 163 | 164 | concat2 = torch.cat((x, new_c), dim=-1) 165 | u = self.relu(self.u_layer(concat2)) 166 | r = self.relu(self.r_layer(concat2)) 167 | h = (2*new_h - u + r) 168 | 169 | concat1 = torch.cat((x, h), dim=-1) 170 | f = self.relu(self.f_layer(concat1)) 171 | i = self.relu(self.i_layer(concat1)) 172 | c = (2*new_c - i + f) 173 | 174 | return h, (c, h) 175 | 176 | def init_weights(self): 177 | for parameter in self.parameters(): 178 | if parameter.ndimension() == 2: 179 | nn.init.xavier_uniform(parameter, gain=1) 180 | 181 | 182 | def init_hidden(self, bsz): 183 | weight = next(self.parameters()).data 184 | return (Variable(weight.new(bsz, self.hidden_size).zero_()), 185 | Variable(weight.new(bsz, self.hidden_size).zero_())) 186 | 187 | class RevLSTMCell3(nn.Module): 188 | """ Defining Network Completely along with gradients to Variables """ 189 | def __init__(self, input_size, hidden_size): 190 | super(RevLSTMCell3, self).__init__() 191 | self.c1_layer= nn.Linear(input_size + hidden_size, hidden_size) 192 | self.c2_layer= nn.Linear(hidden_size, hidden_size, bias=False) 193 | self.h1_layer= nn.Linear(input_size + hidden_size, hidden_size) 194 | self.h2_layer= nn.Linear(hidden_size, hidden_size, bias=False) 195 | #self.c_batchnorm = nn.BatchNorm1d(hidden_size) 196 | #self.h_batchnorm = nn.BatchNorm1d(hidden_size) 197 | self.relu = nn.ReLU() 198 | self.hidden_size = hidden_size 199 | self.init_weights() 200 | 201 | def forward(self, x, state): 202 | c, h = state 203 | #import pdb; pdb.set_trace() 204 | concat1 = torch.cat((x, h), dim=-1) 205 | c_ = self.relu(self.c1_layer(concat1)) 206 | c_ = self.c2_layer(c_) 207 | new_c = (c + c_) 208 | 209 | concat2 = torch.cat((x, new_c), dim=-1) 210 | h_ = self.relu(self.h1_layer(concat2)) 211 | h_ = self.h2_layer(h_) 212 | new_h = (h - h_) 213 | return new_h, (new_c, new_h) 214 | 215 | def reconstruct(self, x, state): 216 | new_c, new_h = state 217 | 218 | concat2 = torch.cat((x, new_c), dim=-1) 219 | h_ = self.relu(self.h1_layer(concat2)) 220 | h_ = self.h2_layer(h_) 221 | h = (new_h + h_) 222 | 223 | concat1 = torch.cat((x, h), dim=-1) 224 | c_ = self.relu(self.c1_layer(concat1)) 225 | c_ = self.c2_layer(c_) 226 | c = (new_c - c_) 227 | 228 | return h, (c, h) 229 | 230 | def init_weights(self): 231 | for parameter in self.parameters(): 232 | if parameter.ndimension() == 2: 233 | nn.init.xavier_uniform(parameter, gain=1) 234 | 235 | 236 | def init_hidden(self, bsz): 237 | weight = next(self.parameters()).data 238 | return (Variable(weight.new(bsz, self.hidden_size).zero_()), 239 | Variable(weight.new(bsz, self.hidden_size).zero_())) 240 | 241 | --------------------------------------------------------------------------------