├── __init__.py ├── data ├── .gitkeep └── enwik8 │ └── prep_enwik8.py ├── save └── .gitkeep ├── cache └── .gitkeep ├── .gitignore ├── locked_dropout.py ├── sys_config.py ├── utils.py ├── model_save.py ├── LICENSE ├── getdata.sh ├── data.py ├── embed_regularize.py ├── generate.py ├── asgd.py ├── weight_drop.py ├── model.py ├── pointer.py ├── splitcross.py ├── finetune.py ├── README.md └── main.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /save/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cache/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /locked_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | class LockedDropout(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, x, dropout=0.5): 10 | if not self.training or not dropout: 11 | return x 12 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) 13 | mask = Variable(m, requires_grad=False) / (1 - dropout) 14 | mask = mask.expand_as(x) 15 | return mask * x 16 | -------------------------------------------------------------------------------- /sys_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | # print("torch:", torch.__version__) 6 | # if torch.__version__ != '0.1.12_2': 7 | # print("Cuda:", torch.backends.cudnn.cuda) 8 | # print("CuDNN:", torch.backends.cudnn.version()) 9 | 10 | # os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 11 | # os.environ["CUDA_VISIBLE_DEVICES"]="0" 12 | 13 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 14 | 15 | DATA_DIR = os.path.join(BASE_DIR, 'data') 16 | 17 | CACHE_DIR = os.path.join(BASE_DIR, 'cache') 18 | 19 | CKPT_DIR = os.path.join(BASE_DIR, 'save') 20 | 21 | -------------------------------------------------------------------------------- /data/enwik8/prep_enwik8.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import os 5 | import sys 6 | import zipfile 7 | 8 | if os.path.exists('train.txt'): 9 | print('Tokenized enwik8 already exists - skipping processing') 10 | sys.exit() 11 | 12 | data = zipfile.ZipFile('enwik8.zip').read('enwik8') 13 | 14 | print('Length of enwik8: {}'.format(len(data))) 15 | 16 | num_test_chars = 5000000 17 | 18 | train_data = data[: -2 * num_test_chars] 19 | valid_data = data[-2 * num_test_chars: -num_test_chars] 20 | test_data = data[-num_test_chars:] 21 | 22 | for fn, part in [('train.txt', train_data), ('valid.txt', valid_data), ('test.txt', test_data)]: 23 | print('{} will have {} bytes'.format(fn, len(part))) 24 | print('- Tokenizing...') 25 | part_str = ' '.join([str(c) if c != ord('\n') else '\n' for c in part]) 26 | print('- Writing...') 27 | f = open(fn, 'w').write(part_str) 28 | f = open(fn + '.raw', 'wb').write(part) 29 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def repackage_hidden(h): 5 | """Wraps hidden states in new Tensors, 6 | to detach them from their history.""" 7 | if torch.__version__ == '0.1.12_2': 8 | from torch.autograd import Variable 9 | if type(h) == Variable: 10 | return Variable(h.data) 11 | else: 12 | return tuple(repackage_hidden(v) for v in h) 13 | else: 14 | if isinstance(h, torch.Tensor): 15 | return h.detach() 16 | else: 17 | return tuple(repackage_hidden(v) for v in h) 18 | 19 | 20 | def batchify(data, bsz, args): 21 | # Work out how cleanly we can divide the dataset into bsz parts. 22 | nbatch = data.size(0) // bsz 23 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 24 | data = data.narrow(0, 0, nbatch * bsz) 25 | # Evenly divide the data across the bsz batches. 26 | data = data.view(bsz, -1).t().contiguous() 27 | if args.cuda: 28 | data = data.cuda() 29 | return data 30 | 31 | 32 | def get_batch(source, i, args, seq_len=None, evaluation=False): 33 | seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i) 34 | data = source[i:i+seq_len] 35 | target = source[i+1:i+1+seq_len].view(-1) 36 | return data, target 37 | -------------------------------------------------------------------------------- /model_save.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def model_save(fn, model, criterion, optimizer, vocab=None, 5 | val_loss=None, val_ppl=None, config=None, epoch=None): 6 | state = {'model': model, 'criterion': criterion, 7 | 'optimizer': optimizer, 'vocab': vocab, 8 | 'val_loss': val_loss, 'val_ppl': val_ppl, 9 | 'config': config, 'epoch': epoch} 10 | with open(fn, 'wb') as f: 11 | torch.save(state, f) 12 | 13 | 14 | def model_load(fn): 15 | # global model, criterion, optimizer 16 | with open(fn, 'rb') as f: 17 | # model, criterion, optimizer, vocab, val_loss, config = torch.load(f) 18 | return torch.load(f) 19 | 20 | 21 | def model_state_save(fn, model, criterion, optimizer, vocab=None, 22 | val_loss=None, val_ppl=None, config=None, epoch=None): 23 | """ 24 | We have to save *only* the state_dicts() of all arguments in order to load the checkpoint from a different project. 25 | :return: 26 | """ 27 | state = {'model_state_dict': model.state_dict(), 28 | 'optimizer_state_dict': optimizer.state_dict(), 'vocab': vocab.__dict__, 29 | 'val_loss': val_loss, 'val_ppl': val_ppl, 30 | 'config': config, 'epoch': epoch} 31 | with open(fn, 'wb') as f: 32 | torch.save(state, f) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /getdata.sh: -------------------------------------------------------------------------------- 1 | echo "=== Acquiring datasets ===" 2 | echo "---" 3 | mkdir -p save 4 | 5 | mkdir -p data 6 | cd data 7 | 8 | echo "- Downloading WikiText-2 (WT2)" 9 | wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip 10 | unzip -q wikitext-2-v1.zip 11 | cd wikitext-2 12 | mv wiki.train.tokens train.txt 13 | mv wiki.valid.tokens valid.txt 14 | mv wiki.test.tokens test.txt 15 | cd .. 16 | 17 | echo "- Downloading WikiText-103 (WT2)" 18 | wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip 19 | unzip -q wikitext-103-v1.zip 20 | cd wikitext-103 21 | mv wiki.train.tokens train.txt 22 | mv wiki.valid.tokens valid.txt 23 | mv wiki.test.tokens test.txt 24 | cd .. 25 | 26 | echo "- Downloading enwik8 (Character)" 27 | mkdir -p enwik8 28 | cd enwik8 29 | wget --continue http://mattmahoney.net/dc/enwik8.zip 30 | python prep_enwik8.py 31 | cd .. 32 | 33 | echo "- Downloading Penn Treebank (PTB)" 34 | wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz 35 | tar -xzf simple-examples.tgz 36 | 37 | mkdir -p penn 38 | cd penn 39 | mv ../simple-examples/data/ptb.train.txt train.txt 40 | mv ../simple-examples/data/ptb.test.txt test.txt 41 | mv ../simple-examples/data/ptb.valid.txt valid.txt 42 | cd .. 43 | 44 | echo "- Downloading Penn Treebank (Character)" 45 | mkdir -p pennchar 46 | cd pennchar 47 | mv ../simple-examples/data/ptb.char.train.txt train.txt 48 | mv ../simple-examples/data/ptb.char.test.txt test.txt 49 | mv ../simple-examples/data/ptb.char.valid.txt valid.txt 50 | cd .. 51 | 52 | rm -rf simple-examples/ 53 | 54 | echo "---" 55 | echo "Happy language modeling :)" 56 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from collections import Counter 5 | 6 | 7 | class Dictionary(object): 8 | def __init__(self): 9 | self.word2idx = {} 10 | self.idx2word = [] 11 | self.counter = Counter() 12 | self.total = 0 13 | 14 | def add_word(self, word): 15 | if word not in self.word2idx: 16 | self.idx2word.append(word) 17 | self.word2idx[word] = len(self.idx2word) - 1 18 | token_id = self.word2idx[word] 19 | self.counter[token_id] += 1 20 | self.total += 1 21 | return self.word2idx[word] 22 | 23 | def __len__(self): 24 | return len(self.idx2word) 25 | 26 | 27 | class Corpus(object): 28 | def __init__(self, path): 29 | self.dictionary = Dictionary() 30 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 31 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 32 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 33 | 34 | def tokenize(self, path): 35 | """Tokenizes a text file.""" 36 | assert os.path.exists(path), str(path) 37 | # Add words to the dictionary 38 | with open(path, 'r') as f: 39 | tokens = 0 40 | for line in f: 41 | words = line.split() + [''] 42 | tokens += len(words) 43 | for word in words: 44 | self.dictionary.add_word(word) 45 | 46 | # Tokenize file content 47 | with open(path, 'r') as f: 48 | ids = torch.LongTensor(tokens) 49 | token = 0 50 | for line in f: 51 | words = line.split() + [''] 52 | for word in words: 53 | ids[token] = self.dictionary.word2idx[word] 54 | token += 1 55 | 56 | return ids 57 | -------------------------------------------------------------------------------- /embed_regularize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class EmbeddingDropout(nn.Module): 8 | """ 9 | Embedding Layer. 10 | If embedding_dropout != 0 we apply dropout to word 'types' not 'tokens' as suggested 11 | in the paper https://arxiv.org/pdf/1512.05287.pdf. 12 | We first map the input sequences to the corresponding embeddings (from |V| -> embedding_dim) 13 | and THEN apply dropout. 14 | """ 15 | 16 | def __init__(self, num_embeddings, embedding_dim, embedding_dropout=0.): 17 | super().__init__() 18 | self.num_embeddings = num_embeddings 19 | self.embedding_dim = embedding_dim 20 | self.dropoute = embedding_dropout 21 | 22 | self.embed = nn.Embedding(num_embeddings=self.num_embeddings, 23 | embedding_dim=self.embedding_dim) 24 | 25 | def forward(self, words): 26 | if self.dropoute and self.training: 27 | mask = self.embed.weight.data.new().resize_((self.embed.weight.size(0), 1)).bernoulli_( 28 | 1 - self.dropoute).expand_as( 29 | self.embed.weight) / (1 - self.dropoute) 30 | masked_embed_weight = mask * self.embed.weight 31 | else: 32 | masked_embed_weight = self.embed.weight 33 | 34 | padding_idx = self.embed.padding_idx # be careful here to use the same 'padding_idx' name 35 | if padding_idx is None: 36 | padding_idx = -1 37 | 38 | X = torch.nn.functional.embedding(words, masked_embed_weight, 39 | padding_idx, self.embed.max_norm, self.embed.norm_type, 40 | self.embed.scale_grad_by_freq, self.embed.sparse 41 | ) 42 | return X 43 | 44 | def embedded_dropout(embed, words, dropout=0.1, scale=None): 45 | if dropout: 46 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) 47 | masked_embed_weight = mask * embed.weight 48 | else: 49 | masked_embed_weight = embed.weight 50 | if scale: 51 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight 52 | 53 | padding_idx = embed.padding_idx 54 | if padding_idx is None: 55 | padding_idx = -1 56 | 57 | X = torch.nn.functional.embedding(words, masked_embed_weight, 58 | padding_idx, embed.max_norm, embed.norm_type, 59 | embed.scale_grad_by_freq, embed.sparse 60 | ) 61 | return X 62 | 63 | if __name__ == '__main__': 64 | V = 50 65 | h = 4 66 | bptt = 10 67 | batch_size = 2 68 | 69 | embed = torch.nn.Embedding(V, h) 70 | 71 | words = np.random.random_integers(low=0, high=V-1, size=(batch_size, bptt)) 72 | words = torch.LongTensor(words) 73 | 74 | origX = embed(words) 75 | X = embedded_dropout(embed, words) 76 | 77 | print(origX) 78 | print(X) 79 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Language Modeling on Penn Tree Bank 3 | # 4 | # This file generates new sentences sampled from the language model 5 | # 6 | ############################################################################### 7 | 8 | import argparse 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | 13 | import data 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch PTB Language Model') 16 | 17 | # Model parameters. 18 | parser.add_argument('--data', type=str, default='./data/penn', 19 | help='location of the data corpus') 20 | parser.add_argument('--model', type=str, default='LSTM', 21 | help='type of recurrent net (LSTM, QRNN)') 22 | parser.add_argument('--checkpoint', type=str, default='./model.pt', 23 | help='model checkpoint to use') 24 | parser.add_argument('--outf', type=str, default='generated.txt', 25 | help='output file for generated text') 26 | parser.add_argument('--words', type=int, default='1000', 27 | help='number of words to generate') 28 | parser.add_argument('--seed', type=int, default=1111, 29 | help='random seed') 30 | parser.add_argument('--cuda', action='store_true', 31 | help='use CUDA') 32 | parser.add_argument('--temperature', type=float, default=1.0, 33 | help='temperature - higher will increase diversity') 34 | parser.add_argument('--log-interval', type=int, default=100, 35 | help='reporting interval') 36 | args = parser.parse_args() 37 | 38 | # Set the random seed manually for reproducibility. 39 | torch.manual_seed(args.seed) 40 | if torch.cuda.is_available(): 41 | if not args.cuda: 42 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 43 | else: 44 | torch.cuda.manual_seed(args.seed) 45 | 46 | if args.temperature < 1e-3: 47 | parser.error("--temperature has to be greater or equal 1e-3") 48 | 49 | with open(args.checkpoint, 'rb') as f: 50 | model = torch.load(f) 51 | model.eval() 52 | if args.model == 'QRNN': 53 | model.reset() 54 | 55 | if args.cuda: 56 | model.cuda() 57 | else: 58 | model.cpu() 59 | 60 | corpus = data.Corpus(args.data) 61 | ntokens = len(corpus.dictionary) 62 | hidden = model.init_hidden(1) 63 | input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) 64 | if args.cuda: 65 | input.data = input.data.cuda() 66 | 67 | with open(args.outf, 'w') as outf: 68 | for i in range(args.words): 69 | output, hidden = model(input, hidden) 70 | word_weights = output.squeeze().data.div(args.temperature).exp().cpu() 71 | word_idx = torch.multinomial(word_weights, 1)[0] 72 | input.data.fill_(word_idx) 73 | word = corpus.dictionary.idx2word[word_idx] 74 | 75 | outf.write(word + ('\n' if i % 20 == 19 else ' ')) 76 | 77 | if i % args.log_interval == 0: 78 | print('| Generated {}/{} words'.format(i, args.words)) 79 | -------------------------------------------------------------------------------- /asgd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | """ 6 | pytorch 1.2.0 ASGD 7 | """ 8 | 9 | class ASGD(Optimizer): 10 | """Implements Averaged Stochastic Gradient Descent. 11 | 12 | It has been proposed in `Acceleration of stochastic approximation by 13 | averaging`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-2) 19 | lambd (float, optional): decay term (default: 1e-4) 20 | alpha (float, optional): power for eta update (default: 0.75) 21 | t0 (float, optional): point at which to start averaging (default: 1e6) 22 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 23 | 24 | .. _Acceleration of stochastic approximation by averaging: 25 | http://dl.acm.org/citation.cfm?id=131098 26 | """ 27 | 28 | def __init__(self, params, lr=1e-2, lambd=1e-4, alpha=0.75, t0=1e6, weight_decay=0): 29 | if not 0.0 <= lr: 30 | raise ValueError("Invalid learning rate: {}".format(lr)) 31 | if not 0.0 <= weight_decay: 32 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 33 | 34 | defaults = dict(lr=lr, lambd=lambd, alpha=alpha, t0=t0, 35 | weight_decay=weight_decay) 36 | super(ASGD, self).__init__(params, defaults) 37 | 38 | def step(self, closure=None): 39 | """Performs a single optimization step. 40 | 41 | Arguments: 42 | closure (callable, optional): A closure that reevaluates the model 43 | and returns the loss. 44 | """ 45 | loss = None 46 | if closure is not None: 47 | loss = closure() 48 | 49 | for group in self.param_groups: 50 | for p in group['params']: 51 | if p.grad is None: 52 | continue 53 | grad = p.grad.data 54 | if grad.is_sparse: 55 | raise RuntimeError('ASGD does not support sparse gradients') 56 | state = self.state[p] 57 | 58 | # State initialization 59 | if len(state) == 0: 60 | state['step'] = 0 61 | state['eta'] = group['lr'] 62 | state['mu'] = 1 63 | state['ax'] = torch.zeros_like(p.data) 64 | 65 | state['step'] += 1 66 | 67 | if group['weight_decay'] != 0: 68 | grad = grad.add(group['weight_decay'], p.data) 69 | 70 | # decay term 71 | p.data.mul_(1 - group['lambd'] * state['eta']) 72 | 73 | # update parameter 74 | p.data.add_(-state['eta'], grad) 75 | 76 | # averaging 77 | if state['mu'] != 1: 78 | state['ax'].add_(p.data.sub(state['ax']).mul(state['mu'])) 79 | else: 80 | state['ax'].copy_(p.data) 81 | 82 | # update eta and mu 83 | state['eta'] = (group['lr'] / 84 | math.pow((1 + group['lambd'] * group['lr'] * state['step']), group['alpha'])) 85 | state['mu'] = 1 / max(1, state['step'] - group['t0']) 86 | 87 | return loss -------------------------------------------------------------------------------- /weight_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from functools import wraps 4 | 5 | class WeightDrop(torch.nn.Module): 6 | def __init__(self, module, weights, dropout=0, variational=False): 7 | super(WeightDrop, self).__init__() 8 | self.module = module 9 | self.weights = weights 10 | self.dropout = dropout 11 | self.variational = variational 12 | self._setup() 13 | 14 | def widget_demagnetizer_y2k_edition(*args, **kwargs): 15 | # We need to replace flatten_parameters with a nothing function 16 | # It must be a function rather than a lambda as otherwise pickling explodes 17 | # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! 18 | # (╯°□°)╯︵ ┻━┻ 19 | return 20 | 21 | def _setup(self): 22 | # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN 23 | if issubclass(type(self.module), torch.nn.RNNBase): 24 | self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition 25 | 26 | for name_w in self.weights: 27 | print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) 28 | w = getattr(self.module, name_w) 29 | del self.module._parameters[name_w] 30 | self.module.register_parameter(name_w + '_raw', Parameter(w.data)) 31 | 32 | def _setweights(self): 33 | for name_w in self.weights: 34 | raw_w = getattr(self.module, name_w + '_raw') 35 | w = None 36 | if self.variational: 37 | mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) 38 | if raw_w.is_cuda: mask = mask.cuda() 39 | mask = torch.nn.functional.dropout(mask, p=self.dropout, training=self.training) 40 | # w = mask.expand_as(raw_w) * raw_w 41 | w = torch.nn.Parameter(mask.expand_as(raw_w) * raw_w) 42 | else: 43 | # w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) 44 | w = torch.nn.Parameter(torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)) 45 | setattr(self.module, name_w, w) 46 | 47 | def forward(self, *args): 48 | self._setweights() 49 | return self.module.forward(*args) 50 | 51 | if __name__ == '__main__': 52 | import torch 53 | from weight_drop import WeightDrop 54 | 55 | # Input is (seq, batch, input) 56 | x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda() 57 | h0 = None 58 | 59 | ### 60 | 61 | print('Testing WeightDrop') 62 | print('=-=-=-=-=-=-=-=-=-=') 63 | 64 | ### 65 | 66 | print('Testing WeightDrop with Linear') 67 | 68 | lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9) 69 | lin.cuda() 70 | run1 = [x.sum() for x in lin(x).data] 71 | run2 = [x.sum() for x in lin(x).data] 72 | 73 | print('All items should be different') 74 | print('Run 1:', run1) 75 | print('Run 2:', run2) 76 | 77 | assert run1[0] != run2[0] 78 | assert run1[1] != run2[1] 79 | 80 | print('---') 81 | 82 | ### 83 | 84 | print('Testing WeightDrop with LSTM') 85 | 86 | wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9) 87 | wdrnn.cuda() 88 | 89 | run1 = [x.sum() for x in wdrnn(x, h0)[0].data] 90 | run2 = [x.sum() for x in wdrnn(x, h0)[0].data] 91 | 92 | print('First timesteps should be equal, all others should differ') 93 | print('Run 1:', run1) 94 | print('Run 2:', run2) 95 | 96 | # First time step, not influenced by hidden to hidden weights, should be equal 97 | assert run1[0] == run2[0] 98 | # Second step should not 99 | assert run1[1] != run2[1] 100 | 101 | print('---') 102 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from embed_regularize import embedded_dropout 5 | from locked_dropout import LockedDropout 6 | from weight_drop import WeightDrop 7 | 8 | class AWD(nn.Module): 9 | """Container module with an encoder, a recurrent module, and a decoder.""" 10 | 11 | def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, 12 | dropout=0.5, dropouth=0.5, dropouti=0.5, 13 | dropoute=0.1, wdrop=0, tie_weights=False): 14 | super(AWD, self).__init__() 15 | self.lockdrop = LockedDropout() 16 | self.idrop = nn.Dropout(dropouti) 17 | self.hdrop = nn.Dropout(dropouth) 18 | self.drop = nn.Dropout(dropout) 19 | self.encoder = nn.Embedding(ntoken, ninp) 20 | assert rnn_type in ['LSTM', 'QRNN', 'GRU'], 'RNN type is not supported' 21 | if rnn_type == 'LSTM': 22 | self.rnns = [torch.nn.LSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), 1, dropout=0) for l in range(nlayers)] 23 | if wdrop: 24 | self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns] 25 | if rnn_type == 'GRU': 26 | self.rnns = [torch.nn.GRU(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else ninp, 1, dropout=0) for l in range(nlayers)] 27 | if wdrop: 28 | self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns] 29 | elif rnn_type == 'QRNN': 30 | from torchqrnn import QRNNLayer 31 | self.rnns = [QRNNLayer(input_size=ninp if l == 0 else nhid, hidden_size=nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True) for l in range(nlayers)] 32 | for rnn in self.rnns: 33 | rnn.linear = WeightDrop(rnn.linear, ['weight'], dropout=wdrop) 34 | print(self.rnns) 35 | self.rnns = torch.nn.ModuleList(self.rnns) 36 | self.decoder = nn.Linear(nhid, ntoken) 37 | 38 | # Optionally tie weights as in: 39 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 40 | # https://arxiv.org/abs/1608.05859 41 | # and 42 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 43 | # https://arxiv.org/abs/1611.01462 44 | if tie_weights: 45 | #if nhid != ninp: 46 | # raise ValueError('When using the tied flag, nhid must be equal to emsize') 47 | self.decoder.weight = self.encoder.weight 48 | 49 | self.init_weights() 50 | 51 | self.rnn_type = rnn_type 52 | self.ninp = ninp 53 | self.nhid = nhid 54 | self.nlayers = nlayers 55 | self.dropout = dropout 56 | self.dropouti = dropouti 57 | self.dropouth = dropouth 58 | self.dropoute = dropoute 59 | self.tie_weights = tie_weights 60 | 61 | def reset(self): 62 | if self.rnn_type == 'QRNN': [r.reset() for r in self.rnns] 63 | 64 | def init_weights(self): 65 | initrange = 0.1 66 | self.encoder.weight.data.uniform_(-initrange, initrange) 67 | self.decoder.bias.data.fill_(0) 68 | self.decoder.weight.data.uniform_(-initrange, initrange) 69 | 70 | def forward(self, input, hidden, return_h=False): 71 | emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0) 72 | #emb = self.idrop(emb) 73 | 74 | emb = self.lockdrop(emb, self.dropouti) 75 | 76 | raw_output = emb 77 | new_hidden = [] 78 | #raw_output, hidden = self.rnn(emb, hidden) 79 | raw_outputs = [] 80 | outputs = [] 81 | 82 | for l, rnn in enumerate(self.rnns): 83 | rnn.module.flatten_parameters() # not working 84 | current_input = raw_output 85 | raw_output, new_h = rnn(raw_output, hidden[l]) 86 | new_hidden.append(new_h) 87 | raw_outputs.append(raw_output) 88 | if l != self.nlayers - 1: 89 | #self.hdrop(raw_output) 90 | raw_output = self.lockdrop(raw_output, self.dropouth) 91 | outputs.append(raw_output) 92 | hidden = new_hidden 93 | 94 | output = self.lockdrop(raw_output, self.dropout) 95 | outputs.append(output) 96 | 97 | result = output.view(output.size(0)*output.size(1), output.size(2)) 98 | if return_h: 99 | return result, hidden, raw_outputs, outputs 100 | return result, hidden 101 | 102 | def init_hidden(self, bsz): 103 | weight = next(self.parameters()).data 104 | if self.rnn_type == 'LSTM': 105 | return [(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_(), 106 | weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_()) 107 | for l in range(self.nlayers)] 108 | elif self.rnn_type == 'QRNN' or self.rnn_type == 'GRU': 109 | return [weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_() 110 | for l in range(self.nlayers)] 111 | -------------------------------------------------------------------------------- /pointer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | import data 10 | import model 11 | 12 | from utils import batchify, get_batch, repackage_hidden 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 15 | parser.add_argument('--data', type=str, default='data/penn', 16 | help='location of the data corpus') 17 | parser.add_argument('--model', type=str, default='LSTM', 18 | help='type of recurrent net (LSTM, QRNN)') 19 | parser.add_argument('--save', type=str,default='best.pt', 20 | help='model to use the pointer over') 21 | parser.add_argument('--cuda', action='store_false', 22 | help='use CUDA') 23 | parser.add_argument('--bptt', type=int, default=5000, 24 | help='sequence length') 25 | parser.add_argument('--window', type=int, default=3785, 26 | help='pointer window length') 27 | parser.add_argument('--theta', type=float, default=0.6625523432485668, 28 | help='mix between uniform distribution and pointer softmax distribution over previous words') 29 | parser.add_argument('--lambdasm', type=float, default=0.12785920428335693, 30 | help='linear mix between only pointer (1) and only vocab (0) distribution') 31 | args = parser.parse_args() 32 | 33 | ############################################################################### 34 | # Load data 35 | ############################################################################### 36 | 37 | corpus = data.Corpus(args.data) 38 | 39 | eval_batch_size = 1 40 | test_batch_size = 1 41 | #train_data = batchify(corpus.train, args.batch_size) 42 | val_data = batchify(corpus.valid, test_batch_size, args) 43 | test_data = batchify(corpus.test, test_batch_size, args) 44 | 45 | ############################################################################### 46 | # Build the model 47 | ############################################################################### 48 | 49 | ntokens = len(corpus.dictionary) 50 | criterion = nn.CrossEntropyLoss() 51 | 52 | def one_hot(idx, size, cuda=True): 53 | a = np.zeros((1, size), np.float32) 54 | a[0][idx] = 1 55 | v = Variable(torch.from_numpy(a)) 56 | if cuda: v = v.cuda() 57 | return v 58 | 59 | def evaluate(data_source, batch_size=10, window=args.window): 60 | # Turn on evaluation mode which disables dropout. 61 | if args.model == 'QRNN': model.reset() 62 | model.eval() 63 | total_loss = 0 64 | ntokens = len(corpus.dictionary) 65 | hidden = model.init_hidden(batch_size) 66 | next_word_history = None 67 | pointer_history = None 68 | for i in range(0, data_source.size(0) - 1, args.bptt): 69 | if i > 0: print(i, len(data_source), math.exp(total_loss / i)) 70 | data, targets = get_batch(data_source, i, evaluation=True, args=args) 71 | output, hidden, rnn_outs, _ = model(data, hidden, return_h=True) 72 | rnn_out = rnn_outs[-1].squeeze() 73 | output_flat = output.view(-1, ntokens) 74 | ### 75 | # Fill pointer history 76 | start_idx = len(next_word_history) if next_word_history is not None else 0 77 | next_word_history = torch.cat([one_hot(t.data[0], ntokens) for t in targets]) if next_word_history is None else torch.cat([next_word_history, torch.cat([one_hot(t.data[0], ntokens) for t in targets])]) 78 | #print(next_word_history) 79 | pointer_history = Variable(rnn_out.data) if pointer_history is None else torch.cat([pointer_history, Variable(rnn_out.data)], dim=0) 80 | #print(pointer_history) 81 | ### 82 | # Built-in cross entropy 83 | # total_loss += len(data) * criterion(output_flat, targets).data[0] 84 | ### 85 | # Manual cross entropy 86 | # softmax_output_flat = torch.nn.functional.softmax(output_flat) 87 | # soft = torch.gather(softmax_output_flat, dim=1, index=targets.view(-1, 1)) 88 | # entropy = -torch.log(soft) 89 | # total_loss += len(data) * entropy.mean().data[0] 90 | ### 91 | # Pointer manual cross entropy 92 | loss = 0 93 | softmax_output_flat = torch.nn.functional.softmax(output_flat) 94 | for idx, vocab_loss in enumerate(softmax_output_flat): 95 | p = vocab_loss 96 | if start_idx + idx > window: 97 | valid_next_word = next_word_history[start_idx + idx - window:start_idx + idx] 98 | valid_pointer_history = pointer_history[start_idx + idx - window:start_idx + idx] 99 | logits = torch.mv(valid_pointer_history, rnn_out[idx]) 100 | theta = args.theta 101 | ptr_attn = torch.nn.functional.softmax(theta * logits).view(-1, 1) 102 | ptr_dist = (ptr_attn.expand_as(valid_next_word) * valid_next_word).sum(0).squeeze() 103 | lambdah = args.lambdasm 104 | p = lambdah * ptr_dist + (1 - lambdah) * vocab_loss 105 | ### 106 | target_loss = p[targets[idx].data] 107 | loss += (-torch.log(target_loss)).data[0] 108 | total_loss += loss / batch_size 109 | ### 110 | hidden = repackage_hidden(hidden) 111 | next_word_history = next_word_history[-window:] 112 | pointer_history = pointer_history[-window:] 113 | return total_loss / len(data_source) 114 | 115 | # Load the best saved model. 116 | with open(args.save, 'rb') as f: 117 | if not args.cuda: 118 | model = torch.load(f, map_location=lambda storage, loc: storage) 119 | else: 120 | model = torch.load(f) 121 | print(model) 122 | 123 | # Run on val data. 124 | val_loss = evaluate(val_data, test_batch_size) 125 | print('=' * 89) 126 | print('| End of pointer | val loss {:5.2f} | val ppl {:8.2f}'.format( 127 | val_loss, math.exp(val_loss))) 128 | print('=' * 89) 129 | 130 | # Run on test data. 131 | test_loss = evaluate(test_data, test_batch_size) 132 | print('=' * 89) 133 | print('| End of pointer | test loss {:5.2f} | test ppl {:8.2f}'.format( 134 | test_loss, math.exp(test_loss))) 135 | print('=' * 89) 136 | -------------------------------------------------------------------------------- /splitcross.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import numpy as np 7 | 8 | 9 | class SplitCrossEntropyLoss(nn.Module): 10 | r'''SplitCrossEntropyLoss calculates an approximate softmax''' 11 | def __init__(self, hidden_size, splits, verbose=False): 12 | # We assume splits is [0, split1, split2, N] where N >= |V| 13 | # For example, a vocab of 1000 words may have splits [0] + [100, 500] + [inf] 14 | super(SplitCrossEntropyLoss, self).__init__() 15 | self.hidden_size = hidden_size 16 | self.splits = [0] + splits + [100 * 1000000] 17 | self.nsplits = len(self.splits) - 1 18 | self.stats = defaultdict(list) 19 | self.verbose = verbose 20 | # Each of the splits that aren't in the head require a pretend token, we'll call them tombstones 21 | # The probability given to this tombstone is the probability of selecting an item from the represented split 22 | if self.nsplits > 1: 23 | self.tail_vectors = nn.Parameter(torch.zeros(self.nsplits - 1, hidden_size)) 24 | self.tail_bias = nn.Parameter(torch.zeros(self.nsplits - 1)) 25 | 26 | def logprob(self, weight, bias, hiddens, splits=None, softmaxed_head_res=None, verbose=False): 27 | # First we perform the first softmax on the head vocabulary and the tombstones 28 | if softmaxed_head_res is None: 29 | start, end = self.splits[0], self.splits[1] 30 | head_weight = None if end - start == 0 else weight[start:end] 31 | head_bias = None if end - start == 0 else bias[start:end] 32 | # We only add the tombstones if we have more than one split 33 | if self.nsplits > 1: 34 | head_weight = self.tail_vectors if head_weight is None else torch.cat([head_weight, self.tail_vectors]) 35 | head_bias = self.tail_bias if head_bias is None else torch.cat([head_bias, self.tail_bias]) 36 | 37 | # Perform the softmax calculation for the word vectors in the head for all splits 38 | # We need to guard against empty splits as torch.cat does not like random lists 39 | head_res = torch.nn.functional.linear(hiddens, head_weight, bias=head_bias) 40 | softmaxed_head_res = torch.nn.functional.log_softmax(head_res, dim=-1) 41 | 42 | if splits is None: 43 | splits = list(range(self.nsplits)) 44 | 45 | results = [] 46 | running_offset = 0 47 | for idx in splits: 48 | 49 | # For those targets in the head (idx == 0) we only need to return their loss 50 | if idx == 0: 51 | results.append(softmaxed_head_res[:, :-(self.nsplits - 1)]) 52 | 53 | # If the target is in one of the splits, the probability is the p(tombstone) * p(word within tombstone) 54 | else: 55 | start, end = self.splits[idx], self.splits[idx + 1] 56 | tail_weight = weight[start:end] 57 | tail_bias = bias[start:end] 58 | 59 | # Calculate the softmax for the words in the tombstone 60 | tail_res = torch.nn.functional.linear(hiddens, tail_weight, bias=tail_bias) 61 | 62 | # Then we calculate p(tombstone) * p(word in tombstone) 63 | # Adding is equivalent to multiplication in log space 64 | head_entropy = (softmaxed_head_res[:, -idx]).contiguous() 65 | tail_entropy = torch.nn.functional.log_softmax(tail_res, dim=-1) 66 | results.append(head_entropy.view(-1, 1) + tail_entropy) 67 | 68 | if len(results) > 1: 69 | return torch.cat(results, dim=1) 70 | return results[0] 71 | 72 | def split_on_targets(self, hiddens, targets): 73 | # Split the targets into those in the head and in the tail 74 | split_targets = [] 75 | split_hiddens = [] 76 | 77 | # Determine to which split each element belongs (for each start split value, add 1 if equal or greater) 78 | # This method appears slower at least for WT-103 values for approx softmax 79 | #masks = [(targets >= self.splits[idx]).view(1, -1) for idx in range(1, self.nsplits)] 80 | #mask = torch.sum(torch.cat(masks, dim=0), dim=0) 81 | ### 82 | # This is equally fast for smaller splits as method below but scales linearly 83 | mask = None 84 | for idx in range(1, self.nsplits): 85 | partial_mask = targets >= self.splits[idx] 86 | mask = mask + partial_mask if mask is not None else partial_mask 87 | ### 88 | #masks = torch.stack([targets] * (self.nsplits - 1)) 89 | #mask = torch.sum(masks >= self.split_starts, dim=0) 90 | for idx in range(self.nsplits): 91 | # If there are no splits, avoid costly masked select 92 | if self.nsplits == 1: 93 | split_targets, split_hiddens = [targets], [hiddens] 94 | continue 95 | # If all the words are covered by earlier targets, we have empties so later stages don't freak out 96 | if sum(len(t) for t in split_targets) == len(targets): 97 | split_targets.append([]) 98 | split_hiddens.append([]) 99 | continue 100 | # Are you in our split? 101 | tmp_mask = mask == idx 102 | split_targets.append(torch.masked_select(targets, tmp_mask)) 103 | split_hiddens.append(hiddens.masked_select(tmp_mask.unsqueeze(1).expand_as(hiddens)).view(-1, hiddens.size(1))) 104 | return split_targets, split_hiddens 105 | 106 | def forward(self, weight, bias, hiddens, targets, verbose=False): 107 | if self.verbose or verbose: 108 | for idx in sorted(self.stats): 109 | print('{}: {}'.format(idx, int(np.mean(self.stats[idx]))), end=', ') 110 | print() 111 | 112 | total_loss = None 113 | if len(hiddens.size()) > 2: hiddens = hiddens.view(-1, hiddens.size(2)) 114 | 115 | split_targets, split_hiddens = self.split_on_targets(hiddens, targets) 116 | 117 | # First we perform the first softmax on the head vocabulary and the tombstones 118 | start, end = self.splits[0], self.splits[1] 119 | head_weight = None if end - start == 0 else weight[start:end] 120 | head_bias = None if end - start == 0 else bias[start:end] 121 | 122 | # We only add the tombstones if we have more than one split 123 | if self.nsplits > 1: 124 | head_weight = self.tail_vectors if head_weight is None else torch.cat([head_weight, self.tail_vectors]) 125 | head_bias = self.tail_bias if head_bias is None else torch.cat([head_bias, self.tail_bias]) 126 | 127 | # Perform the softmax calculation for the word vectors in the head for all splits 128 | # We need to guard against empty splits as torch.cat does not like random lists 129 | combo = torch.cat([split_hiddens[i] for i in range(self.nsplits) if len(split_hiddens[i])]) 130 | ### 131 | all_head_res = torch.nn.functional.linear(combo, head_weight, bias=head_bias) 132 | softmaxed_all_head_res = torch.nn.functional.log_softmax(all_head_res, dim=-1) 133 | if self.verbose or verbose: 134 | self.stats[0].append(combo.size()[0] * head_weight.size()[0]) 135 | 136 | running_offset = 0 137 | for idx in range(self.nsplits): 138 | # If there are no targets for this split, continue 139 | if len(split_targets[idx]) == 0: continue 140 | 141 | # For those targets in the head (idx == 0) we only need to return their loss 142 | if idx == 0: 143 | softmaxed_head_res = softmaxed_all_head_res[running_offset:running_offset + len(split_hiddens[idx])] 144 | entropy = -torch.gather(softmaxed_head_res, dim=1, index=split_targets[idx].view(-1, 1)) 145 | # If the target is in one of the splits, the probability is the p(tombstone) * p(word within tombstone) 146 | else: 147 | softmaxed_head_res = softmaxed_all_head_res[running_offset:running_offset + len(split_hiddens[idx])] 148 | 149 | if self.verbose or verbose: 150 | start, end = self.splits[idx], self.splits[idx + 1] 151 | tail_weight = weight[start:end] 152 | self.stats[idx].append(split_hiddens[idx].size()[0] * tail_weight.size()[0]) 153 | 154 | # Calculate the softmax for the words in the tombstone 155 | tail_res = self.logprob(weight, bias, split_hiddens[idx], splits=[idx], softmaxed_head_res=softmaxed_head_res) 156 | 157 | # Then we calculate p(tombstone) * p(word in tombstone) 158 | # Adding is equivalent to multiplication in log space 159 | head_entropy = softmaxed_head_res[:, -idx] 160 | # All indices are shifted - if the first split handles [0,...,499] then the 500th in the second split will be 0 indexed 161 | indices = (split_targets[idx] - self.splits[idx]).view(-1, 1) 162 | # Warning: if you don't squeeze, you get an N x 1 return, which acts oddly with broadcasting 163 | tail_entropy = torch.gather(torch.nn.functional.log_softmax(tail_res, dim=-1), dim=1, index=indices).squeeze() 164 | entropy = -(head_entropy + tail_entropy) 165 | ### 166 | running_offset += len(split_hiddens[idx]) 167 | total_loss = entropy.float().sum() if total_loss is None else total_loss + entropy.float().sum() 168 | 169 | return (total_loss / len(targets)).type_as(weight) 170 | 171 | 172 | if __name__ == '__main__': 173 | np.random.seed(42) 174 | torch.manual_seed(42) 175 | if torch.cuda.is_available(): 176 | torch.cuda.manual_seed(42) 177 | 178 | V = 8 179 | H = 10 180 | N = 100 181 | E = 10 182 | 183 | embed = torch.nn.Embedding(V, H) 184 | crit = SplitCrossEntropyLoss(hidden_size=H, splits=[V // 2]) 185 | bias = torch.nn.Parameter(torch.ones(V)) 186 | optimizer = torch.optim.SGD(list(embed.parameters()) + list(crit.parameters()), lr=1) 187 | 188 | for _ in range(E): 189 | prev = torch.autograd.Variable((torch.rand(N, 1) * 0.999 * V).int().long()) 190 | x = torch.autograd.Variable((torch.rand(N, 1) * 0.999 * V).int().long()) 191 | y = embed(prev).squeeze() 192 | c = crit(embed.weight, bias, y, x.view(N)) 193 | print('Crit', c.exp().data[0]) 194 | 195 | logprobs = crit.logprob(embed.weight, bias, y[:2]).exp() 196 | print(logprobs) 197 | print(logprobs.sum(dim=1)) 198 | 199 | optimizer.zero_grad() 200 | c.backward() 201 | optimizer.step() 202 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | import numpy as np 5 | np.random.seed(331) 6 | import torch 7 | import torch.nn as nn 8 | 9 | import data 10 | import model 11 | 12 | from utils import batchify, get_batch, repackage_hidden 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 15 | parser.add_argument('--data', type=str, default='data/penn/', 16 | help='location of the data corpus') 17 | parser.add_argument('--model', type=str, default='LSTM', 18 | help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)') 19 | parser.add_argument('--emsize', type=int, default=400, 20 | help='size of word embeddings') 21 | parser.add_argument('--nhid', type=int, default=1150, 22 | help='number of hidden units per layer') 23 | parser.add_argument('--nlayers', type=int, default=3, 24 | help='number of layers') 25 | parser.add_argument('--lr', type=float, default=30, 26 | help='initial learning rate') 27 | parser.add_argument('--clip', type=float, default=0.25, 28 | help='gradient clipping') 29 | parser.add_argument('--epochs', type=int, default=8000, 30 | help='upper epoch limit') 31 | parser.add_argument('--batch_size', type=int, default=80, metavar='N', 32 | help='batch size') 33 | parser.add_argument('--bptt', type=int, default=70, 34 | help='sequence length') 35 | parser.add_argument('--dropout', type=float, default=0.4, 36 | help='dropout applied to layers (0 = no dropout)') 37 | parser.add_argument('--dropouth', type=float, default=0.3, 38 | help='dropout for rnn layers (0 = no dropout)') 39 | parser.add_argument('--dropouti', type=float, default=0.65, 40 | help='dropout for input embedding layers (0 = no dropout)') 41 | parser.add_argument('--dropoute', type=float, default=0.1, 42 | help='dropout to remove words from embedding layer (0 = no dropout)') 43 | parser.add_argument('--wdrop', type=float, default=0.5, 44 | help='amount of weight dropout to apply to the RNN hidden to hidden matrix') 45 | parser.add_argument('--tied', action='store_false', 46 | help='tie the word embedding and softmax weights') 47 | parser.add_argument('--seed', type=int, default=1111, 48 | help='random seed') 49 | parser.add_argument('--nonmono', type=int, default=5, 50 | help='random seed') 51 | parser.add_argument('--cuda', action='store_false', 52 | help='use CUDA') 53 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 54 | help='report interval') 55 | randomhash = ''.join(str(time.time()).split('.')) 56 | parser.add_argument('--save', type=str, default=randomhash+'.pt', 57 | help='path to save the final model') 58 | parser.add_argument('--alpha', type=float, default=2, 59 | help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)') 60 | parser.add_argument('--beta', type=float, default=1, 61 | help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)') 62 | parser.add_argument('--wdecay', type=float, default=1.2e-6, 63 | help='weight decay applied to all weights') 64 | args = parser.parse_args() 65 | 66 | # Set the random seed manually for reproducibility. 67 | torch.manual_seed(args.seed) 68 | if torch.cuda.is_available(): 69 | if not args.cuda: 70 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 71 | else: 72 | torch.cuda.manual_seed(args.seed) 73 | 74 | ############################################################################### 75 | # Load data 76 | ############################################################################### 77 | 78 | corpus = data.Corpus(args.data) 79 | 80 | eval_batch_size = 10 81 | test_batch_size = 1 82 | train_data = batchify(corpus.train, args.batch_size, args) 83 | val_data = batchify(corpus.valid, eval_batch_size, args) 84 | test_data = batchify(corpus.test, test_batch_size, args) 85 | 86 | ############################################################################### 87 | # Build the model 88 | ############################################################################### 89 | 90 | ntokens = len(corpus.dictionary) 91 | model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.dropouth, args.dropouti, args.dropoute, args.wdrop, args.tied) 92 | if args.cuda: 93 | model.cuda() 94 | total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in model.parameters()) 95 | print('Args:', args) 96 | print('Model total parameters:', total_params) 97 | 98 | criterion = nn.CrossEntropyLoss() 99 | 100 | ############################################################################### 101 | # Training code 102 | ############################################################################### 103 | 104 | def evaluate(data_source, batch_size=10): 105 | # Turn on evaluation mode which disables dropout. 106 | if args.model == 'QRNN': model.reset() 107 | model.eval() 108 | total_loss = 0 109 | ntokens = len(corpus.dictionary) 110 | hidden = model.init_hidden(batch_size) 111 | for i in range(0, data_source.size(0) - 1, args.bptt): 112 | data, targets = get_batch(data_source, i, args, evaluation=True) 113 | output, hidden = model(data, hidden) 114 | output_flat = output.view(-1, ntokens) 115 | total_loss += len(data) * criterion(output_flat, targets).data 116 | hidden = repackage_hidden(hidden) 117 | return total_loss[0] / len(data_source) 118 | 119 | 120 | def train(): 121 | # Turn on training mode which enables dropout. 122 | if args.model == 'QRNN': model.reset() 123 | total_loss = 0 124 | start_time = time.time() 125 | ntokens = len(corpus.dictionary) 126 | hidden = model.init_hidden(args.batch_size) 127 | batch, i = 0, 0 128 | while i < train_data.size(0) - 1 - 1: 129 | bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. 130 | # Prevent excessively small or negative sequence lengths 131 | seq_len = max(5, int(np.random.normal(bptt, 5))) 132 | # There's a very small chance that it could select a very long sequence length resulting in OOM 133 | seq_len = min(seq_len, args.bptt + 10) 134 | 135 | lr2 = optimizer.param_groups[0]['lr'] 136 | optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt 137 | model.train() 138 | data, targets = get_batch(train_data, i, args, seq_len=seq_len) 139 | 140 | # Starting each batch, we detach the hidden state from how it was previously produced. 141 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 142 | hidden = repackage_hidden(hidden) 143 | optimizer.zero_grad() 144 | 145 | output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True) 146 | raw_loss = criterion(output.view(-1, ntokens), targets) 147 | 148 | loss = raw_loss 149 | # Activiation Regularization 150 | loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) 151 | # Temporal Activation Regularization (slowness) 152 | loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) 153 | loss.backward() 154 | 155 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 156 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) 157 | optimizer.step() 158 | 159 | total_loss += raw_loss.data 160 | optimizer.param_groups[0]['lr'] = lr2 161 | if batch % args.log_interval == 0 and batch > 0: 162 | cur_loss = total_loss[0] / args.log_interval 163 | elapsed = time.time() - start_time 164 | print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 165 | 'loss {:5.2f} | ppl {:8.2f}'.format( 166 | epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], 167 | elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) 168 | total_loss = 0 169 | start_time = time.time() 170 | ### 171 | batch += 1 172 | i += seq_len 173 | 174 | 175 | # Load the best saved model. 176 | with open(args.save, 'rb') as f: 177 | model = torch.load(f) 178 | 179 | 180 | # Loop over epochs. 181 | lr = args.lr 182 | stored_loss = evaluate(val_data) 183 | best_val_loss = [] 184 | # At any point you can hit Ctrl + C to break out of training early. 185 | try: 186 | #optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 187 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 188 | for epoch in range(1, args.epochs+1): 189 | epoch_start_time = time.time() 190 | train() 191 | if 't0' in optimizer.param_groups[0]: 192 | tmp = {} 193 | for prm in model.parameters(): 194 | tmp[prm] = prm.data.clone() 195 | prm.data = optimizer.state[prm]['ax'].clone() 196 | 197 | val_loss2 = evaluate(val_data) 198 | print('-' * 89) 199 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 200 | 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 201 | val_loss2, math.exp(val_loss2))) 202 | print('-' * 89) 203 | 204 | if val_loss2 < stored_loss: 205 | with open(args.save, 'wb') as f: 206 | torch.save(model, f) 207 | print('Saving Averaged!') 208 | stored_loss = val_loss2 209 | 210 | for prm in model.parameters(): 211 | prm.data = tmp[prm].clone() 212 | 213 | if (len(best_val_loss)>args.nonmono and val_loss2 > min(best_val_loss[:-args.nonmono])): 214 | print('Done!') 215 | import sys 216 | sys.exit(1) 217 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 218 | #optimizer.param_groups[0]['lr'] /= 2. 219 | best_val_loss.append(val_loss2) 220 | 221 | except KeyboardInterrupt: 222 | print('-' * 89) 223 | print('Exiting from training early') 224 | 225 | # Load the best saved model. 226 | with open(args.save, 'rb') as f: 227 | model = torch.load(f) 228 | 229 | # Run on test data. 230 | test_loss = evaluate(test_data, test_batch_size) 231 | print('=' * 89) 232 | print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 233 | test_loss, math.exp(test_loss))) 234 | print('=' * 89) 235 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AWD-LSTM with PyTorch 1.2.0 2 | I have made minor changes to the original AWD-LSTM codebase (that uses PyTorch 0.1.12 or 0.4) to make it compatible with PyTorch [1.2.0](https://pytorch.org/docs/1.2.0/). 3 | 4 | ## Software Requirements 5 | Create a conda environment to run the code: 6 | ``` 7 | conda create -n awd python=3.6 8 | source activate awd 9 | conda install pytorch==1.2.0 torchvision==0.4.0 cudatoolkit=10.0 -c pytorch 10 | conda install -c anaconda cupy 11 | pip install pynvrtc git+https://github.com/salesforce/pytorch-qrnn 12 | ``` 13 | Note that `cupy`, `pynvtrc`, and `pytorch-qrnn` are required only if you want to run a QRNN-LSTM. 14 | 15 | ------------------------------ 16 | _Original readme:_ 17 | # LSTM and QRNN Language Model Toolkit 18 | 19 | This repository contains the code used for two [Salesforce Research](https://einstein.ai/) papers: 20 | + [Regularizing and Optimizing LSTM Language Models](https://arxiv.org/abs/1708.02182) 21 | + [An Analysis of Neural Language Modeling at Multiple Scales](https://arxiv.org/abs/1803.08240) 22 | This code was originally forked from the [PyTorch word level language modeling example](https://github.com/pytorch/examples/tree/master/word_language_model). 23 | 24 | The model comes with instructions to train: 25 | + word level language models over the Penn Treebank (PTB), [WikiText-2](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) (WT2), and [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) (WT103) datasets 26 | 27 | + character level language models over the Penn Treebank (PTBC) and Hutter Prize dataset (enwik8) 28 | 29 | The model can be composed of an LSTM or a [Quasi-Recurrent Neural Network](https://github.com/salesforce/pytorch-qrnn/) (QRNN) which is two or more times faster than the cuDNN LSTM in this setup while achieving equivalent or better accuracy. 30 | 31 | + Install PyTorch 0.4 32 | + Run `getdata.sh` to acquire the Penn Treebank and WikiText-2 datasets 33 | + Train the base model using `main.py` 34 | + (Optionally) Finetune the model using `finetune.py` 35 | + (Optionally) Apply the [continuous cache pointer](https://arxiv.org/abs/1612.04426) to the finetuned model using `pointer.py` 36 | 37 | If you use this code or our results in your research, please cite as appropriate: 38 | 39 | ``` 40 | @article{merityRegOpt, 41 | title={{Regularizing and Optimizing LSTM Language Models}}, 42 | author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard}, 43 | journal={arXiv preprint arXiv:1708.02182}, 44 | year={2017} 45 | } 46 | ``` 47 | 48 | ``` 49 | @article{merityAnalysis, 50 | title={{An Analysis of Neural Language Modeling at Multiple Scales}}, 51 | author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard}, 52 | journal={arXiv preprint arXiv:1803.08240}, 53 | year={2018} 54 | } 55 | ``` 56 | ## Update (June/13/2018) 57 | 58 | The codebase is now PyTorch 0.4 compatible for most use cases (a big shoutout to https://github.com/shawntan for a fairly comprehensive PR https://github.com/salesforce/awd-lstm-lm/pull/43). Mild readjustments to hyperparameters may be necessary to obtain quoted performance. If you desire exact reproducibility (or wish to run on PyTorch 0.3 or lower), we suggest using an older commit of this repository. We are still working on `pointer`, `finetune` and `generate` functionalities. 59 | 60 | ## Software Requirements 61 | 62 | Python 3 and PyTorch 0.4 are required for the current codebase. 63 | 64 | Included below are hyper parameters to get equivalent or better results to those included in the original paper. 65 | 66 | If you need to use an earlier version of the codebase, the original code and hyper parameters accessible at the [PyTorch==0.1.12](https://github.com/salesforce/awd-lstm-lm/tree/PyTorch%3D%3D0.1.12) release, with Python 3 and PyTorch 0.1.12 are required. 67 | If you are using Anaconda, installation of PyTorch 0.1.12 can be achieved via: 68 | `conda install pytorch=0.1.12 -c soumith`. 69 | 70 | ## Experiments 71 | 72 | The codebase was modified during the writing of the paper, preventing exact reproduction due to minor differences in random seeds or similar. 73 | We have also seen exact reproduction numbers change when changing underlying GPU. 74 | The guide below produces results largely similar to the numbers reported. 75 | 76 | For data setup, run `./getdata.sh`. 77 | This script collects the Mikolov pre-processed Penn Treebank and the WikiText-2 datasets and places them in the `data` directory. 78 | 79 | Next, decide whether to use the QRNN or the LSTM as the underlying recurrent neural network model. 80 | The QRNN is many times faster than even Nvidia's cuDNN optimized LSTM (and dozens of times faster than a naive LSTM implementation) yet achieves similar or better results than the LSTM for many word level datasets. 81 | At the time of writing, the QRNN models use the same number of parameters and are slightly deeper networks but are two to four times faster per epoch and require less epochs to converge. 82 | 83 | The QRNN model uses a QRNN with convolutional size 2 for the first layer, allowing the model to view discrete natural language inputs (i.e. "New York"), while all other layers use a convolutional size of 1. 84 | 85 | **Finetuning Note:** Fine-tuning modifies the original saved model `model.pt` file - if you wish to keep the original weights you must copy the file. 86 | 87 | **Pointer note:** BPTT just changes the length of the sequence pushed onto the GPU but won't impact the final result. 88 | 89 | ### Character level enwik8 with LSTM 90 | 91 | + `python -u main.py --epochs 50 --nlayers 3 --emsize 400 --nhid 1840 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.1 --dropouti 0.1 --dropout 0.4 --wdrop 0.2 --wdecay 1.2e-6 --bptt 200 --batch_size 128 --optimizer adam --lr 1e-3 --data data/enwik8 --save ENWIK8.pt --when 25 35` 92 | 93 | ### Character level Penn Treebank (PTB) with LSTM 94 | 95 | + `python -u main.py --epochs 500 --nlayers 3 --emsize 200 --nhid 1000 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.25 --dropouti 0.1 --dropout 0.1 --wdrop 0.5 --wdecay 1.2e-6 --bptt 150 --batch_size 128 --optimizer adam --lr 2e-3 --data data/pennchar --save PTBC.pt --when 300 400` 96 | 97 | ### Word level WikiText-103 (WT103) with QRNN 98 | 99 | + `python -u main.py --epochs 14 --nlayers 4 --emsize 400 --nhid 2500 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.1 --dropouti 0.1 --dropout 0.1 --wdrop 0 --wdecay 0 --bptt 140 --batch_size 60 --optimizer adam --lr 1e-3 --data data/wikitext-103 --save WT103.12hr.QRNN.pt --when 12 --model QRNN` 100 | 101 | ### Word level Penn Treebank (PTB) with LSTM 102 | 103 | The instruction below trains a PTB model that without finetuning achieves perplexities of approximately `61.2` / `58.8` (validation / testing), with finetuning achieves perplexities of approximately `58.8` / `56.5`, and with the continuous cache pointer augmentation achieves perplexities of approximately `53.2` / `52.5`. 104 | 105 | + `python main.py --batch_size 20 --data data/penn --dropouti 0.4 --dropouth 0.25 --seed 141 --epoch 500 --save PTB.pt` 106 | + `python finetune.py --batch_size 20 --data data/penn --dropouti 0.4 --dropouth 0.25 --seed 141 --epoch 500 --save PTB.pt` 107 | + `python pointer.py --data data/penn --save PTB.pt --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000` 108 | 109 | ### Word level Penn Treebank (PTB) with QRNN 110 | 111 | The instruction below trains a QRNN model that without finetuning achieves perplexities of approximately `60.6` / `58.3` (validation / testing), with finetuning achieves perplexities of approximately `59.1` / `56.7`, and with the continuous cache pointer augmentation achieves perplexities of approximately `53.4` / `52.6`. 112 | 113 | + `python -u main.py --model QRNN --batch_size 20 --clip 0.2 --wdrop 0.1 --nhid 1550 --nlayers 4 --emsize 400 --dropouth 0.3 --seed 9001 --dropouti 0.4 --epochs 550 --save PTB.pt` 114 | + `python -u finetune.py --model QRNN --batch_size 20 --clip 0.2 --wdrop 0.1 --nhid 1550 --nlayers 4 --emsize 400 --dropouth 0.3 --seed 404 --dropouti 0.4 --epochs 300 --save PTB.pt` 115 | + `python pointer.py --model QRNN --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000 --save PTB.pt` 116 | 117 | ### Word level WikiText-2 (WT2) with LSTM 118 | The instruction below trains a PTB model that without finetuning achieves perplexities of approximately `68.7` / `65.6` (validation / testing), with finetuning achieves perplexities of approximately `67.4` / `64.7`, and with the continuous cache pointer augmentation achieves perplexities of approximately `52.2` / `50.6`. 119 | 120 | + `python main.py --epochs 750 --data data/wikitext-2 --save WT2.pt --dropouth 0.2 --seed 1882` 121 | + `python finetune.py --epochs 750 --data data/wikitext-2 --save WT2.pt --dropouth 0.2 --seed 1882` 122 | + `python pointer.py --save WT2.pt --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2` 123 | 124 | ### Word level WikiText-2 (WT2) with QRNN 125 | 126 | The instruction below will a QRNN model that without finetuning achieves perplexities of approximately `69.3` / `66.8` (validation / testing), with finetuning achieves perplexities of approximately `68.5` / `65.9`, and with the continuous cache pointer augmentation achieves perplexities of approximately `53.6` / `52.1`. 127 | Better numbers are likely achievable but the hyper parameters have not been extensively searched. These hyper parameters should serve as a good starting point however. 128 | 129 | + `python -u main.py --epochs 500 --data data/wikitext-2 --clip 0.25 --dropouti 0.4 --dropouth 0.2 --nhid 1550 --nlayers 4 --seed 4002 --model QRNN --wdrop 0.1 --batch_size 40 --save WT2.pt` 130 | + `python finetune.py --epochs 500 --data data/wikitext-2 --clip 0.25 --dropouti 0.4 --dropouth 0.2 --nhid 1550 --nlayers 4 --seed 4002 --model QRNN --wdrop 0.1 --batch_size 40 --save WT2.pt` 131 | + `python -u pointer.py --save WT2.pt --model QRNN --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2` 132 | 133 | ## Speed 134 | 135 | For speed regarding character-level PTB and enwik8 or word-level WikiText-103, refer to the relevant paper. 136 | 137 | The default speeds for the models during training on an NVIDIA Quadro GP100: 138 | 139 | + Penn Treebank (batch size 20): LSTM takes 65 seconds per epoch, QRNN takes 28 seconds per epoch 140 | + WikiText-2 (batch size 20): LSTM takes 180 seconds per epoch, QRNN takes 90 seconds per epoch 141 | 142 | The default QRNN models can be far faster than the cuDNN LSTM model, with the speed-ups depending on how much of a bottleneck the RNN is. The majority of the model time above is now spent in softmax or optimization overhead (see [PyTorch QRNN discussion on speed](https://github.com/salesforce/pytorch-qrnn#speed)). 143 | 144 | Speeds are approximately three times slower on a K80. On a K80 or other memory cards with less memory you may wish to enable [the cap on the maximum sampled sequence length](https://github.com/salesforce/awd-lstm-lm/blob/ef9369d277f8326b16a9f822adae8480b6d492d0/main.py#L131) to prevent out-of-memory (OOM) errors, especially for WikiText-2. 145 | 146 | If speed is a major issue, SGD converges more quickly than our non-monotonically triggered variant of ASGD though achieves a worse overall perplexity. 147 | 148 | ### Details of the QRNN optimization 149 | 150 | For full details, refer to the [PyTorch QRNN repository](https://github.com/salesforce/pytorch-qrnn). 151 | 152 | ### Details of the LSTM optimization 153 | 154 | All the augmentations to the LSTM, including our variant of [DropConnect (Wan et al. 2013)](https://cs.nyu.edu/~wanli/dropc/dropc.pdf) termed weight dropping which adds recurrent dropout, allow for the use of NVIDIA's cuDNN LSTM implementation. 155 | PyTorch will automatically use the cuDNN backend if run on CUDA with cuDNN installed. 156 | This ensures the model is fast to train even when convergence may take many hundreds of epochs. 157 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | import data 10 | import model 11 | from asgd import ASGD 12 | from model_save import model_load, model_save, model_state_save 13 | from sys_config import BASE_DIR, CKPT_DIR, CACHE_DIR 14 | 15 | from utils import batchify, get_batch, repackage_hidden 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 18 | parser.add_argument('--data', type=str, default='data/wikitext-2', 19 | help='location of the data corpus') 20 | parser.add_argument('--model', type=str, default='LSTM', 21 | help='type of recurrent net (LSTM, QRNN, GRU)') 22 | parser.add_argument('--emsize', type=int, default=400, 23 | help='size of word embeddings') 24 | parser.add_argument('--nhid', type=int, default=1150, 25 | help='number of hidden units per layer') 26 | parser.add_argument('--nlayers', type=int, default=3, 27 | help='number of layers') 28 | parser.add_argument('--lr', type=float, default=30, 29 | help='initial learning rate') 30 | parser.add_argument('--clip', type=float, default=0.25, 31 | help='gradient clipping') 32 | parser.add_argument('--epochs', type=int, default=800, 33 | help='upper epoch limit') 34 | parser.add_argument('--batch_size', type=int, default=32, metavar='N', 35 | help='batch size') 36 | parser.add_argument('--bptt', type=int, default=70, 37 | help='sequence length') 38 | parser.add_argument('--dropout', type=float, default=0.4, 39 | help='dropout applied to layers (0 = no dropout)') 40 | parser.add_argument('--dropouth', type=float, default=0.3, 41 | help='dropout for rnn layers (0 = no dropout)') 42 | parser.add_argument('--dropouti', type=float, default=0.65, 43 | help='dropout for input embedding layers (0 = no dropout)') 44 | parser.add_argument('--dropoute', type=float, default=0.1, 45 | help='dropout to remove words from embedding layer (0 = no dropout)') 46 | parser.add_argument('--wdrop', type=float, default=0.5, 47 | help='amount of weight dropout to apply to the RNN hidden to hidden matrix') 48 | parser.add_argument('--seed', type=int, default=1882, 49 | help='random seed') 50 | parser.add_argument('--nonmono', type=int, default=5, 51 | help='random seed') 52 | parser.add_argument('--cuda', action='store_false', 53 | help='use CUDA') 54 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 55 | help='report interval') 56 | randomhash = ''.join(str(time.time()).split('.')) 57 | parser.add_argument('--save', type=str, default='BLIBLU' + '.pt', 58 | help='path to save the final model') 59 | parser.add_argument('--alpha', type=float, default=2, 60 | help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)') 61 | parser.add_argument('--beta', type=float, default=1, 62 | help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)') 63 | parser.add_argument('--wdecay', type=float, default=1.2e-6, 64 | help='weight decay applied to all weights') 65 | parser.add_argument('--resume', type=str, default='', 66 | help='path of model to resume') 67 | parser.add_argument('--optimizer', type=str, default='sgd', 68 | help='optimizer to use (sgd, adam)') 69 | parser.add_argument('--when', nargs="+", type=int, default=[-1], 70 | help='When (which epochs) to divide the learning rate by 10 - accepts multiple') 71 | parser.add_argument("-g", "--gpu", required=False, 72 | default='1', help="gpu on which this experiment runs") 73 | parser.add_argument("-server", "--server", required=False, 74 | default='ford', help="server on which this experiment runs") 75 | parser.add_argument("-asgd", "--asgd", required=False, 76 | default='True', help="server on which this experiment runs") 77 | args = parser.parse_args() 78 | args.tied = True 79 | 80 | if args.server is 'ford': 81 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 82 | print("\nThis experiment runs on gpu {}...\n".format(args.gpu)) 83 | 84 | ############################################################################### 85 | print("torch:", torch.__version__) 86 | if torch.__version__ != '0.1.12_2': 87 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 88 | print("Cuda:", torch.backends.cudnn.cuda) 89 | print("CuDNN:", torch.backends.cudnn.version()) 90 | print('device: {}'.format(device)) 91 | ############################################################################### 92 | global model, criterion, optimizer 93 | 94 | # Set the random seed manually for reproducibility. 95 | np.random.seed(args.seed) 96 | torch.manual_seed(args.seed) 97 | if torch.cuda.is_available(): 98 | args.cuda = True 99 | if not args.cuda: 100 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 101 | else: 102 | torch.cuda.manual_seed(args.seed) 103 | else: 104 | args.cuda = False 105 | print('No cuda! device is cpu :)') 106 | 107 | ############################################################################### 108 | # Load data 109 | ############################################################################### 110 | print('Base directory: {}'.format(BASE_DIR)) 111 | # fn = 'corpus.{}.data'.format(hashlib.md5(args.data.encode()).hexdigest()) 112 | fn = 'corpus.{}'.format(args.data) 113 | fn = fn.replace('data/', '').replace('wikitext-2', 'wt2') 114 | 115 | fn_path = os.path.join(CACHE_DIR, fn) 116 | if os.path.exists(fn_path): 117 | print('Loading cached dataset...') 118 | corpus = torch.load(fn_path) 119 | else: 120 | print('Producing dataset...') 121 | datapath = os.path.join(BASE_DIR, args.data) 122 | corpus = data.Corpus(datapath) 123 | torch.save(corpus, fn_path) 124 | 125 | eval_batch_size = 10 126 | test_batch_size = 1 127 | train_data = batchify(corpus.train, args.batch_size, args) 128 | val_data = batchify(corpus.valid, eval_batch_size, args) 129 | test_data = batchify(corpus.test, test_batch_size, args) 130 | 131 | vocabulary = corpus.dictionary 132 | 133 | ############################################################################### 134 | # Build the model 135 | ############################################################################### 136 | from splitcross import SplitCrossEntropyLoss 137 | 138 | criterion = None 139 | 140 | ntokens = len(corpus.dictionary) 141 | model = model.AWD(args.model, ntokens, args.emsize, args.nhid, 142 | args.nlayers, args.dropout, args.dropouth, 143 | args.dropouti, args.dropoute, args.wdrop, args.tied) 144 | 145 | ### 146 | if args.resume: 147 | print('Resuming model ...') 148 | model, criterion, optimizer, vocab, val_loss, config = model_load(args.resume) 149 | optimizer.param_groups[0]['lr'] = args.lr 150 | model.dropouti, model.dropouth, model.dropout, args.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute 151 | if args.wdrop: 152 | from weight_drop import WeightDrop 153 | 154 | for rnn in model.rnns: 155 | if type(rnn) == WeightDrop: rnn.dropout = args.wdrop 156 | elif rnn.zoneout > 0: rnn.zoneout = args.wdrop 157 | ### 158 | if not criterion: 159 | splits = [] 160 | if ntokens > 500000: 161 | # One Billion 162 | # This produces fairly even matrix mults for the buckets: 163 | # 0: 11723136, 1: 10854630, 2: 11270961, 3: 11219422 164 | splits = [4200, 35000, 180000] 165 | elif ntokens > 75000: 166 | # WikiText-103 167 | splits = [2800, 20000, 76000] 168 | print('Using splits {}'.format(splits)) 169 | criterion = SplitCrossEntropyLoss(args.emsize, splits=splits, verbose=False) 170 | 171 | # if torch.__version__ != '0.1.12_2': 172 | # print([(name, p.device) for name, p in model.named_parameters()]) 173 | ### 174 | if args.cuda: 175 | model = model.cuda() 176 | criterion = criterion.cuda() 177 | ### 178 | params = list(model.parameters()) + list(criterion.parameters()) 179 | trainable_parameters = [p for p in model.parameters() if p.requires_grad] 180 | total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size()) 181 | print('Args:', args) 182 | print('Model total parameters:', total_params) 183 | 184 | 185 | ############################################################################### 186 | # Training code 187 | ############################################################################### 188 | 189 | def evaluate(data_source, batch_size=10): 190 | # Turn on evaluation mode which disables dropout. 191 | model.eval() 192 | if args.model == 'QRNN': model.reset() 193 | total_loss = 0 194 | ntokens = len(corpus.dictionary) 195 | hidden = model.init_hidden(batch_size) 196 | for i in range(0, data_source.size(0) - 1, args.bptt): 197 | data, targets = get_batch(data_source, i, args, evaluation=True) 198 | output, hidden = model(data, hidden) 199 | total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data 200 | hidden = repackage_hidden(hidden) 201 | return total_loss.item() / len(data_source) 202 | 203 | 204 | def train(): 205 | # Turn on training mode which enables dropout. 206 | if args.model == 'QRNN': model.reset() 207 | total_loss = 0 208 | start_time = time.time() 209 | ntokens = len(corpus.dictionary) 210 | hidden = model.init_hidden(args.batch_size) 211 | batch, i = 0, 0 212 | while i < train_data.size(0) - 1 - 1: 213 | bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. 214 | # Prevent excessively small or negative sequence lengths 215 | seq_len = max(5, int(np.random.normal(bptt, 5))) 216 | # There's a very small chance that it could select a very long sequence length resulting in OOM 217 | # seq_len = min(seq_len, args.bptt + 10) 218 | 219 | lr2 = optimizer.param_groups[0]['lr'] 220 | optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt 221 | model.train() 222 | data, targets = get_batch(train_data, i, args, seq_len=seq_len) 223 | 224 | # Starting each batch, we detach the hidden state from how it was previously produced. 225 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 226 | hidden = repackage_hidden(hidden) 227 | optimizer.zero_grad() 228 | 229 | output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True) 230 | raw_loss = criterion(model.decoder.weight, model.decoder.bias, output, targets) 231 | 232 | loss = raw_loss 233 | # Activation Regularization 234 | if args.alpha: loss = loss + sum( 235 | args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) 236 | # Temporal Activation Regularization (slowness) 237 | if args.beta: loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) 238 | loss.backward() 239 | 240 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 241 | if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip) 242 | optimizer.step() 243 | 244 | total_loss += raw_loss.data 245 | optimizer.param_groups[0]['lr'] = lr2 246 | if batch % args.log_interval == 0 and batch > 0: 247 | cur_loss = total_loss / args.log_interval 248 | elapsed = time.time() - start_time 249 | print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | ' 250 | 'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format( 251 | epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], 252 | elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss), cur_loss / math.log(2))) 253 | total_loss = 0 254 | start_time = time.time() 255 | ### 256 | batch += 1 257 | i += seq_len 258 | 259 | #################################### 260 | if args.cuda: 261 | try: 262 | torch.cuda.empty_cache() 263 | # print('torch cuda empty cache') 264 | except: 265 | pass 266 | #################################### 267 | 268 | 269 | # Loop over epochs. 270 | lr = args.lr 271 | best_val_loss = [] 272 | stored_loss = 100000000 273 | 274 | print('Starting training......') 275 | # At any point you can hit Ctrl + C to break out of training early. 276 | try: 277 | optimizer = None 278 | # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax) 279 | if args.optimizer == 'sgd': 280 | optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay) # params not trainable params... (?) 281 | if args.optimizer == 'adam': 282 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay) 283 | 284 | for epoch in range(1, args.epochs+1): 285 | print('Starting epoch {}'.format(epoch)) 286 | epoch_start_time = time.time() 287 | #################################### 288 | # memory debug 289 | print('Memory before train') 290 | if args.cuda: 291 | print(torch.cuda.get_device_properties(device).total_memory) 292 | print(torch.cuda.memory_cached(device)) 293 | print(torch.cuda.memory_allocated(device)) 294 | #################################### 295 | train() 296 | #################################### 297 | print('Memory after train') 298 | if args.cuda: 299 | print(torch.cuda.get_device_properties(device).total_memory) 300 | print(torch.cuda.memory_cached(device)) 301 | print(torch.cuda.memory_allocated(device)) 302 | #################################### 303 | if args.cuda: 304 | try: 305 | torch.cuda.empty_cache() 306 | # print('torch cuda empty cache') 307 | except: 308 | pass 309 | #################################### 310 | if 't0' in optimizer.param_groups[0]: # if ASGD 311 | tmp = {} 312 | for prm in model.parameters(): 313 | if prm in optimizer.state.keys(): 314 | # tmp[prm] = prm.data.clone() 315 | tmp[prm] = prm.data.detach() 316 | # tmp[prm].copy_(prm.data) 317 | # if 'ax' in optimizer.state[prm]: # added this line because of error: File "main.py", line 268, in prm.data = optimizer.state[prm]['ax'].clone() KeyError: 'ax' 318 | # prm.data = optimizer.state[prm]['ax'].clone() 319 | prm.data = optimizer.state[prm]['ax'].detach() 320 | 321 | # else: 322 | # print(prm) 323 | 324 | # prm.data = optimizer.state[prm]['ax'].clone() 325 | # prm.data = optimizer.state[prm]['ax'].detach() 326 | # prm.data.copy_(optimizer.state[prm]['ax']) 327 | 328 | val_loss2 = evaluate(val_data) 329 | print('-' * 89) 330 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 331 | 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format( 332 | epoch, (time.time() - epoch_start_time), val_loss2, math.exp(val_loss2), val_loss2 / math.log(2))) 333 | print('-' * 89) 334 | 335 | if val_loss2 < stored_loss: 336 | # model_save(os.path.join(CKPT_DIR, args.save), model, criterion, optimizer, 337 | # vocabulary, val_loss2, math.exp(val_loss2), vars(args), epoch) 338 | model_state_save(os.path.join(CKPT_DIR, args.save), model, criterion, optimizer, 339 | vocabulary, val_loss2, math.exp(val_loss2), vars(args), epoch) 340 | print('Saving Averaged!') 341 | stored_loss = val_loss2 342 | 343 | # nparams = 0 344 | # nparams_in_temp_keys = 0 345 | for prm in model.parameters(): 346 | # nparams += 1 347 | if prm in tmp.keys(): 348 | # nparams_in_temp_keys += 1 349 | # prm.data = tmp[prm].clone() 350 | prm.data = tmp[prm].detach() 351 | prm.requires_grad = True 352 | # print('params {}, params in tmp keys: {}'.format(nparams, nparams_in_temp_keys)) 353 | del tmp 354 | else: 355 | print('{} model params (SGD before eval)'.format(len([prm for prm in model.parameters()]))) 356 | val_loss = evaluate(val_data, eval_batch_size) 357 | print('{} model params (SGD after eval)'.format(len([prm for prm in model.parameters()]))) 358 | print('-' * 89) 359 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 360 | 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format( 361 | epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), val_loss / math.log(2))) 362 | print('-' * 89) 363 | 364 | if val_loss < stored_loss: 365 | # model_save(os.path.join(CKPT_DIR, args.save), model, criterion, optimizer, 366 | # vocabulary, val_loss, math.exp(val_loss), vars(args), epoch) 367 | model_state_save(os.path.join(CKPT_DIR, args.save), model, criterion, optimizer, 368 | vocabulary, val_loss, math.exp(val_loss), vars(args), epoch) 369 | print('Saving model (new best validation)') 370 | stored_loss = val_loss 371 | 372 | if args.asgd: 373 | if args.optimizer == 'sgd' and 't0' not in optimizer.param_groups[0] and ( 374 | len(best_val_loss) > args.nonmono and val_loss > min(best_val_loss[:-args.nonmono])): 375 | # if 't0' not in optimizer.param_groups[0]: 376 | print('Switching to ASGD') 377 | # optimizer = ASGD(trainable_parameters, lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 378 | optimizer = ASGD(params, lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 379 | 380 | if epoch in args.when: 381 | print('Saving model before learning rate decreased') 382 | # model_save('{}.e{}'.format(os.path.join(CKPT_DIR, args.save), model, criterion, optimizer, 383 | # vocabulary, val_loss, math.exp(val_loss), vars(args), epoch)) 384 | model_state_save('{}.e{}'.format(os.path.join(CKPT_DIR, args.save), args.save), model, criterion, optimizer, 385 | vocabulary, val_loss, math.exp(val_loss), vars(args), epoch) 386 | print('Dividing learning rate by 10') 387 | optimizer.param_groups[0]['lr'] /= 10. 388 | 389 | best_val_loss.append(val_loss) 390 | 391 | except KeyboardInterrupt: 392 | print('-' * 89) 393 | print('Exiting from training early') 394 | 395 | # Load the best saved model. 396 | model_load(os.path.join(CKPT_DIR, args.save)) 397 | 398 | # Run on test data. 399 | test_loss = evaluate(test_data, test_batch_size) 400 | print('=' * 89) 401 | print('| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'.format( 402 | test_loss, math.exp(test_loss), test_loss / math.log(2))) 403 | print('=' * 89) 404 | --------------------------------------------------------------------------------