├── requirements.txt ├── README.md ├── data.py ├── .gitignore ├── ran.py ├── generate.py ├── model.py └── main.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recurrent Additive Networks 2 | 3 | **Note:** This code is not up-to-date, please refer to the implementation by the original authors: https://github.com/kentonl/ran 4 | 5 | --- 6 | 7 | This is a PyTorch implementation of Recurrent Additive Networks (RAN) by Kenton Lee, 8 | Omer Levy, and Luke Zettlemoyer: 9 | 10 | http://www.kentonl.com/pub/llz.2017.pdf 11 | 12 | The RAN model is implemented in `ran.py`. 13 | 14 | 15 | Code for running Penn Tree Bank (PTB) experiments is taken from: 16 | 17 | https://github.com/pytorch/examples/tree/master/word_language_model 18 | 19 | 20 | To run PTB experiments, clone this repository: 21 | 22 | ``` 23 | git clone https://github.com/bheinzerling/ran 24 | ``` 25 | 26 | and then do: 27 | 28 | ``` 29 | cd ran 30 | python main.py --cuda --emsize 256 --nhid 1024 --dropout 0.5 --epochs 100 --nlayers 1 --batch-size 512 --model RAN 31 | ``` 32 | 33 | This should result in a test set perplexity which roughly agrees with the RAN (tanh) result reported in the paper: 34 | 35 | ``` 36 | End of training | test loss 4.78 | test ppl 119.40 37 | ``` 38 | 39 | Better results can be achieved with smaller batch sizes, e.g. with batch size 40: 40 | 41 | ``` 42 | End of training | test loss 4.45 | test ppl 85.24 43 | ``` 44 | 45 | batch size 20: 46 | 47 | ``` 48 | | End of training | test loss 4.42 | test ppl 83.42 49 | ``` 50 | 51 | batch size 10: 52 | 53 | ``` 54 | | End of training | test loss 4.41 | test ppl 82.62 55 | ``` 56 | 57 | batch size 5: 58 | 59 | ``` 60 | | End of training | test loss 4.49 | test ppl 89.21 61 | ``` 62 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | class Dictionary(object): 5 | def __init__(self): 6 | self.word2idx = {} 7 | self.idx2word = [] 8 | 9 | def add_word(self, word): 10 | if word not in self.word2idx: 11 | self.idx2word.append(word) 12 | self.word2idx[word] = len(self.idx2word) - 1 13 | return self.word2idx[word] 14 | 15 | def __len__(self): 16 | return len(self.idx2word) 17 | 18 | 19 | class Corpus(object): 20 | def __init__(self, path): 21 | self.dictionary = Dictionary() 22 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 23 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 24 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 25 | 26 | def tokenize(self, path): 27 | """Tokenizes a text file.""" 28 | assert os.path.exists(path) 29 | # Add words to the dictionary 30 | with open(path, 'r') as f: 31 | tokens = 0 32 | for line in f: 33 | words = line.split() + [''] 34 | tokens += len(words) 35 | for word in words: 36 | self.dictionary.add_word(word) 37 | 38 | # Tokenize file content 39 | with open(path, 'r') as f: 40 | ids = torch.LongTensor(tokens) 41 | token = 0 42 | for line in f: 43 | words = line.split() + [''] 44 | for word in words: 45 | ids[token] = self.dictionary.word2idx[word] 46 | token += 1 47 | 48 | return ids 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | -------------------------------------------------------------------------------- /ran.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from torch.nn._functions.rnn import Recurrent, StackedRNN 6 | 7 | 8 | class RAN(nn.Module): 9 | 10 | def __init__(self, input_size, hidden_size, nlayers=1, dropout=0.5): 11 | super().__init__() 12 | if nlayers > 1: 13 | raise NotImplementedError("TODO: nlayers > 1") 14 | self.input_size = input_size 15 | self.hidden_size = hidden_size 16 | self.nlayers = nlayers 17 | self.dropout = dropout 18 | 19 | self.w_cx = nn.Parameter(torch.Tensor(hidden_size, input_size)) 20 | self.w_ic = nn.Parameter(torch.Tensor(hidden_size, hidden_size)) 21 | self.w_ix = nn.Parameter(torch.Tensor(hidden_size, input_size)) 22 | self.w_fc = nn.Parameter(torch.Tensor(hidden_size, hidden_size)) 23 | self.w_fx = nn.Parameter(torch.Tensor(hidden_size, input_size)) 24 | 25 | self.b_cx = nn.Parameter(torch.Tensor(hidden_size)) 26 | self.b_ic = nn.Parameter(torch.Tensor(hidden_size)) 27 | self.b_ix = nn.Parameter(torch.Tensor(hidden_size)) 28 | self.b_fc = nn.Parameter(torch.Tensor(hidden_size)) 29 | self.b_fx = nn.Parameter(torch.Tensor(hidden_size)) 30 | 31 | self.weights = self.w_cx, self.w_ic, self.w_ix, self.w_fc, self.w_fx 32 | for w in self.weights: 33 | init.xavier_uniform(w) 34 | 35 | self.biases = self.b_cx, self.b_ic, self.b_ix, self.b_fc, self.b_fx 36 | for b in self.biases: 37 | b.data.fill_(0) 38 | 39 | def forward(self, input, hidden): 40 | layer = (Recurrent(RANCell), ) 41 | func = StackedRNN(layer, self.nlayers, dropout=self.dropout) 42 | hidden, output = func(input, hidden, ((self.weights, self.biases), )) 43 | return output, hidden 44 | 45 | 46 | def RANCell(input, hidden, weights, biases): 47 | w_cx, w_ic, w_ix, w_fc, w_fx = weights 48 | b_cx, b_ic, b_ix, b_fc, b_fx = biases 49 | 50 | ctilde_t = F.linear(input, w_cx, b_cx) 51 | i_t = F.sigmoid(F.linear(hidden, w_ic, b_ic) + F.linear(input, w_ix, b_ix)) 52 | f_t = F.sigmoid(F.linear(hidden, w_fc, b_fc) + F.linear(input, w_fx, b_fx)) 53 | c_t = i_t * ctilde_t + f_t * hidden 54 | h_t = F.tanh(c_t) 55 | 56 | return h_t 57 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Language Modeling on Penn Tree Bank 3 | # 4 | # This file generates new sentences sampled from the language model 5 | # 6 | ############################################################################### 7 | 8 | import argparse 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | 13 | import data 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch PTB Language Model') 16 | 17 | # Model parameters. 18 | parser.add_argument('--data', type=str, default='./data/penn', 19 | help='location of the data corpus') 20 | parser.add_argument('--checkpoint', type=str, default='./model.pt', 21 | help='model checkpoint to use') 22 | parser.add_argument('--outf', type=str, default='generated.txt', 23 | help='output file for generated text') 24 | parser.add_argument('--words', type=int, default='1000', 25 | help='number of words to generate') 26 | parser.add_argument('--seed', type=int, default=1111, 27 | help='random seed') 28 | parser.add_argument('--cuda', action='store_true', 29 | help='use CUDA') 30 | parser.add_argument('--temperature', type=float, default=1.0, 31 | help='temperature - higher will increase diversity') 32 | parser.add_argument('--log-interval', type=int, default=100, 33 | help='reporting interval') 34 | args = parser.parse_args() 35 | 36 | # Set the random seed manually for reproducibility. 37 | torch.manual_seed(args.seed) 38 | if torch.cuda.is_available(): 39 | if not args.cuda: 40 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 41 | else: 42 | torch.cuda.manual_seed(args.seed) 43 | 44 | if args.temperature < 1e-3: 45 | parser.error("--temperature has to be greater or equal 1e-3") 46 | 47 | with open(args.checkpoint, 'rb') as f: 48 | model = torch.load(f) 49 | model.eval() 50 | 51 | if args.cuda: 52 | model.cuda() 53 | else: 54 | model.cpu() 55 | 56 | corpus = data.Corpus(args.data) 57 | ntokens = len(corpus.dictionary) 58 | hidden = model.init_hidden(1) 59 | input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) 60 | if args.cuda: 61 | input.data = input.data.cuda() 62 | 63 | with open(args.outf, 'w') as outf: 64 | for i in range(args.words): 65 | output, hidden = model(input, hidden) 66 | word_weights = output.squeeze().data.div(args.temperature).exp().cpu() 67 | word_idx = torch.multinomial(word_weights, 1)[0] 68 | input.data.fill_(word_idx) 69 | word = corpus.dictionary.idx2word[word_idx] 70 | 71 | outf.write(word + ('\n' if i % 20 == 19 else ' ')) 72 | 73 | if i % args.log_interval == 0: 74 | print('| Generated {}/{} words'.format(i, args.words)) 75 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import init 3 | from torch.autograd import Variable 4 | 5 | from ran import RAN 6 | 7 | 8 | class RNNModel(nn.Module): 9 | """Container module with an encoder, a recurrent module, and a decoder.""" 10 | 11 | def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False): 12 | super(RNNModel, self).__init__() 13 | self.drop = nn.Dropout(dropout) 14 | self.encoder = nn.Embedding(ntoken, ninp) 15 | if rnn_type == "RAN": 16 | self.rnn = RAN(ninp, nhid, nlayers, dropout=dropout) 17 | elif rnn_type in ['LSTM', 'GRU']: 18 | self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) 19 | else: 20 | try: 21 | nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] 22 | except KeyError: 23 | raise ValueError("""An invalid option for `--model` was supplied, 24 | options are ['LSTM', 'GRU', 'RAN', 'RNN_TANH' or 'RNN_RELU']""") 25 | self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) 26 | self.decoder = nn.Linear(nhid, ntoken) 27 | 28 | # Optionally tie weights as in: 29 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 30 | # https://arxiv.org/abs/1608.05859 31 | # and 32 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 33 | # https://arxiv.org/abs/1611.01462 34 | if tie_weights: 35 | if nhid != ninp: 36 | raise ValueError('When using the tied flag, nhid must be equal to emsize') 37 | self.decoder.weight = self.encoder.weight 38 | 39 | self.init_weights() 40 | 41 | self.rnn_type = rnn_type 42 | self.nhid = nhid 43 | self.nlayers = nlayers 44 | 45 | def init_weights(self): 46 | init.xavier_uniform(self.encoder.weight) 47 | self.decoder.bias.data.fill_(0) 48 | init.xavier_uniform(self.decoder.weight) 49 | 50 | def forward(self, input, hidden): 51 | emb = self.drop(self.encoder(input)) 52 | output, hidden = self.rnn(emb, hidden) 53 | output = self.drop(output) 54 | decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2))) 55 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden 56 | 57 | def init_hidden(self, bsz): 58 | weight = next(self.parameters()).data 59 | if self.rnn_type == 'LSTM': 60 | return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()), 61 | Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())) 62 | else: 63 | return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()) 64 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | import data 9 | import model 10 | 11 | parser = argparse.ArgumentParser( 12 | description='PyTorch PennTreeBank RNN/LSTM Language Model') 13 | parser.add_argument('--data', type=str, default='./data/penn', 14 | help='location of the data corpus') 15 | parser.add_argument( 16 | '--model', type=str, default='LSTM', 17 | help='type of recurrent net (RAN, RNN_TANH, RNN_RELU, LSTM, GRU)') 18 | parser.add_argument('--emsize', type=int, default=200, 19 | help='size of word embeddings') 20 | parser.add_argument('--nhid', type=int, default=200, 21 | help='number of hidden units per layer') 22 | parser.add_argument('--nlayers', type=int, default=2, 23 | help='number of layers') 24 | parser.add_argument('--lr', type=float, default=20, 25 | help='initial learning rate') 26 | parser.add_argument('--clip', type=float, default=0.25, 27 | help='gradient clipping') 28 | parser.add_argument('--epochs', type=int, default=40, 29 | help='upper epoch limit') 30 | parser.add_argument('--batch-size', type=int, default=20, metavar='N', 31 | help='batch size') 32 | parser.add_argument('--bptt', type=int, default=35, 33 | help='sequence length') 34 | parser.add_argument('--dropout', type=float, default=0.2, 35 | help='dropout applied to layers (0 = no dropout)') 36 | parser.add_argument('--tied', action='store_true', 37 | help='tie the word embedding and softmax weights') 38 | parser.add_argument('--seed', type=int, default=1111, 39 | help='random seed') 40 | parser.add_argument('--cuda', action='store_true', 41 | help='use CUDA') 42 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 43 | help='report interval') 44 | parser.add_argument('--save', type=str, default='model.pt', 45 | help='path to save the final model') 46 | args = parser.parse_args() 47 | 48 | # Set the random seed manually for reproducibility. 49 | torch.manual_seed(args.seed) 50 | if torch.cuda.is_available(): 51 | if not args.cuda: 52 | print( 53 | "WARNING: You have a CUDA device, " 54 | "so you should probably run with --cuda") 55 | else: 56 | torch.cuda.manual_seed(args.seed) 57 | 58 | ############################################################################### 59 | # Load data 60 | ############################################################################### 61 | 62 | corpus = data.Corpus(args.data) 63 | 64 | 65 | def batchify(data, bsz): 66 | # Work out how cleanly we can divide the dataset into bsz parts. 67 | nbatch = data.size(0) // bsz 68 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 69 | data = data.narrow(0, 0, nbatch * bsz) 70 | # Evenly divide the data across the bsz batches. 71 | data = data.view(bsz, -1).t().contiguous() 72 | if args.cuda: 73 | data = data.cuda() 74 | return data 75 | 76 | 77 | eval_batch_size = 10 78 | train_data = batchify(corpus.train, args.batch_size) 79 | val_data = batchify(corpus.valid, eval_batch_size) 80 | test_data = batchify(corpus.test, eval_batch_size) 81 | 82 | ############################################################################### 83 | # Build the model 84 | ############################################################################### 85 | 86 | ntokens = len(corpus.dictionary) 87 | model = model.RNNModel( 88 | args.model, ntokens, args.emsize, args.nhid, 89 | args.nlayers, args.dropout, args.tied) 90 | if args.cuda: 91 | model.cuda() 92 | 93 | criterion = nn.CrossEntropyLoss() 94 | 95 | ############################################################################### 96 | # Training code 97 | ############################################################################### 98 | 99 | 100 | def repackage_hidden(h): 101 | """Wraps hidden states in new Variables, 102 | to detach them from their history.""" 103 | if type(h) == Variable: 104 | return Variable(h.data) 105 | else: 106 | return tuple(repackage_hidden(v) for v in h) 107 | 108 | 109 | def get_batch(source, i, evaluation=False): 110 | seq_len = min(args.bptt, len(source) - 1 - i) 111 | data = Variable(source[i:i+seq_len], volatile=evaluation) 112 | target = Variable(source[i+1:i+1+seq_len].view(-1)) 113 | return data, target 114 | 115 | 116 | def evaluate(data_source): 117 | # Turn on evaluation mode which disables dropout. 118 | model.eval() 119 | total_loss = 0 120 | ntokens = len(corpus.dictionary) 121 | hidden = model.init_hidden(eval_batch_size) 122 | for i in range(0, data_source.size(0) - 1, args.bptt): 123 | data, targets = get_batch(data_source, i, evaluation=True) 124 | output, hidden = model(data, hidden) 125 | output_flat = output.view(-1, ntokens) 126 | total_loss += len(data) * criterion(output_flat, targets).data 127 | hidden = repackage_hidden(hidden) 128 | return total_loss[0] / len(data_source) 129 | 130 | 131 | def train(): 132 | # Turn on training mode which enables dropout. 133 | model.train() 134 | total_loss = 0 135 | start_time = time.time() 136 | ntokens = len(corpus.dictionary) 137 | hidden = model.init_hidden(args.batch_size) 138 | for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): 139 | data, targets = get_batch(train_data, i) 140 | # Starting each batch, we detach the hidden state from how it 141 | # was previously produced. If we didn't, the model would try 142 | # backpropagating all the way to start of the dataset. 143 | hidden = repackage_hidden(hidden) 144 | model.zero_grad() 145 | output, hidden = model(data, hidden) 146 | loss = criterion(output.view(-1, ntokens), targets) 147 | loss.backward() 148 | 149 | # `clip_grad_norm` helps prevent the exploding gradient problem 150 | # in RNNs / LSTMs. 151 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) 152 | for p in model.parameters(): 153 | p.data.add_(-lr, p.grad.data) 154 | 155 | total_loss += loss.data 156 | 157 | if batch % args.log_interval == 0 and batch > 0: 158 | cur_loss = total_loss[0] / args.log_interval 159 | elapsed = time.time() - start_time 160 | print( 161 | '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ' 162 | 'ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}'.format( 163 | epoch, batch, len(train_data) // args.bptt, lr, 164 | elapsed * 1000 / args.log_interval, 165 | cur_loss, math.exp(cur_loss))) 166 | total_loss = 0 167 | start_time = time.time() 168 | 169 | 170 | # Loop over epochs. 171 | lr = args.lr 172 | best_val_loss = None 173 | 174 | # At any point you can hit Ctrl + C to break out of training early. 175 | try: 176 | for epoch in range(1, args.epochs+1): 177 | epoch_start_time = time.time() 178 | train() 179 | val_loss = evaluate(val_data) 180 | print('-' * 89) 181 | print( 182 | '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 183 | 'valid ppl {:8.2f}'.format( 184 | epoch, (time.time() - epoch_start_time), 185 | val_loss, math.exp(val_loss))) 186 | print('-' * 89) 187 | # Save the model if the validation loss is the best we've seen so far. 188 | if not best_val_loss or val_loss < best_val_loss: 189 | with open(args.save, 'wb') as f: 190 | torch.save(model, f) 191 | best_val_loss = val_loss 192 | else: 193 | # Anneal the learning rate if no improvement has been seen 194 | # in the validation dataset. 195 | lr /= 4.0 196 | except KeyboardInterrupt: 197 | print('-' * 89) 198 | print('Exiting from training early') 199 | 200 | # Load the best saved model. 201 | with open(args.save, 'rb') as f: 202 | model = torch.load(f) 203 | 204 | # Run on test data. 205 | test_loss = evaluate(test_data) 206 | print('=' * 89) 207 | print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 208 | test_loss, math.exp(test_loss))) 209 | print('=' * 89) 210 | --------------------------------------------------------------------------------