├── .gitignore ├── CODEOWNERS ├── LICENSE ├── README.md ├── data.py ├── data └── enwik8 │ └── prep_enwik8.py ├── embed_regularize.py ├── finetune.py ├── generate.py ├── getdata.sh ├── locked_dropout.py ├── main.py ├── model.py ├── pointer.py ├── splitcross.py ├── utils.py └── weight_drop.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSTM and QRNN Language Model Toolkit 2 | 3 | This repository contains the code used for two [Salesforce Research](https://einstein.ai/) papers: 4 | + [Regularizing and Optimizing LSTM Language Models](https://arxiv.org/abs/1708.02182) 5 | + [An Analysis of Neural Language Modeling at Multiple Scales](https://arxiv.org/abs/1803.08240) 6 | This code was originally forked from the [PyTorch word level language modeling example](https://github.com/pytorch/examples/tree/master/word_language_model). 7 | 8 | The model comes with instructions to train: 9 | + word level language models over the Penn Treebank (PTB), [WikiText-2](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) (WT2), and [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) (WT103) datasets 10 | 11 | + character level language models over the Penn Treebank (PTBC) and Hutter Prize dataset (enwik8) 12 | 13 | The model can be composed of an LSTM or a [Quasi-Recurrent Neural Network](https://github.com/salesforce/pytorch-qrnn/) (QRNN) which is two or more times faster than the cuDNN LSTM in this setup while achieving equivalent or better accuracy. 14 | 15 | + Install PyTorch 0.4 16 | + Run `getdata.sh` to acquire the Penn Treebank and WikiText-2 datasets 17 | + Train the base model using `main.py` 18 | + (Optionally) Finetune the model using `finetune.py` 19 | + (Optionally) Apply the [continuous cache pointer](https://arxiv.org/abs/1612.04426) to the finetuned model using `pointer.py` 20 | 21 | If you use this code or our results in your research, please cite as appropriate: 22 | 23 | ``` 24 | @article{merityRegOpt, 25 | title={{Regularizing and Optimizing LSTM Language Models}}, 26 | author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard}, 27 | journal={arXiv preprint arXiv:1708.02182}, 28 | year={2017} 29 | } 30 | ``` 31 | 32 | ``` 33 | @article{merityAnalysis, 34 | title={{An Analysis of Neural Language Modeling at Multiple Scales}}, 35 | author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard}, 36 | journal={arXiv preprint arXiv:1803.08240}, 37 | year={2018} 38 | } 39 | ``` 40 | ## Update (June/13/2018) 41 | 42 | The codebase is now PyTorch 0.4 compatible for most use cases (a big shoutout to https://github.com/shawntan for a fairly comprehensive PR https://github.com/salesforce/awd-lstm-lm/pull/43). Mild readjustments to hyperparameters may be necessary to obtain quoted performance. If you desire exact reproducibility (or wish to run on PyTorch 0.3 or lower), we suggest using an older commit of this repository. We are still working on `pointer`, `finetune` and `generate` functionalities. 43 | 44 | ## Software Requirements 45 | 46 | Python 3 and PyTorch 0.4 are required for the current codebase. 47 | 48 | Included below are hyper parameters to get equivalent or better results to those included in the original paper. 49 | 50 | If you need to use an earlier version of the codebase, the original code and hyper parameters accessible at the [PyTorch==0.1.12](https://github.com/salesforce/awd-lstm-lm/tree/PyTorch%3D%3D0.1.12) release, with Python 3 and PyTorch 0.1.12 are required. 51 | If you are using Anaconda, installation of PyTorch 0.1.12 can be achieved via: 52 | `conda install pytorch=0.1.12 -c soumith`. 53 | 54 | ## Experiments 55 | 56 | The codebase was modified during the writing of the paper, preventing exact reproduction due to minor differences in random seeds or similar. 57 | We have also seen exact reproduction numbers change when changing underlying GPU. 58 | The guide below produces results largely similar to the numbers reported. 59 | 60 | For data setup, run `./getdata.sh`. 61 | This script collects the Mikolov pre-processed Penn Treebank and the WikiText-2 datasets and places them in the `data` directory. 62 | 63 | Next, decide whether to use the QRNN or the LSTM as the underlying recurrent neural network model. 64 | The QRNN is many times faster than even Nvidia's cuDNN optimized LSTM (and dozens of times faster than a naive LSTM implementation) yet achieves similar or better results than the LSTM for many word level datasets. 65 | At the time of writing, the QRNN models use the same number of parameters and are slightly deeper networks but are two to four times faster per epoch and require less epochs to converge. 66 | 67 | The QRNN model uses a QRNN with convolutional size 2 for the first layer, allowing the model to view discrete natural language inputs (i.e. "New York"), while all other layers use a convolutional size of 1. 68 | 69 | **Finetuning Note:** Fine-tuning modifies the original saved model `model.pt` file - if you wish to keep the original weights you must copy the file. 70 | 71 | **Pointer note:** BPTT just changes the length of the sequence pushed onto the GPU but won't impact the final result. 72 | 73 | ### Character level enwik8 with LSTM 74 | 75 | + `python -u main.py --epochs 50 --nlayers 3 --emsize 400 --nhid 1840 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.1 --dropouti 0.1 --dropout 0.4 --wdrop 0.2 --wdecay 1.2e-6 --bptt 200 --batch_size 128 --optimizer adam --lr 1e-3 --data data/enwik8 --save ENWIK8.pt --when 25 35` 76 | 77 | ### Character level Penn Treebank (PTB) with LSTM 78 | 79 | + `python -u main.py --epochs 500 --nlayers 3 --emsize 200 --nhid 1000 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.25 --dropouti 0.1 --dropout 0.1 --wdrop 0.5 --wdecay 1.2e-6 --bptt 150 --batch_size 128 --optimizer adam --lr 2e-3 --data data/pennchar --save PTBC.pt --when 300 400` 80 | 81 | ### Word level WikiText-103 (WT103) with QRNN 82 | 83 | + `python -u main.py --epochs 14 --nlayers 4 --emsize 400 --nhid 2500 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.1 --dropouti 0.1 --dropout 0.1 --wdrop 0 --wdecay 0 --bptt 140 --batch_size 60 --optimizer adam --lr 1e-3 --data data/wikitext-103 --save WT103.12hr.QRNN.pt --when 12 --model QRNN` 84 | 85 | ### Word level Penn Treebank (PTB) with LSTM 86 | 87 | The instruction below trains a PTB model that without finetuning achieves perplexities of approximately `61.2` / `58.8` (validation / testing), with finetuning achieves perplexities of approximately `58.8` / `56.5`, and with the continuous cache pointer augmentation achieves perplexities of approximately `53.2` / `52.5`. 88 | 89 | + `python main.py --batch_size 20 --data data/penn --dropouti 0.4 --dropouth 0.25 --seed 141 --epoch 500 --save PTB.pt` 90 | + `python finetune.py --batch_size 20 --data data/penn --dropouti 0.4 --dropouth 0.25 --seed 141 --epoch 500 --save PTB.pt` 91 | + `python pointer.py --data data/penn --save PTB.pt --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000` 92 | 93 | ### Word level Penn Treebank (PTB) with QRNN 94 | 95 | The instruction below trains a QRNN model that without finetuning achieves perplexities of approximately `60.6` / `58.3` (validation / testing), with finetuning achieves perplexities of approximately `59.1` / `56.7`, and with the continuous cache pointer augmentation achieves perplexities of approximately `53.4` / `52.6`. 96 | 97 | + `python -u main.py --model QRNN --batch_size 20 --clip 0.2 --wdrop 0.1 --nhid 1550 --nlayers 4 --emsize 400 --dropouth 0.3 --seed 9001 --dropouti 0.4 --epochs 550 --save PTB.pt` 98 | + `python -u finetune.py --model QRNN --batch_size 20 --clip 0.2 --wdrop 0.1 --nhid 1550 --nlayers 4 --emsize 400 --dropouth 0.3 --seed 404 --dropouti 0.4 --epochs 300 --save PTB.pt` 99 | + `python pointer.py --model QRNN --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000 --save PTB.pt` 100 | 101 | ### Word level WikiText-2 (WT2) with LSTM 102 | The instruction below trains a PTB model that without finetuning achieves perplexities of approximately `68.7` / `65.6` (validation / testing), with finetuning achieves perplexities of approximately `67.4` / `64.7`, and with the continuous cache pointer augmentation achieves perplexities of approximately `52.2` / `50.6`. 103 | 104 | + `python main.py --epochs 750 --data data/wikitext-2 --save WT2.pt --dropouth 0.2 --seed 1882` 105 | + `python finetune.py --epochs 750 --data data/wikitext-2 --save WT2.pt --dropouth 0.2 --seed 1882` 106 | + `python pointer.py --save WT2.pt --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2` 107 | 108 | ### Word level WikiText-2 (WT2) with QRNN 109 | 110 | The instruction below will a QRNN model that without finetuning achieves perplexities of approximately `69.3` / `66.8` (validation / testing), with finetuning achieves perplexities of approximately `68.5` / `65.9`, and with the continuous cache pointer augmentation achieves perplexities of approximately `53.6` / `52.1`. 111 | Better numbers are likely achievable but the hyper parameters have not been extensively searched. These hyper parameters should serve as a good starting point however. 112 | 113 | + `python -u main.py --epochs 500 --data data/wikitext-2 --clip 0.25 --dropouti 0.4 --dropouth 0.2 --nhid 1550 --nlayers 4 --seed 4002 --model QRNN --wdrop 0.1 --batch_size 40 --save WT2.pt` 114 | + `python finetune.py --epochs 500 --data data/wikitext-2 --clip 0.25 --dropouti 0.4 --dropouth 0.2 --nhid 1550 --nlayers 4 --seed 4002 --model QRNN --wdrop 0.1 --batch_size 40 --save WT2.pt` 115 | + `python -u pointer.py --save WT2.pt --model QRNN --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2` 116 | 117 | ## Speed 118 | 119 | For speed regarding character-level PTB and enwik8 or word-level WikiText-103, refer to the relevant paper. 120 | 121 | The default speeds for the models during training on an NVIDIA Quadro GP100: 122 | 123 | + Penn Treebank (batch size 20): LSTM takes 65 seconds per epoch, QRNN takes 28 seconds per epoch 124 | + WikiText-2 (batch size 20): LSTM takes 180 seconds per epoch, QRNN takes 90 seconds per epoch 125 | 126 | The default QRNN models can be far faster than the cuDNN LSTM model, with the speed-ups depending on how much of a bottleneck the RNN is. The majority of the model time above is now spent in softmax or optimization overhead (see [PyTorch QRNN discussion on speed](https://github.com/salesforce/pytorch-qrnn#speed)). 127 | 128 | Speeds are approximately three times slower on a K80. On a K80 or other memory cards with less memory you may wish to enable [the cap on the maximum sampled sequence length](https://github.com/salesforce/awd-lstm-lm/blob/ef9369d277f8326b16a9f822adae8480b6d492d0/main.py#L131) to prevent out-of-memory (OOM) errors, especially for WikiText-2. 129 | 130 | If speed is a major issue, SGD converges more quickly than our non-monotonically triggered variant of ASGD though achieves a worse overall perplexity. 131 | 132 | ### Details of the QRNN optimization 133 | 134 | For full details, refer to the [PyTorch QRNN repository](https://github.com/salesforce/pytorch-qrnn). 135 | 136 | ### Details of the LSTM optimization 137 | 138 | All the augmentations to the LSTM, including our variant of [DropConnect (Wan et al. 2013)](https://cs.nyu.edu/~wanli/dropc/dropc.pdf) termed weight dropping which adds recurrent dropout, allow for the use of NVIDIA's cuDNN LSTM implementation. 139 | PyTorch will automatically use the cuDNN backend if run on CUDA with cuDNN installed. 140 | This ensures the model is fast to train even when convergence may take many hundreds of epochs. 141 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from collections import Counter 5 | 6 | 7 | class Dictionary(object): 8 | def __init__(self): 9 | self.word2idx = {} 10 | self.idx2word = [] 11 | self.counter = Counter() 12 | self.total = 0 13 | 14 | def add_word(self, word): 15 | if word not in self.word2idx: 16 | self.idx2word.append(word) 17 | self.word2idx[word] = len(self.idx2word) - 1 18 | token_id = self.word2idx[word] 19 | self.counter[token_id] += 1 20 | self.total += 1 21 | return self.word2idx[word] 22 | 23 | def __len__(self): 24 | return len(self.idx2word) 25 | 26 | 27 | class Corpus(object): 28 | def __init__(self, path): 29 | self.dictionary = Dictionary() 30 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 31 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 32 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 33 | 34 | def tokenize(self, path): 35 | """Tokenizes a text file.""" 36 | assert os.path.exists(path) 37 | # Add words to the dictionary 38 | with open(path, 'r') as f: 39 | tokens = 0 40 | for line in f: 41 | words = line.split() + [''] 42 | tokens += len(words) 43 | for word in words: 44 | self.dictionary.add_word(word) 45 | 46 | # Tokenize file content 47 | with open(path, 'r') as f: 48 | ids = torch.LongTensor(tokens) 49 | token = 0 50 | for line in f: 51 | words = line.split() + [''] 52 | for word in words: 53 | ids[token] = self.dictionary.word2idx[word] 54 | token += 1 55 | 56 | return ids 57 | -------------------------------------------------------------------------------- /data/enwik8/prep_enwik8.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import os 5 | import sys 6 | import zipfile 7 | 8 | if os.path.exists('train.txt'): 9 | print('Tokenized enwik8 already exists - skipping processing') 10 | sys.exit() 11 | 12 | data = zipfile.ZipFile('enwik8.zip').read('enwik8') 13 | 14 | print('Length of enwik8: {}'.format(len(data))) 15 | 16 | num_test_chars = 5000000 17 | 18 | train_data = data[: -2 * num_test_chars] 19 | valid_data = data[-2 * num_test_chars: -num_test_chars] 20 | test_data = data[-num_test_chars:] 21 | 22 | for fn, part in [('train.txt', train_data), ('valid.txt', valid_data), ('test.txt', test_data)]: 23 | print('{} will have {} bytes'.format(fn, len(part))) 24 | print('- Tokenizing...') 25 | part_str = ' '.join([str(c) if c != ord('\n') else '\n' for c in part]) 26 | print('- Writing...') 27 | f = open(fn, 'w').write(part_str) 28 | f = open(fn + '.raw', 'wb').write(part) 29 | -------------------------------------------------------------------------------- /embed_regularize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | def embedded_dropout(embed, words, dropout=0.1, scale=None): 6 | if dropout: 7 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) 8 | masked_embed_weight = mask * embed.weight 9 | else: 10 | masked_embed_weight = embed.weight 11 | if scale: 12 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight 13 | 14 | padding_idx = embed.padding_idx 15 | if padding_idx is None: 16 | padding_idx = -1 17 | 18 | X = torch.nn.functional.embedding(words, masked_embed_weight, 19 | padding_idx, embed.max_norm, embed.norm_type, 20 | embed.scale_grad_by_freq, embed.sparse 21 | ) 22 | return X 23 | 24 | if __name__ == '__main__': 25 | V = 50 26 | h = 4 27 | bptt = 10 28 | batch_size = 2 29 | 30 | embed = torch.nn.Embedding(V, h) 31 | 32 | words = np.random.random_integers(low=0, high=V-1, size=(batch_size, bptt)) 33 | words = torch.LongTensor(words) 34 | 35 | origX = embed(words) 36 | X = embedded_dropout(embed, words) 37 | 38 | print(origX) 39 | print(X) 40 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | import numpy as np 5 | np.random.seed(331) 6 | import torch 7 | import torch.nn as nn 8 | 9 | import data 10 | import model 11 | 12 | from utils import batchify, get_batch, repackage_hidden 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 15 | parser.add_argument('--data', type=str, default='data/penn/', 16 | help='location of the data corpus') 17 | parser.add_argument('--model', type=str, default='LSTM', 18 | help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)') 19 | parser.add_argument('--emsize', type=int, default=400, 20 | help='size of word embeddings') 21 | parser.add_argument('--nhid', type=int, default=1150, 22 | help='number of hidden units per layer') 23 | parser.add_argument('--nlayers', type=int, default=3, 24 | help='number of layers') 25 | parser.add_argument('--lr', type=float, default=30, 26 | help='initial learning rate') 27 | parser.add_argument('--clip', type=float, default=0.25, 28 | help='gradient clipping') 29 | parser.add_argument('--epochs', type=int, default=8000, 30 | help='upper epoch limit') 31 | parser.add_argument('--batch_size', type=int, default=80, metavar='N', 32 | help='batch size') 33 | parser.add_argument('--bptt', type=int, default=70, 34 | help='sequence length') 35 | parser.add_argument('--dropout', type=float, default=0.4, 36 | help='dropout applied to layers (0 = no dropout)') 37 | parser.add_argument('--dropouth', type=float, default=0.3, 38 | help='dropout for rnn layers (0 = no dropout)') 39 | parser.add_argument('--dropouti', type=float, default=0.65, 40 | help='dropout for input embedding layers (0 = no dropout)') 41 | parser.add_argument('--dropoute', type=float, default=0.1, 42 | help='dropout to remove words from embedding layer (0 = no dropout)') 43 | parser.add_argument('--wdrop', type=float, default=0.5, 44 | help='amount of weight dropout to apply to the RNN hidden to hidden matrix') 45 | parser.add_argument('--tied', action='store_false', 46 | help='tie the word embedding and softmax weights') 47 | parser.add_argument('--seed', type=int, default=1111, 48 | help='random seed') 49 | parser.add_argument('--nonmono', type=int, default=5, 50 | help='random seed') 51 | parser.add_argument('--cuda', action='store_false', 52 | help='use CUDA') 53 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 54 | help='report interval') 55 | randomhash = ''.join(str(time.time()).split('.')) 56 | parser.add_argument('--save', type=str, default=randomhash+'.pt', 57 | help='path to save the final model') 58 | parser.add_argument('--alpha', type=float, default=2, 59 | help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)') 60 | parser.add_argument('--beta', type=float, default=1, 61 | help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)') 62 | parser.add_argument('--wdecay', type=float, default=1.2e-6, 63 | help='weight decay applied to all weights') 64 | args = parser.parse_args() 65 | 66 | # Set the random seed manually for reproducibility. 67 | torch.manual_seed(args.seed) 68 | if torch.cuda.is_available(): 69 | if not args.cuda: 70 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 71 | else: 72 | torch.cuda.manual_seed(args.seed) 73 | 74 | ############################################################################### 75 | # Load data 76 | ############################################################################### 77 | 78 | corpus = data.Corpus(args.data) 79 | 80 | eval_batch_size = 10 81 | test_batch_size = 1 82 | train_data = batchify(corpus.train, args.batch_size, args) 83 | val_data = batchify(corpus.valid, eval_batch_size, args) 84 | test_data = batchify(corpus.test, test_batch_size, args) 85 | 86 | ############################################################################### 87 | # Build the model 88 | ############################################################################### 89 | 90 | ntokens = len(corpus.dictionary) 91 | model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.dropouth, args.dropouti, args.dropoute, args.wdrop, args.tied) 92 | if args.cuda: 93 | model.cuda() 94 | total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in model.parameters()) 95 | print('Args:', args) 96 | print('Model total parameters:', total_params) 97 | 98 | criterion = nn.CrossEntropyLoss() 99 | 100 | ############################################################################### 101 | # Training code 102 | ############################################################################### 103 | 104 | def evaluate(data_source, batch_size=10): 105 | # Turn on evaluation mode which disables dropout. 106 | if args.model == 'QRNN': model.reset() 107 | model.eval() 108 | total_loss = 0 109 | ntokens = len(corpus.dictionary) 110 | hidden = model.init_hidden(batch_size) 111 | for i in range(0, data_source.size(0) - 1, args.bptt): 112 | data, targets = get_batch(data_source, i, args, evaluation=True) 113 | output, hidden = model(data, hidden) 114 | output_flat = output.view(-1, ntokens) 115 | total_loss += len(data) * criterion(output_flat, targets).data 116 | hidden = repackage_hidden(hidden) 117 | return total_loss[0] / len(data_source) 118 | 119 | 120 | def train(): 121 | # Turn on training mode which enables dropout. 122 | if args.model == 'QRNN': model.reset() 123 | total_loss = 0 124 | start_time = time.time() 125 | ntokens = len(corpus.dictionary) 126 | hidden = model.init_hidden(args.batch_size) 127 | batch, i = 0, 0 128 | while i < train_data.size(0) - 1 - 1: 129 | bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. 130 | # Prevent excessively small or negative sequence lengths 131 | seq_len = max(5, int(np.random.normal(bptt, 5))) 132 | # There's a very small chance that it could select a very long sequence length resulting in OOM 133 | seq_len = min(seq_len, args.bptt + 10) 134 | 135 | lr2 = optimizer.param_groups[0]['lr'] 136 | optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt 137 | model.train() 138 | data, targets = get_batch(train_data, i, args, seq_len=seq_len) 139 | 140 | # Starting each batch, we detach the hidden state from how it was previously produced. 141 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 142 | hidden = repackage_hidden(hidden) 143 | optimizer.zero_grad() 144 | 145 | output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True) 146 | raw_loss = criterion(output.view(-1, ntokens), targets) 147 | 148 | loss = raw_loss 149 | # Activiation Regularization 150 | loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) 151 | # Temporal Activation Regularization (slowness) 152 | loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) 153 | loss.backward() 154 | 155 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 156 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) 157 | optimizer.step() 158 | 159 | total_loss += raw_loss.data 160 | optimizer.param_groups[0]['lr'] = lr2 161 | if batch % args.log_interval == 0 and batch > 0: 162 | cur_loss = total_loss[0] / args.log_interval 163 | elapsed = time.time() - start_time 164 | print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 165 | 'loss {:5.2f} | ppl {:8.2f}'.format( 166 | epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], 167 | elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) 168 | total_loss = 0 169 | start_time = time.time() 170 | ### 171 | batch += 1 172 | i += seq_len 173 | 174 | 175 | # Load the best saved model. 176 | with open(args.save, 'rb') as f: 177 | model = torch.load(f) 178 | 179 | 180 | # Loop over epochs. 181 | lr = args.lr 182 | stored_loss = evaluate(val_data) 183 | best_val_loss = [] 184 | # At any point you can hit Ctrl + C to break out of training early. 185 | try: 186 | #optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 187 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 188 | for epoch in range(1, args.epochs+1): 189 | epoch_start_time = time.time() 190 | train() 191 | if 't0' in optimizer.param_groups[0]: 192 | tmp = {} 193 | for prm in model.parameters(): 194 | tmp[prm] = prm.data.clone() 195 | prm.data = optimizer.state[prm]['ax'].clone() 196 | 197 | val_loss2 = evaluate(val_data) 198 | print('-' * 89) 199 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 200 | 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 201 | val_loss2, math.exp(val_loss2))) 202 | print('-' * 89) 203 | 204 | if val_loss2 < stored_loss: 205 | with open(args.save, 'wb') as f: 206 | torch.save(model, f) 207 | print('Saving Averaged!') 208 | stored_loss = val_loss2 209 | 210 | for prm in model.parameters(): 211 | prm.data = tmp[prm].clone() 212 | 213 | if (len(best_val_loss)>args.nonmono and val_loss2 > min(best_val_loss[:-args.nonmono])): 214 | print('Done!') 215 | import sys 216 | sys.exit(1) 217 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 218 | #optimizer.param_groups[0]['lr'] /= 2. 219 | best_val_loss.append(val_loss2) 220 | 221 | except KeyboardInterrupt: 222 | print('-' * 89) 223 | print('Exiting from training early') 224 | 225 | # Load the best saved model. 226 | with open(args.save, 'rb') as f: 227 | model = torch.load(f) 228 | 229 | # Run on test data. 230 | test_loss = evaluate(test_data, test_batch_size) 231 | print('=' * 89) 232 | print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 233 | test_loss, math.exp(test_loss))) 234 | print('=' * 89) 235 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Language Modeling on Penn Tree Bank 3 | # 4 | # This file generates new sentences sampled from the language model 5 | # 6 | ############################################################################### 7 | 8 | import argparse 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | 13 | import data 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch PTB Language Model') 16 | 17 | # Model parameters. 18 | parser.add_argument('--data', type=str, default='./data/penn', 19 | help='location of the data corpus') 20 | parser.add_argument('--model', type=str, default='LSTM', 21 | help='type of recurrent net (LSTM, QRNN)') 22 | parser.add_argument('--checkpoint', type=str, default='./model.pt', 23 | help='model checkpoint to use') 24 | parser.add_argument('--outf', type=str, default='generated.txt', 25 | help='output file for generated text') 26 | parser.add_argument('--words', type=int, default='1000', 27 | help='number of words to generate') 28 | parser.add_argument('--seed', type=int, default=1111, 29 | help='random seed') 30 | parser.add_argument('--cuda', action='store_true', 31 | help='use CUDA') 32 | parser.add_argument('--temperature', type=float, default=1.0, 33 | help='temperature - higher will increase diversity') 34 | parser.add_argument('--log-interval', type=int, default=100, 35 | help='reporting interval') 36 | args = parser.parse_args() 37 | 38 | # Set the random seed manually for reproducibility. 39 | torch.manual_seed(args.seed) 40 | if torch.cuda.is_available(): 41 | if not args.cuda: 42 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 43 | else: 44 | torch.cuda.manual_seed(args.seed) 45 | 46 | if args.temperature < 1e-3: 47 | parser.error("--temperature has to be greater or equal 1e-3") 48 | 49 | with open(args.checkpoint, 'rb') as f: 50 | model = torch.load(f) 51 | model.eval() 52 | if args.model == 'QRNN': 53 | model.reset() 54 | 55 | if args.cuda: 56 | model.cuda() 57 | else: 58 | model.cpu() 59 | 60 | corpus = data.Corpus(args.data) 61 | ntokens = len(corpus.dictionary) 62 | hidden = model.init_hidden(1) 63 | input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) 64 | if args.cuda: 65 | input.data = input.data.cuda() 66 | 67 | with open(args.outf, 'w') as outf: 68 | for i in range(args.words): 69 | output, hidden = model(input, hidden) 70 | word_weights = output.squeeze().data.div(args.temperature).exp().cpu() 71 | word_idx = torch.multinomial(word_weights, 1)[0] 72 | input.data.fill_(word_idx) 73 | word = corpus.dictionary.idx2word[word_idx] 74 | 75 | outf.write(word + ('\n' if i % 20 == 19 else ' ')) 76 | 77 | if i % args.log_interval == 0: 78 | print('| Generated {}/{} words'.format(i, args.words)) 79 | -------------------------------------------------------------------------------- /getdata.sh: -------------------------------------------------------------------------------- 1 | echo "=== Acquiring datasets ===" 2 | echo "---" 3 | mkdir -p save 4 | 5 | mkdir -p data 6 | cd data 7 | 8 | echo "- Downloading WikiText-2 (WT2)" 9 | wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip 10 | unzip -q wikitext-2-v1.zip 11 | cd wikitext-2 12 | mv wiki.train.tokens train.txt 13 | mv wiki.valid.tokens valid.txt 14 | mv wiki.test.tokens test.txt 15 | cd .. 16 | 17 | echo "- Downloading WikiText-103 (WT2)" 18 | wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip 19 | unzip -q wikitext-103-v1.zip 20 | cd wikitext-103 21 | mv wiki.train.tokens train.txt 22 | mv wiki.valid.tokens valid.txt 23 | mv wiki.test.tokens test.txt 24 | cd .. 25 | 26 | echo "- Downloading enwik8 (Character)" 27 | mkdir -p enwik8 28 | cd enwik8 29 | wget --continue http://mattmahoney.net/dc/enwik8.zip 30 | python prep_enwik8.py 31 | cd .. 32 | 33 | echo "- Downloading Penn Treebank (PTB)" 34 | wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz 35 | tar -xzf simple-examples.tgz 36 | 37 | mkdir -p penn 38 | cd penn 39 | mv ../simple-examples/data/ptb.train.txt train.txt 40 | mv ../simple-examples/data/ptb.test.txt test.txt 41 | mv ../simple-examples/data/ptb.valid.txt valid.txt 42 | cd .. 43 | 44 | echo "- Downloading Penn Treebank (Character)" 45 | mkdir -p pennchar 46 | cd pennchar 47 | mv ../simple-examples/data/ptb.char.train.txt train.txt 48 | mv ../simple-examples/data/ptb.char.test.txt test.txt 49 | mv ../simple-examples/data/ptb.char.valid.txt valid.txt 50 | cd .. 51 | 52 | rm -rf simple-examples/ 53 | 54 | echo "---" 55 | echo "Happy language modeling :)" 56 | -------------------------------------------------------------------------------- /locked_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | class LockedDropout(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, x, dropout=0.5): 10 | if not self.training or not dropout: 11 | return x 12 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) 13 | mask = Variable(m, requires_grad=False) / (1 - dropout) 14 | mask = mask.expand_as(x) 15 | return mask * x 16 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | import data 9 | import model 10 | 11 | from utils import batchify, get_batch, repackage_hidden 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 14 | parser.add_argument('--data', type=str, default='data/penn/', 15 | help='location of the data corpus') 16 | parser.add_argument('--model', type=str, default='LSTM', 17 | help='type of recurrent net (LSTM, QRNN, GRU)') 18 | parser.add_argument('--emsize', type=int, default=400, 19 | help='size of word embeddings') 20 | parser.add_argument('--nhid', type=int, default=1150, 21 | help='number of hidden units per layer') 22 | parser.add_argument('--nlayers', type=int, default=3, 23 | help='number of layers') 24 | parser.add_argument('--lr', type=float, default=30, 25 | help='initial learning rate') 26 | parser.add_argument('--clip', type=float, default=0.25, 27 | help='gradient clipping') 28 | parser.add_argument('--epochs', type=int, default=8000, 29 | help='upper epoch limit') 30 | parser.add_argument('--batch_size', type=int, default=80, metavar='N', 31 | help='batch size') 32 | parser.add_argument('--bptt', type=int, default=70, 33 | help='sequence length') 34 | parser.add_argument('--dropout', type=float, default=0.4, 35 | help='dropout applied to layers (0 = no dropout)') 36 | parser.add_argument('--dropouth', type=float, default=0.3, 37 | help='dropout for rnn layers (0 = no dropout)') 38 | parser.add_argument('--dropouti', type=float, default=0.65, 39 | help='dropout for input embedding layers (0 = no dropout)') 40 | parser.add_argument('--dropoute', type=float, default=0.1, 41 | help='dropout to remove words from embedding layer (0 = no dropout)') 42 | parser.add_argument('--wdrop', type=float, default=0.5, 43 | help='amount of weight dropout to apply to the RNN hidden to hidden matrix') 44 | parser.add_argument('--seed', type=int, default=1111, 45 | help='random seed') 46 | parser.add_argument('--nonmono', type=int, default=5, 47 | help='random seed') 48 | parser.add_argument('--cuda', action='store_false', 49 | help='use CUDA') 50 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 51 | help='report interval') 52 | randomhash = ''.join(str(time.time()).split('.')) 53 | parser.add_argument('--save', type=str, default=randomhash+'.pt', 54 | help='path to save the final model') 55 | parser.add_argument('--alpha', type=float, default=2, 56 | help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)') 57 | parser.add_argument('--beta', type=float, default=1, 58 | help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)') 59 | parser.add_argument('--wdecay', type=float, default=1.2e-6, 60 | help='weight decay applied to all weights') 61 | parser.add_argument('--resume', type=str, default='', 62 | help='path of model to resume') 63 | parser.add_argument('--optimizer', type=str, default='sgd', 64 | help='optimizer to use (sgd, adam)') 65 | parser.add_argument('--when', nargs="+", type=int, default=[-1], 66 | help='When (which epochs) to divide the learning rate by 10 - accepts multiple') 67 | args = parser.parse_args() 68 | args.tied = True 69 | 70 | # Set the random seed manually for reproducibility. 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | if torch.cuda.is_available(): 74 | if not args.cuda: 75 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 76 | else: 77 | torch.cuda.manual_seed(args.seed) 78 | 79 | ############################################################################### 80 | # Load data 81 | ############################################################################### 82 | 83 | def model_save(fn): 84 | with open(fn, 'wb') as f: 85 | torch.save([model, criterion, optimizer], f) 86 | 87 | def model_load(fn): 88 | global model, criterion, optimizer 89 | with open(fn, 'rb') as f: 90 | model, criterion, optimizer = torch.load(f) 91 | 92 | import os 93 | import hashlib 94 | fn = 'corpus.{}.data'.format(hashlib.md5(args.data.encode()).hexdigest()) 95 | if os.path.exists(fn): 96 | print('Loading cached dataset...') 97 | corpus = torch.load(fn) 98 | else: 99 | print('Producing dataset...') 100 | corpus = data.Corpus(args.data) 101 | torch.save(corpus, fn) 102 | 103 | eval_batch_size = 10 104 | test_batch_size = 1 105 | train_data = batchify(corpus.train, args.batch_size, args) 106 | val_data = batchify(corpus.valid, eval_batch_size, args) 107 | test_data = batchify(corpus.test, test_batch_size, args) 108 | 109 | ############################################################################### 110 | # Build the model 111 | ############################################################################### 112 | 113 | from splitcross import SplitCrossEntropyLoss 114 | criterion = None 115 | 116 | ntokens = len(corpus.dictionary) 117 | model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.dropouth, args.dropouti, args.dropoute, args.wdrop, args.tied) 118 | ### 119 | if args.resume: 120 | print('Resuming model ...') 121 | model_load(args.resume) 122 | optimizer.param_groups[0]['lr'] = args.lr 123 | model.dropouti, model.dropouth, model.dropout, args.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute 124 | if args.wdrop: 125 | from weight_drop import WeightDrop 126 | for rnn in model.rnns: 127 | if type(rnn) == WeightDrop: rnn.dropout = args.wdrop 128 | elif rnn.zoneout > 0: rnn.zoneout = args.wdrop 129 | ### 130 | if not criterion: 131 | splits = [] 132 | if ntokens > 500000: 133 | # One Billion 134 | # This produces fairly even matrix mults for the buckets: 135 | # 0: 11723136, 1: 10854630, 2: 11270961, 3: 11219422 136 | splits = [4200, 35000, 180000] 137 | elif ntokens > 75000: 138 | # WikiText-103 139 | splits = [2800, 20000, 76000] 140 | print('Using', splits) 141 | criterion = SplitCrossEntropyLoss(args.emsize, splits=splits, verbose=False) 142 | ### 143 | if args.cuda: 144 | model = model.cuda() 145 | criterion = criterion.cuda() 146 | ### 147 | params = list(model.parameters()) + list(criterion.parameters()) 148 | total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size()) 149 | print('Args:', args) 150 | print('Model total parameters:', total_params) 151 | 152 | ############################################################################### 153 | # Training code 154 | ############################################################################### 155 | 156 | def evaluate(data_source, batch_size=10): 157 | # Turn on evaluation mode which disables dropout. 158 | model.eval() 159 | if args.model == 'QRNN': model.reset() 160 | total_loss = 0 161 | ntokens = len(corpus.dictionary) 162 | hidden = model.init_hidden(batch_size) 163 | for i in range(0, data_source.size(0) - 1, args.bptt): 164 | data, targets = get_batch(data_source, i, args, evaluation=True) 165 | output, hidden = model(data, hidden) 166 | total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data 167 | hidden = repackage_hidden(hidden) 168 | return total_loss.item() / len(data_source) 169 | 170 | 171 | def train(): 172 | # Turn on training mode which enables dropout. 173 | if args.model == 'QRNN': model.reset() 174 | total_loss = 0 175 | start_time = time.time() 176 | ntokens = len(corpus.dictionary) 177 | hidden = model.init_hidden(args.batch_size) 178 | batch, i = 0, 0 179 | while i < train_data.size(0) - 1 - 1: 180 | bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. 181 | # Prevent excessively small or negative sequence lengths 182 | seq_len = max(5, int(np.random.normal(bptt, 5))) 183 | # There's a very small chance that it could select a very long sequence length resulting in OOM 184 | # seq_len = min(seq_len, args.bptt + 10) 185 | 186 | lr2 = optimizer.param_groups[0]['lr'] 187 | optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt 188 | model.train() 189 | data, targets = get_batch(train_data, i, args, seq_len=seq_len) 190 | 191 | # Starting each batch, we detach the hidden state from how it was previously produced. 192 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 193 | hidden = repackage_hidden(hidden) 194 | optimizer.zero_grad() 195 | 196 | output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True) 197 | raw_loss = criterion(model.decoder.weight, model.decoder.bias, output, targets) 198 | 199 | loss = raw_loss 200 | # Activiation Regularization 201 | if args.alpha: loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) 202 | # Temporal Activation Regularization (slowness) 203 | if args.beta: loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) 204 | loss.backward() 205 | 206 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 207 | if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip) 208 | optimizer.step() 209 | 210 | total_loss += raw_loss.data 211 | optimizer.param_groups[0]['lr'] = lr2 212 | if batch % args.log_interval == 0 and batch > 0: 213 | cur_loss = total_loss.item() / args.log_interval 214 | elapsed = time.time() - start_time 215 | print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | ' 216 | 'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format( 217 | epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], 218 | elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss), cur_loss / math.log(2))) 219 | total_loss = 0 220 | start_time = time.time() 221 | ### 222 | batch += 1 223 | i += seq_len 224 | 225 | # Loop over epochs. 226 | lr = args.lr 227 | best_val_loss = [] 228 | stored_loss = 100000000 229 | 230 | # At any point you can hit Ctrl + C to break out of training early. 231 | try: 232 | optimizer = None 233 | # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax) 234 | if args.optimizer == 'sgd': 235 | optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay) 236 | if args.optimizer == 'adam': 237 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay) 238 | for epoch in range(1, args.epochs+1): 239 | epoch_start_time = time.time() 240 | train() 241 | if 't0' in optimizer.param_groups[0]: 242 | tmp = {} 243 | for prm in model.parameters(): 244 | tmp[prm] = prm.data.clone() 245 | prm.data = optimizer.state[prm]['ax'].clone() 246 | 247 | val_loss2 = evaluate(val_data) 248 | print('-' * 89) 249 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 250 | 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format( 251 | epoch, (time.time() - epoch_start_time), val_loss2, math.exp(val_loss2), val_loss2 / math.log(2))) 252 | print('-' * 89) 253 | 254 | if val_loss2 < stored_loss: 255 | model_save(args.save) 256 | print('Saving Averaged!') 257 | stored_loss = val_loss2 258 | 259 | for prm in model.parameters(): 260 | prm.data = tmp[prm].clone() 261 | 262 | else: 263 | val_loss = evaluate(val_data, eval_batch_size) 264 | print('-' * 89) 265 | print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 266 | 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format( 267 | epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), val_loss / math.log(2))) 268 | print('-' * 89) 269 | 270 | if val_loss < stored_loss: 271 | model_save(args.save) 272 | print('Saving model (new best validation)') 273 | stored_loss = val_loss 274 | 275 | if args.optimizer == 'sgd' and 't0' not in optimizer.param_groups[0] and (len(best_val_loss)>args.nonmono and val_loss > min(best_val_loss[:-args.nonmono])): 276 | print('Switching to ASGD') 277 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 278 | 279 | if epoch in args.when: 280 | print('Saving model before learning rate decreased') 281 | model_save('{}.e{}'.format(args.save, epoch)) 282 | print('Dividing learning rate by 10') 283 | optimizer.param_groups[0]['lr'] /= 10. 284 | 285 | best_val_loss.append(val_loss) 286 | 287 | except KeyboardInterrupt: 288 | print('-' * 89) 289 | print('Exiting from training early') 290 | 291 | # Load the best saved model. 292 | model_load(args.save) 293 | 294 | # Run on test data. 295 | test_loss = evaluate(test_data, test_batch_size) 296 | print('=' * 89) 297 | print('| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'.format( 298 | test_loss, math.exp(test_loss), test_loss / math.log(2))) 299 | print('=' * 89) 300 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from embed_regularize import embedded_dropout 5 | from locked_dropout import LockedDropout 6 | from weight_drop import WeightDrop 7 | 8 | class RNNModel(nn.Module): 9 | """Container module with an encoder, a recurrent module, and a decoder.""" 10 | 11 | def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, dropouth=0.5, dropouti=0.5, dropoute=0.1, wdrop=0, tie_weights=False): 12 | super(RNNModel, self).__init__() 13 | self.lockdrop = LockedDropout() 14 | self.idrop = nn.Dropout(dropouti) 15 | self.hdrop = nn.Dropout(dropouth) 16 | self.drop = nn.Dropout(dropout) 17 | self.encoder = nn.Embedding(ntoken, ninp) 18 | assert rnn_type in ['LSTM', 'QRNN', 'GRU'], 'RNN type is not supported' 19 | if rnn_type == 'LSTM': 20 | self.rnns = [torch.nn.LSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), 1, dropout=0) for l in range(nlayers)] 21 | if wdrop: 22 | self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns] 23 | if rnn_type == 'GRU': 24 | self.rnns = [torch.nn.GRU(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else ninp, 1, dropout=0) for l in range(nlayers)] 25 | if wdrop: 26 | self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns] 27 | elif rnn_type == 'QRNN': 28 | from torchqrnn import QRNNLayer 29 | self.rnns = [QRNNLayer(input_size=ninp if l == 0 else nhid, hidden_size=nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True) for l in range(nlayers)] 30 | for rnn in self.rnns: 31 | rnn.linear = WeightDrop(rnn.linear, ['weight'], dropout=wdrop) 32 | print(self.rnns) 33 | self.rnns = torch.nn.ModuleList(self.rnns) 34 | self.decoder = nn.Linear(nhid, ntoken) 35 | 36 | # Optionally tie weights as in: 37 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 38 | # https://arxiv.org/abs/1608.05859 39 | # and 40 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 41 | # https://arxiv.org/abs/1611.01462 42 | if tie_weights: 43 | #if nhid != ninp: 44 | # raise ValueError('When using the tied flag, nhid must be equal to emsize') 45 | self.decoder.weight = self.encoder.weight 46 | 47 | self.init_weights() 48 | 49 | self.rnn_type = rnn_type 50 | self.ninp = ninp 51 | self.nhid = nhid 52 | self.nlayers = nlayers 53 | self.dropout = dropout 54 | self.dropouti = dropouti 55 | self.dropouth = dropouth 56 | self.dropoute = dropoute 57 | self.tie_weights = tie_weights 58 | 59 | def reset(self): 60 | if self.rnn_type == 'QRNN': [r.reset() for r in self.rnns] 61 | 62 | def init_weights(self): 63 | initrange = 0.1 64 | self.encoder.weight.data.uniform_(-initrange, initrange) 65 | self.decoder.bias.data.fill_(0) 66 | self.decoder.weight.data.uniform_(-initrange, initrange) 67 | 68 | def forward(self, input, hidden, return_h=False): 69 | emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0) 70 | #emb = self.idrop(emb) 71 | 72 | emb = self.lockdrop(emb, self.dropouti) 73 | 74 | raw_output = emb 75 | new_hidden = [] 76 | #raw_output, hidden = self.rnn(emb, hidden) 77 | raw_outputs = [] 78 | outputs = [] 79 | for l, rnn in enumerate(self.rnns): 80 | current_input = raw_output 81 | raw_output, new_h = rnn(raw_output, hidden[l]) 82 | new_hidden.append(new_h) 83 | raw_outputs.append(raw_output) 84 | if l != self.nlayers - 1: 85 | #self.hdrop(raw_output) 86 | raw_output = self.lockdrop(raw_output, self.dropouth) 87 | outputs.append(raw_output) 88 | hidden = new_hidden 89 | 90 | output = self.lockdrop(raw_output, self.dropout) 91 | outputs.append(output) 92 | 93 | result = output.view(output.size(0)*output.size(1), output.size(2)) 94 | if return_h: 95 | return result, hidden, raw_outputs, outputs 96 | return result, hidden 97 | 98 | def init_hidden(self, bsz): 99 | weight = next(self.parameters()).data 100 | if self.rnn_type == 'LSTM': 101 | return [(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_(), 102 | weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_()) 103 | for l in range(self.nlayers)] 104 | elif self.rnn_type == 'QRNN' or self.rnn_type == 'GRU': 105 | return [weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_() 106 | for l in range(self.nlayers)] 107 | -------------------------------------------------------------------------------- /pointer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | import data 10 | import model 11 | 12 | from utils import batchify, get_batch, repackage_hidden 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 15 | parser.add_argument('--data', type=str, default='data/penn', 16 | help='location of the data corpus') 17 | parser.add_argument('--model', type=str, default='LSTM', 18 | help='type of recurrent net (LSTM, QRNN)') 19 | parser.add_argument('--save', type=str,default='best.pt', 20 | help='model to use the pointer over') 21 | parser.add_argument('--cuda', action='store_false', 22 | help='use CUDA') 23 | parser.add_argument('--bptt', type=int, default=5000, 24 | help='sequence length') 25 | parser.add_argument('--window', type=int, default=3785, 26 | help='pointer window length') 27 | parser.add_argument('--theta', type=float, default=0.6625523432485668, 28 | help='mix between uniform distribution and pointer softmax distribution over previous words') 29 | parser.add_argument('--lambdasm', type=float, default=0.12785920428335693, 30 | help='linear mix between only pointer (1) and only vocab (0) distribution') 31 | args = parser.parse_args() 32 | 33 | ############################################################################### 34 | # Load data 35 | ############################################################################### 36 | 37 | corpus = data.Corpus(args.data) 38 | 39 | eval_batch_size = 1 40 | test_batch_size = 1 41 | #train_data = batchify(corpus.train, args.batch_size) 42 | val_data = batchify(corpus.valid, test_batch_size, args) 43 | test_data = batchify(corpus.test, test_batch_size, args) 44 | 45 | ############################################################################### 46 | # Build the model 47 | ############################################################################### 48 | 49 | ntokens = len(corpus.dictionary) 50 | criterion = nn.CrossEntropyLoss() 51 | 52 | def one_hot(idx, size, cuda=True): 53 | a = np.zeros((1, size), np.float32) 54 | a[0][idx] = 1 55 | v = Variable(torch.from_numpy(a)) 56 | if cuda: v = v.cuda() 57 | return v 58 | 59 | def evaluate(data_source, batch_size=10, window=args.window): 60 | # Turn on evaluation mode which disables dropout. 61 | if args.model == 'QRNN': model.reset() 62 | model.eval() 63 | total_loss = 0 64 | ntokens = len(corpus.dictionary) 65 | hidden = model.init_hidden(batch_size) 66 | next_word_history = None 67 | pointer_history = None 68 | for i in range(0, data_source.size(0) - 1, args.bptt): 69 | if i > 0: print(i, len(data_source), math.exp(total_loss / i)) 70 | data, targets = get_batch(data_source, i, evaluation=True, args=args) 71 | output, hidden, rnn_outs, _ = model(data, hidden, return_h=True) 72 | rnn_out = rnn_outs[-1].squeeze() 73 | output_flat = output.view(-1, ntokens) 74 | ### 75 | # Fill pointer history 76 | start_idx = len(next_word_history) if next_word_history is not None else 0 77 | next_word_history = torch.cat([one_hot(t.data[0], ntokens) for t in targets]) if next_word_history is None else torch.cat([next_word_history, torch.cat([one_hot(t.data[0], ntokens) for t in targets])]) 78 | #print(next_word_history) 79 | pointer_history = Variable(rnn_out.data) if pointer_history is None else torch.cat([pointer_history, Variable(rnn_out.data)], dim=0) 80 | #print(pointer_history) 81 | ### 82 | # Built-in cross entropy 83 | # total_loss += len(data) * criterion(output_flat, targets).data[0] 84 | ### 85 | # Manual cross entropy 86 | # softmax_output_flat = torch.nn.functional.softmax(output_flat) 87 | # soft = torch.gather(softmax_output_flat, dim=1, index=targets.view(-1, 1)) 88 | # entropy = -torch.log(soft) 89 | # total_loss += len(data) * entropy.mean().data[0] 90 | ### 91 | # Pointer manual cross entropy 92 | loss = 0 93 | softmax_output_flat = torch.nn.functional.softmax(output_flat) 94 | for idx, vocab_loss in enumerate(softmax_output_flat): 95 | p = vocab_loss 96 | if start_idx + idx > window: 97 | valid_next_word = next_word_history[start_idx + idx - window:start_idx + idx] 98 | valid_pointer_history = pointer_history[start_idx + idx - window:start_idx + idx] 99 | logits = torch.mv(valid_pointer_history, rnn_out[idx]) 100 | theta = args.theta 101 | ptr_attn = torch.nn.functional.softmax(theta * logits).view(-1, 1) 102 | ptr_dist = (ptr_attn.expand_as(valid_next_word) * valid_next_word).sum(0).squeeze() 103 | lambdah = args.lambdasm 104 | p = lambdah * ptr_dist + (1 - lambdah) * vocab_loss 105 | ### 106 | target_loss = p[targets[idx].data] 107 | loss += (-torch.log(target_loss)).data[0] 108 | total_loss += loss / batch_size 109 | ### 110 | hidden = repackage_hidden(hidden) 111 | next_word_history = next_word_history[-window:] 112 | pointer_history = pointer_history[-window:] 113 | return total_loss / len(data_source) 114 | 115 | # Load the best saved model. 116 | with open(args.save, 'rb') as f: 117 | if not args.cuda: 118 | model = torch.load(f, map_location=lambda storage, loc: storage) 119 | else: 120 | model = torch.load(f) 121 | print(model) 122 | 123 | # Run on val data. 124 | val_loss = evaluate(val_data, test_batch_size) 125 | print('=' * 89) 126 | print('| End of pointer | val loss {:5.2f} | val ppl {:8.2f}'.format( 127 | val_loss, math.exp(val_loss))) 128 | print('=' * 89) 129 | 130 | # Run on test data. 131 | test_loss = evaluate(test_data, test_batch_size) 132 | print('=' * 89) 133 | print('| End of pointer | test loss {:5.2f} | test ppl {:8.2f}'.format( 134 | test_loss, math.exp(test_loss))) 135 | print('=' * 89) 136 | -------------------------------------------------------------------------------- /splitcross.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import numpy as np 7 | 8 | 9 | class SplitCrossEntropyLoss(nn.Module): 10 | r'''SplitCrossEntropyLoss calculates an approximate softmax''' 11 | def __init__(self, hidden_size, splits, verbose=False): 12 | # We assume splits is [0, split1, split2, N] where N >= |V| 13 | # For example, a vocab of 1000 words may have splits [0] + [100, 500] + [inf] 14 | super(SplitCrossEntropyLoss, self).__init__() 15 | self.hidden_size = hidden_size 16 | self.splits = [0] + splits + [100 * 1000000] 17 | self.nsplits = len(self.splits) - 1 18 | self.stats = defaultdict(list) 19 | self.verbose = verbose 20 | # Each of the splits that aren't in the head require a pretend token, we'll call them tombstones 21 | # The probability given to this tombstone is the probability of selecting an item from the represented split 22 | if self.nsplits > 1: 23 | self.tail_vectors = nn.Parameter(torch.zeros(self.nsplits - 1, hidden_size)) 24 | self.tail_bias = nn.Parameter(torch.zeros(self.nsplits - 1)) 25 | 26 | def logprob(self, weight, bias, hiddens, splits=None, softmaxed_head_res=None, verbose=False): 27 | # First we perform the first softmax on the head vocabulary and the tombstones 28 | if softmaxed_head_res is None: 29 | start, end = self.splits[0], self.splits[1] 30 | head_weight = None if end - start == 0 else weight[start:end] 31 | head_bias = None if end - start == 0 else bias[start:end] 32 | # We only add the tombstones if we have more than one split 33 | if self.nsplits > 1: 34 | head_weight = self.tail_vectors if head_weight is None else torch.cat([head_weight, self.tail_vectors]) 35 | head_bias = self.tail_bias if head_bias is None else torch.cat([head_bias, self.tail_bias]) 36 | 37 | # Perform the softmax calculation for the word vectors in the head for all splits 38 | # We need to guard against empty splits as torch.cat does not like random lists 39 | head_res = torch.nn.functional.linear(hiddens, head_weight, bias=head_bias) 40 | softmaxed_head_res = torch.nn.functional.log_softmax(head_res, dim=-1) 41 | 42 | if splits is None: 43 | splits = list(range(self.nsplits)) 44 | 45 | results = [] 46 | running_offset = 0 47 | for idx in splits: 48 | 49 | # For those targets in the head (idx == 0) we only need to return their loss 50 | if idx == 0: 51 | results.append(softmaxed_head_res[:, :-(self.nsplits - 1)]) 52 | 53 | # If the target is in one of the splits, the probability is the p(tombstone) * p(word within tombstone) 54 | else: 55 | start, end = self.splits[idx], self.splits[idx + 1] 56 | tail_weight = weight[start:end] 57 | tail_bias = bias[start:end] 58 | 59 | # Calculate the softmax for the words in the tombstone 60 | tail_res = torch.nn.functional.linear(hiddens, tail_weight, bias=tail_bias) 61 | 62 | # Then we calculate p(tombstone) * p(word in tombstone) 63 | # Adding is equivalent to multiplication in log space 64 | head_entropy = (softmaxed_head_res[:, -idx]).contiguous() 65 | tail_entropy = torch.nn.functional.log_softmax(tail_res, dim=-1) 66 | results.append(head_entropy.view(-1, 1) + tail_entropy) 67 | 68 | if len(results) > 1: 69 | return torch.cat(results, dim=1) 70 | return results[0] 71 | 72 | def split_on_targets(self, hiddens, targets): 73 | # Split the targets into those in the head and in the tail 74 | split_targets = [] 75 | split_hiddens = [] 76 | 77 | # Determine to which split each element belongs (for each start split value, add 1 if equal or greater) 78 | # This method appears slower at least for WT-103 values for approx softmax 79 | #masks = [(targets >= self.splits[idx]).view(1, -1) for idx in range(1, self.nsplits)] 80 | #mask = torch.sum(torch.cat(masks, dim=0), dim=0) 81 | ### 82 | # This is equally fast for smaller splits as method below but scales linearly 83 | mask = None 84 | for idx in range(1, self.nsplits): 85 | partial_mask = targets >= self.splits[idx] 86 | mask = mask + partial_mask if mask is not None else partial_mask 87 | ### 88 | #masks = torch.stack([targets] * (self.nsplits - 1)) 89 | #mask = torch.sum(masks >= self.split_starts, dim=0) 90 | for idx in range(self.nsplits): 91 | # If there are no splits, avoid costly masked select 92 | if self.nsplits == 1: 93 | split_targets, split_hiddens = [targets], [hiddens] 94 | continue 95 | # If all the words are covered by earlier targets, we have empties so later stages don't freak out 96 | if sum(len(t) for t in split_targets) == len(targets): 97 | split_targets.append([]) 98 | split_hiddens.append([]) 99 | continue 100 | # Are you in our split? 101 | tmp_mask = mask == idx 102 | split_targets.append(torch.masked_select(targets, tmp_mask)) 103 | split_hiddens.append(hiddens.masked_select(tmp_mask.unsqueeze(1).expand_as(hiddens)).view(-1, hiddens.size(1))) 104 | return split_targets, split_hiddens 105 | 106 | def forward(self, weight, bias, hiddens, targets, verbose=False): 107 | if self.verbose or verbose: 108 | for idx in sorted(self.stats): 109 | print('{}: {}'.format(idx, int(np.mean(self.stats[idx]))), end=', ') 110 | print() 111 | 112 | total_loss = None 113 | if len(hiddens.size()) > 2: hiddens = hiddens.view(-1, hiddens.size(2)) 114 | 115 | split_targets, split_hiddens = self.split_on_targets(hiddens, targets) 116 | 117 | # First we perform the first softmax on the head vocabulary and the tombstones 118 | start, end = self.splits[0], self.splits[1] 119 | head_weight = None if end - start == 0 else weight[start:end] 120 | head_bias = None if end - start == 0 else bias[start:end] 121 | 122 | # We only add the tombstones if we have more than one split 123 | if self.nsplits > 1: 124 | head_weight = self.tail_vectors if head_weight is None else torch.cat([head_weight, self.tail_vectors]) 125 | head_bias = self.tail_bias if head_bias is None else torch.cat([head_bias, self.tail_bias]) 126 | 127 | # Perform the softmax calculation for the word vectors in the head for all splits 128 | # We need to guard against empty splits as torch.cat does not like random lists 129 | combo = torch.cat([split_hiddens[i] for i in range(self.nsplits) if len(split_hiddens[i])]) 130 | ### 131 | all_head_res = torch.nn.functional.linear(combo, head_weight, bias=head_bias) 132 | softmaxed_all_head_res = torch.nn.functional.log_softmax(all_head_res, dim=-1) 133 | if self.verbose or verbose: 134 | self.stats[0].append(combo.size()[0] * head_weight.size()[0]) 135 | 136 | running_offset = 0 137 | for idx in range(self.nsplits): 138 | # If there are no targets for this split, continue 139 | if len(split_targets[idx]) == 0: continue 140 | 141 | # For those targets in the head (idx == 0) we only need to return their loss 142 | if idx == 0: 143 | softmaxed_head_res = softmaxed_all_head_res[running_offset:running_offset + len(split_hiddens[idx])] 144 | entropy = -torch.gather(softmaxed_head_res, dim=1, index=split_targets[idx].view(-1, 1)) 145 | # If the target is in one of the splits, the probability is the p(tombstone) * p(word within tombstone) 146 | else: 147 | softmaxed_head_res = softmaxed_all_head_res[running_offset:running_offset + len(split_hiddens[idx])] 148 | 149 | if self.verbose or verbose: 150 | start, end = self.splits[idx], self.splits[idx + 1] 151 | tail_weight = weight[start:end] 152 | self.stats[idx].append(split_hiddens[idx].size()[0] * tail_weight.size()[0]) 153 | 154 | # Calculate the softmax for the words in the tombstone 155 | tail_res = self.logprob(weight, bias, split_hiddens[idx], splits=[idx], softmaxed_head_res=softmaxed_head_res) 156 | 157 | # Then we calculate p(tombstone) * p(word in tombstone) 158 | # Adding is equivalent to multiplication in log space 159 | head_entropy = softmaxed_head_res[:, -idx] 160 | # All indices are shifted - if the first split handles [0,...,499] then the 500th in the second split will be 0 indexed 161 | indices = (split_targets[idx] - self.splits[idx]).view(-1, 1) 162 | # Warning: if you don't squeeze, you get an N x 1 return, which acts oddly with broadcasting 163 | tail_entropy = torch.gather(torch.nn.functional.log_softmax(tail_res, dim=-1), dim=1, index=indices).squeeze() 164 | entropy = -(head_entropy + tail_entropy) 165 | ### 166 | running_offset += len(split_hiddens[idx]) 167 | total_loss = entropy.float().sum() if total_loss is None else total_loss + entropy.float().sum() 168 | 169 | return (total_loss / len(targets)).type_as(weight) 170 | 171 | 172 | if __name__ == '__main__': 173 | np.random.seed(42) 174 | torch.manual_seed(42) 175 | if torch.cuda.is_available(): 176 | torch.cuda.manual_seed(42) 177 | 178 | V = 8 179 | H = 10 180 | N = 100 181 | E = 10 182 | 183 | embed = torch.nn.Embedding(V, H) 184 | crit = SplitCrossEntropyLoss(hidden_size=H, splits=[V // 2]) 185 | bias = torch.nn.Parameter(torch.ones(V)) 186 | optimizer = torch.optim.SGD(list(embed.parameters()) + list(crit.parameters()), lr=1) 187 | 188 | for _ in range(E): 189 | prev = torch.autograd.Variable((torch.rand(N, 1) * 0.999 * V).int().long()) 190 | x = torch.autograd.Variable((torch.rand(N, 1) * 0.999 * V).int().long()) 191 | y = embed(prev).squeeze() 192 | c = crit(embed.weight, bias, y, x.view(N)) 193 | print('Crit', c.exp().data[0]) 194 | 195 | logprobs = crit.logprob(embed.weight, bias, y[:2]).exp() 196 | print(logprobs) 197 | print(logprobs.sum(dim=1)) 198 | 199 | optimizer.zero_grad() 200 | c.backward() 201 | optimizer.step() 202 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def repackage_hidden(h): 5 | """Wraps hidden states in new Tensors, 6 | to detach them from their history.""" 7 | if isinstance(h, torch.Tensor): 8 | return h.detach() 9 | else: 10 | return tuple(repackage_hidden(v) for v in h) 11 | 12 | 13 | def batchify(data, bsz, args): 14 | # Work out how cleanly we can divide the dataset into bsz parts. 15 | nbatch = data.size(0) // bsz 16 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 17 | data = data.narrow(0, 0, nbatch * bsz) 18 | # Evenly divide the data across the bsz batches. 19 | data = data.view(bsz, -1).t().contiguous() 20 | if args.cuda: 21 | data = data.cuda() 22 | return data 23 | 24 | 25 | def get_batch(source, i, args, seq_len=None, evaluation=False): 26 | seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i) 27 | data = source[i:i+seq_len] 28 | target = source[i+1:i+1+seq_len].view(-1) 29 | return data, target 30 | -------------------------------------------------------------------------------- /weight_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from functools import wraps 4 | 5 | class WeightDrop(torch.nn.Module): 6 | def __init__(self, module, weights, dropout=0, variational=False): 7 | super(WeightDrop, self).__init__() 8 | self.module = module 9 | self.weights = weights 10 | self.dropout = dropout 11 | self.variational = variational 12 | self._setup() 13 | 14 | def widget_demagnetizer_y2k_edition(*args, **kwargs): 15 | # We need to replace flatten_parameters with a nothing function 16 | # It must be a function rather than a lambda as otherwise pickling explodes 17 | # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! 18 | # (╯°□°)╯︵ ┻━┻ 19 | return 20 | 21 | def _setup(self): 22 | # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN 23 | if issubclass(type(self.module), torch.nn.RNNBase): 24 | self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition 25 | 26 | for name_w in self.weights: 27 | print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) 28 | w = getattr(self.module, name_w) 29 | del self.module._parameters[name_w] 30 | self.module.register_parameter(name_w + '_raw', Parameter(w.data)) 31 | 32 | def _setweights(self): 33 | for name_w in self.weights: 34 | raw_w = getattr(self.module, name_w + '_raw') 35 | w = None 36 | if self.variational: 37 | mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) 38 | if raw_w.is_cuda: mask = mask.cuda() 39 | mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True) 40 | w = mask.expand_as(raw_w) * raw_w 41 | else: 42 | w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) 43 | setattr(self.module, name_w, w) 44 | 45 | def forward(self, *args): 46 | self._setweights() 47 | return self.module.forward(*args) 48 | 49 | if __name__ == '__main__': 50 | import torch 51 | from weight_drop import WeightDrop 52 | 53 | # Input is (seq, batch, input) 54 | x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda() 55 | h0 = None 56 | 57 | ### 58 | 59 | print('Testing WeightDrop') 60 | print('=-=-=-=-=-=-=-=-=-=') 61 | 62 | ### 63 | 64 | print('Testing WeightDrop with Linear') 65 | 66 | lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9) 67 | lin.cuda() 68 | run1 = [x.sum() for x in lin(x).data] 69 | run2 = [x.sum() for x in lin(x).data] 70 | 71 | print('All items should be different') 72 | print('Run 1:', run1) 73 | print('Run 2:', run2) 74 | 75 | assert run1[0] != run2[0] 76 | assert run1[1] != run2[1] 77 | 78 | print('---') 79 | 80 | ### 81 | 82 | print('Testing WeightDrop with LSTM') 83 | 84 | wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9) 85 | wdrnn.cuda() 86 | 87 | run1 = [x.sum() for x in wdrnn(x, h0)[0].data] 88 | run2 = [x.sum() for x in wdrnn(x, h0)[0].data] 89 | 90 | print('First timesteps should be equal, all others should differ') 91 | print('Run 1:', run1) 92 | print('Run 2:', run2) 93 | 94 | # First time step, not influenced by hidden to hidden weights, should be equal 95 | assert run1[0] == run2[0] 96 | # Second step should not 97 | assert run1[1] != run2[1] 98 | 99 | print('---') 100 | --------------------------------------------------------------------------------