├── LICENSE ├── NTT_LICENSE ├── README.md ├── cal_ppl.py ├── data.py ├── embed_regularize.py ├── finetune.py ├── generate.py ├── get_data.sh ├── locked_dropout.py ├── main.py ├── model.py ├── utils.py └── weight_drop.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Zihang Dai and Zhilin Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NTT_LICENSE: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT FOR EVALUATION 2 | 3 | This SOFTWARE EVALUATION LICENSE AGREEMENT (this "Agreement") is a legal contract between a person who uses or otherwise accesses or installs the Software (“User(s)”), and Nippon Telegraph and Telephone corporation ("NTT"). 4 | READ THE TERMS AND CONDITIONS OF THIS AGREEMENT CAREFULLY BEFORE INSTALLING OR OTHERWISE ACCESSING OR USING NTT'S PROPRIETARY SOFTWARE ACCOMPANIED BY THIS AGREEMENT (the "SOFTWARE"). THE SOFTWARE IS COPYRIGHTED AND IT IS LICENSED TO USER UNDER THIS AGREEMENT, NOT SOLD TO USER. BY INSTALLING OR OTHERWISE ACCESSING OR USING THE SOFTWARE, USER ACKNOWLEDGES THAT USER HAS READ THIS AGREEMENT, THAT USER UNDERSTANDS IT, AND THAT USER ACCEPTS AND AGREES TO BE BOUND BY ITS TERMS. IF AT ANY TIME USER IS NOT WILLING TO BE BOUND BY THE TERMS OF THIS AGREEMENT, USER SHOULD TERMINATE THE INSTALLATION PROCESS, IMMEDIATELY CEASE AND REFRAIN FROM ACCESSING OR USING THE SOFTWARE AND DELETE ANY COPIES USER MAY HAVE. THIS AGREEMENT REPRESENTS THE ENTIRE AGREEMENT BETWEEN USER AND NTT CONCERNING THE SOFTWARE. 5 | 6 | 7 | BACKGROUND 8 | A. NTT is the owner of all rights, including all patent rights, and copyrights in and to the Software and related documentation listed in Exhibit A to this Agreement. 9 | B. User wishes to obtain a royalty free license to use the Software to enable User to evaluate, and NTT wishes to grant such a license to User, pursuant and subject to the terms and conditions of this Agreement. 10 | C. As a condition to NTT's provision of the Software to User, NTT has required User to execute this Agreement. 11 | In consideration of these premises, and the mutual promises and conditions in this Agreement, the parties hereby agree as follows: 12 | 1. Grant of Evaluation License. NTT hereby grants to User, and User hereby accepts, under the terms and conditions of this Agreement, a royalty free, nontransferable and nonexclusive license to use the Software internally for the purposes of testing, analyzing, and evaluating the methods or mechanisms as shown in [the research paper submitted by NTT to a certain academy]. User may make a reasonable number of backup copies of the Software solely for User's internal use pursuant to the license granted in this Section 1. 13 | 2. Shipment and Installation. NTT will ship or deliver the Software by any method that NTT deems appropriate. User shall be solely responsible for proper installation of the Software. 14 | 3. Term. This Agreement is effective whichever is earlier (i) upon User’s acceptance of the Agreement, or (ii) upon User’s installing, accessing, and using the Software, even if User has not expressly accepted this Agreement. Without prejudice to any other rights, NTT may terminate this Agreement without notice to User. User may terminate this Agreement at any time by User’s decision to terminate the Agreement to NTT and ceasing use of the Software. Upon any termination or expiration of this Agreement for any reason, User agrees to uninstall the Software and destroy all copies of the Software. 15 | 4. Proprietary Rights 16 | (a) The Software is the valuable and proprietary property of NTT, and NTT shall retain exclusive title to this property both during the term and after the termination of this Agreement. Without limitation, User acknowledges that all patent rights and copyrights in the Software shall remain the exclusive property of NTT at all times. User shall use not less than reasonable care in safeguarding the confidentiality of the Software. 17 | (b) USER SHALL NOT, IN WHOLE OR IN PART, AT ANY TIME DURING THE TERM OF OR AFTER THE TERMINATION OF THIS AGREEMENT: (i) SELL, ASSIGN, LEASE, DISTRIBUTE, OR OTHERWISE TRANSFER THE SOFTWARE TO ANY THIRD PARTY; (ii) EXCEPT AS OTHERWISE PROVIDED HEREIN, COPY OR REPRODUCE THE SOFTWARE IN ANY MANNER; OR (iii) ALLOW ANY PERSON OR ENTITY TO COMMIT ANY OF THE ACTIONS DESCRIBED IN (i) THROUGH (ii) ABOVE. 18 | (c) User shall take appropriate action, by instruction, agreement, or otherwise, with respect to its employees permitted under this Agreement to have access to the Software to ensure that all of User's obligations under this Section 4 shall be satisfied. 19 | 5.  Indemnity. User shall defend, indemnify and hold harmless NTT, its agents and employees, from any loss, damage, or liability arising in connection with User's improper or unauthorized use of the Software. NTT SHALL HAVE THE SOLE RIGHT TO CONDUCT DEFEND ANY ACTTION RELATING TO THE SOFTWARE. 20 | 6. Disclaimer. THE SOFTWARE IS LICENSED TO USER "AS IS," WITHOUT ANY TRAINING, MAINTENANCE, OR SERVICE OBLIGATIONS WHATSOEVER ON THE PART OF NTT. NTT MAKES NO EXPRESS OR IMPLIED WARRANTIES OF ANY TYPE WHATSOEVER, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF MERCHANTABILITY, OF FITNESS FOR A PARTICULAR PURPOSE AND OF NON-INFRINGEMENT ON COPYRIGHT OR ANY OTHER RIGHT OF THIRD PARTIES. USER ASSUMES ALL RISKS ASSOCIATED WITH ITS USE OF THE SOFTWARE, INCLUDING WITHOUT LIMITATION RISKS RELATING TO QUALITY, PERFORMANCE, DATA LOSS, AND UTILITY IN A PRODUCTION ENVIRONMENT. 21 | 7. Limitation of Liability. IN NO EVENT SHALL NTT BE LIABLE TO USER OR TO ANY THIRD PARTY FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING BUT NOT LIMITED TO DAMAGES FOR PERSONAL INJURY, PROPERTY DAMAGE, LOST PROFITS, OR OTHER ECONOMIC LOSS, ARISING IN CONNECTION WITH USER'S USE OF OR INABILITY TO USE THE SOFTWARE, IN CONNECTION WITH NTT'S PROVISION OF OR FAILURE TO PROVIDE SERVICES PERTAINING TO THE SOFTWARE, OR AS A RESULT OF ANY DEFECT IN THE SOFTWARE. THIS DISCLAIMER OF LIABILITY SHALL APPLY REGARD¬LESS OF THE FORM OF ACTION THAT MAY BE BROUGHT AGAINST NTT, WHETHER IN CONTRACT OR TORT, INCLUDING WITHOUT LIMITATION ANY ACTION FOR NEGLIGENCE. USER'S SOLE REMEDY IN THE EVENT OF ANY BREACH OF THIS AGREEMENT BY NTT SHALL BE TERMINATION PURSUANT TO SECTION 3. 22 | 8. No Assignment or Sublicense. Neither this Agreement nor any right or license under this Agreement, nor the Software, may be sublicensed, assigned, or otherwise transferred by User without NTT's prior written consent. 23 | 9. General 24 | (a) If any provision, or part of a provision, of this Agreement is or becomes illegal, unenforceable, or invalidated, by operation of law or otherwise, that provision or part shall to that extent be deemed omitted, and the remainder of this Agreement shall remain in full force and effect. 25 | (b) This Agreement is the complete and exclusive statement of the agreement between the parties with respect to the subject matter hereof, and supersedes all written and oral contracts, proposals, and other communications between the parties relating to that subject matter. 26 | (c) Subject to Section 8, this Agreement shall be binding on, and shall inure to the benefit of, the respective successors and assigns of NTT and User. 27 | (d) If either party to this Agreement initiates a legal action or proceeding to enforce or interpret any part of this Agreement, the prevailing party in such action shall be entitled to recover, as an element of the costs of such action and not as damages, its attorneys' fees and other costs associated with such action or proceeding. 28 | (e) This Agreement shall be governed by and interpreted under the laws of Japan, without reference to conflicts of law principles. All disputes arising out of or in connection with this Agreement shall be finally settled by arbitration in Tokyo in accordance with the Commercial Arbitration Rules of the Japan Commercial Arbitration Association. The arbitration shall be conducted by three (3) arbitrators and in Japanese. The award rendered by the arbitrators shall be final and binding upon the parties. Judgment upon the award may be entered in any court having jurisdiction thereof. 29 | (f)   NTT shall not be liable to the User or to any third party for any delay or failure to perform NTT’s obligation set forth under this Agreement due to any cause beyond NTT’s reasonable control. 30 |   31 | EXHIBIT A 32 | 33 | cal_ppl.py 34 | model.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Direct Output Connection for a High-Rank Language Model 2 | 3 | This repository contains source files we used in our paper 4 | >[Direct Output Connection for a High-Rank Language Model](https://arxiv.org/abs/1808.10143) 5 | 6 | >Sho Takase, Jun Suzuki, Masaaki Nagata 7 | 8 | > Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing 9 | 10 | ## Requirements 11 | 12 | Python 3.5, PyTorch 0.2.0 13 | 14 | About 9.5 GB of VRAM (tested on K80). 15 | 16 | ## Download the data 17 | 18 | ```./get_data.sh``` 19 | 20 | ## Train the models (to reproduce our results) 21 | 22 | ### Penn Treebank 23 | 24 | First, train the model 25 | 26 | ```python main.py --data data/penn --dropouti 0.4 --dropoutl 0.6 --dropouth 0.225 --seed 28 --batch_size 12 --lr 20.0 --epoch 500 --nhid 960 --nhidlast 620 --emsize 280 --n_experts 15 --num4second 5 --var 0.001 --nonmono 60 --save PTB --single_gpu``` 27 | 28 | Second, finetune the model 29 | 30 | ```python finetune.py --data data/penn --dropouti 0.4 --dropoutl 0.6 --dropouth 0.225 --seed 28 --batch_size 12 --lr 20.0 --epoch 500 --var 0.001 --nonmono 60 --save PATH_TO_FOLDER --single_gpu``` 31 | 32 | where `PATH_TO_FOLDER` is the folder created by the first step (concatenation of PTB with a timestamp). 33 | 34 | Third, run evaluation 35 | 36 | ```python cal_ppl.py --data data/penn --save PATH_TO_FOLDER/finetune_model.pt --bptt 1000``` 37 | 38 | ### WikiText-2 (Single GPU) 39 | 40 | First, train the model 41 | 42 | ```python main.py --epochs 500 --data data/wikitext-2 --save WT2 --dropouth 0.2 --seed 1882 --n_experts 15 --num4second 5 --var 0.001 --nhid 1150 --nhidlast 650 --emsize 300 --batch_size 15 --lr 15.0 --dropoutl 0.6 --small_batch_size 5 --max_seq_len_delta 20 --dropouti 0.55 --nonmono 60 --single_gpu``` 43 | 44 | Second, finetune the model 45 | 46 | ```python finetune.py --epochs 500 --data data/wikitext-2 --save PATH_TO_FOLDER --dropouth 0.2 --seed 1882 --var 0.001 --batch_size 15 --lr 15.0 --dropoutl 0.6 --small_batch_size 5 --max_seq_len_delta 20 --dropouti 0.55 --nonmono 60 --single_gpu``` 47 | 48 | Third, run evaluation 49 | 50 | ```python cal_ppl.py --data data/wikitext-2 --save PATH_TO_FOLDER/finetune_model.pt --bptt 1000``` 51 | 52 | ### Pre-trained Models 53 | 54 | [https://drive.google.com/open?id=1ug-6ISrXHEGcWTk5KIw8Ojdjuww-i-Ci](https://drive.google.com/open?id=1ug-6ISrXHEGcWTk5KIw8Ojdjuww-i-Ci) 55 | 56 | ptb, wikitext2: models to obtain the single model results 57 | 58 | ptb_ensemble, wikitext2_ensemble: other trained models to obtain ensemble results 59 | 60 | 61 | ## Licenses 62 | 63 | Files listed in NTT_LICENSE's EXHIBIT A are applied the NTT_LICENSE. 64 | 65 | Other files are applied the LICENSE in this repository. 66 | 67 | -------------------------------------------------------------------------------- /cal_ppl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | import data 10 | import model 11 | 12 | from utils import batchify, get_batch, repackage_hidden 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') 15 | parser.add_argument('--data', type=str, default='data/penn', 16 | help='location of the data corpus') 17 | parser.add_argument('--save', type=str,default='best.pt', 18 | help='model to use the pointer over') 19 | parser.add_argument('--cuda', action='store_false', 20 | help='use CUDA') 21 | parser.add_argument('--bptt', type=int, default=5000, 22 | help='sequence length') 23 | args = parser.parse_args() 24 | 25 | ############################################################################### 26 | # Load data 27 | ############################################################################### 28 | 29 | corpus = data.Corpus(args.data) 30 | 31 | eval_batch_size = 1 32 | test_batch_size = 1 33 | train_data = batchify(corpus.train, test_batch_size, args) 34 | val_data = batchify(corpus.valid, test_batch_size, args) 35 | test_data = batchify(corpus.test, test_batch_size, args) 36 | 37 | ############################################################################### 38 | # Build the model 39 | ############################################################################### 40 | 41 | ntokens = len(corpus.dictionary) 42 | criterion = nn.CrossEntropyLoss() 43 | 44 | def one_hot(idx, size, cuda=True): 45 | a = np.zeros((1, size), np.float32) 46 | a[0][idx] = 1 47 | v = Variable(torch.from_numpy(a)) 48 | if cuda: v = v.cuda() 49 | return v 50 | 51 | def evaluate(data_source, batch_size=10): 52 | # Turn on evaluation mode which disables dropout. 53 | model.eval() 54 | total_loss = 0 55 | ntokens = len(corpus.dictionary) 56 | hidden = model.init_hidden(batch_size) 57 | matrix_list = [] 58 | prior_total = 0 59 | for i in range(0, data_source.size(0) - 1, args.bptt): 60 | data, targets = get_batch(data_source, i, args, evaluation=True) 61 | targets = targets.view(-1) 62 | output, hidden, rnn_outs, _, prior = model(data, hidden, return_h=True) 63 | loss = nn.functional.nll_loss(output.view(-1, ntokens), targets).data 64 | total_loss += loss * len(data) 65 | hidden = repackage_hidden(hidden) 66 | prior_total += prior.sum(0).data.cpu().numpy() 67 | output_numpy = output.view(-1, ntokens).data.cpu().numpy() 68 | matrix_list.append(output_numpy) 69 | matrix = np.concatenate(matrix_list) 70 | return total_loss[0] / len(data_source) 71 | 72 | 73 | # Load the best saved model. 74 | with open(args.save, 'rb') as f: 75 | if not args.cuda: 76 | model = torch.load(f, map_location=lambda storage, loc: storage) 77 | else: 78 | model = torch.load(f) 79 | print(model) 80 | 81 | # Run on val data. 82 | val_loss = evaluate(val_data, test_batch_size) 83 | print('=' * 89) 84 | print('| End of pointer | val loss {:5.2f} | val ppl {:8.2f}'.format( 85 | val_loss, math.exp(val_loss))) 86 | print('=' * 89) 87 | 88 | # Run on test data. 89 | test_loss = evaluate(test_data, test_batch_size) 90 | print('=' * 89) 91 | print('| End of pointer | test loss {:5.2f} | test ppl {:8.2f}'.format( 92 | test_loss, math.exp(test_loss))) 93 | print('=' * 89) 94 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from collections import Counter 5 | 6 | 7 | class Dictionary(object): 8 | def __init__(self): 9 | self.word2idx = {} 10 | self.idx2word = [] 11 | self.counter = Counter() 12 | self.total = 0 13 | 14 | def add_word(self, word): 15 | if word not in self.word2idx: 16 | self.idx2word.append(word) 17 | self.word2idx[word] = len(self.idx2word) - 1 18 | token_id = self.word2idx[word] 19 | self.counter[token_id] += 1 20 | self.total += 1 21 | return self.word2idx[word] 22 | 23 | def __len__(self): 24 | return len(self.idx2word) 25 | 26 | 27 | class Corpus(object): 28 | def __init__(self, path): 29 | self.dictionary = Dictionary() 30 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 31 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 32 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 33 | 34 | def tokenize(self, path): 35 | """Tokenizes a text file.""" 36 | assert os.path.exists(path) 37 | # Add words to the dictionary 38 | with open(path, 'r', encoding='utf-8') as f: 39 | tokens = 0 40 | for line in f: 41 | words = line.split() + [''] 42 | tokens += len(words) 43 | for word in words: 44 | self.dictionary.add_word(word) 45 | 46 | # Tokenize file content 47 | with open(path, 'r', encoding='utf-8') as f: 48 | ids = torch.LongTensor(tokens) 49 | token = 0 50 | for line in f: 51 | words = line.split() + [''] 52 | for word in words: 53 | ids[token] = self.dictionary.word2idx[word] 54 | token += 1 55 | 56 | return ids 57 | 58 | class SentCorpus(object): 59 | def __init__(self, path): 60 | self.dictionary = Dictionary() 61 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 62 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 63 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 64 | 65 | def tokenize(self, path): 66 | """Tokenizes a text file.""" 67 | assert os.path.exists(path) 68 | # Add words to the dictionary 69 | with open(path, 'r', encoding='utf-8') as f: 70 | tokens = 0 71 | for line in f: 72 | words = line.split() + [''] 73 | tokens += len(words) 74 | for word in words: 75 | self.dictionary.add_word(word) 76 | 77 | # Tokenize file content 78 | sents = [] 79 | with open(path, 'r', encoding='utf-8') as f: 80 | for line in f: 81 | if not line: 82 | continue 83 | words = line.split() + [''] 84 | sent = torch.LongTensor(len(words)) 85 | for i, word in enumerate(words): 86 | sent[i] = self.dictionary.word2idx[word] 87 | sents.append(sent) 88 | 89 | return sents 90 | 91 | class BatchSentLoader(object): 92 | def __init__(self, sents, batch_size, pad_id=0, cuda=False, volatile=False): 93 | self.sents = sents 94 | self.batch_size = batch_size 95 | self.sort_sents = sorted(sents, key=lambda x: x.size(0)) 96 | self.cuda = cuda 97 | self.volatile = volatile 98 | self.pad_id = pad_id 99 | 100 | def __next__(self): 101 | if self.idx >= len(self.sort_sents): 102 | raise StopIteration 103 | 104 | batch_size = min(self.batch_size, len(self.sort_sents)-self.idx) 105 | batch = self.sort_sents[self.idx:self.idx+batch_size] 106 | max_len = max([s.size(0) for s in batch]) 107 | tensor = torch.LongTensor(max_len, batch_size).fill_(self.pad_id) 108 | for i in range(len(batch)): 109 | s = batch[i] 110 | tensor[:s.size(0),i].copy_(s) 111 | if self.cuda: 112 | tensor = tensor.cuda() 113 | 114 | self.idx += batch_size 115 | 116 | return tensor 117 | 118 | next = __next__ 119 | 120 | def __iter__(self): 121 | self.idx = 0 122 | return self 123 | 124 | if __name__ == '__main__': 125 | corpus = SentCorpus('../penn') 126 | loader = BatchSentLoader(corpus.test, 10) 127 | for i, d in enumerate(loader): 128 | print(i, d.size()) 129 | -------------------------------------------------------------------------------- /embed_regularize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | def embedded_dropout(embed, words, dropout=0.1, scale=None): 7 | if dropout: 8 | mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) 9 | mask = Variable(mask) 10 | masked_embed_weight = mask * embed.weight 11 | else: 12 | masked_embed_weight = embed.weight 13 | if scale: 14 | masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight 15 | 16 | padding_idx = embed.padding_idx 17 | if padding_idx is None: 18 | padding_idx = -1 19 | X = embed._backend.Embedding.apply(words, masked_embed_weight, 20 | padding_idx, embed.max_norm, embed.norm_type, 21 | embed.scale_grad_by_freq, embed.sparse 22 | ) 23 | return X 24 | 25 | if __name__ == '__main__': 26 | V = 50 27 | h = 4 28 | bptt = 10 29 | batch_size = 2 30 | 31 | embed = torch.nn.Embedding(V, h) 32 | 33 | words = np.random.random_integers(low=0, high=V-1, size=(batch_size, bptt)) 34 | words = torch.LongTensor(words) 35 | words = Variable(words) 36 | 37 | origX = embed(words) 38 | X = embedded_dropout(embed, words) 39 | 40 | print(origX) 41 | print(X) 42 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os, sys 4 | import math 5 | import numpy as np 6 | np.random.seed(331) 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | 11 | import data 12 | import model 13 | import os 14 | 15 | from utils import batchify, get_batch, repackage_hidden, create_exp_dir, save_checkpoint 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank/WikiText2 RNN/LSTM Language Model') 18 | parser.add_argument('--data', type=str, default='./penn/', 19 | help='location of the data corpus') 20 | parser.add_argument('--model', type=str, default='LSTM', 21 | help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)') 22 | parser.add_argument('--emsize', type=int, default=400, 23 | help='size of word embeddings') 24 | parser.add_argument('--nhid', type=int, default=1150, 25 | help='number of hidden units per layer') 26 | parser.add_argument('--nlayers', type=int, default=3, 27 | help='number of layers') 28 | parser.add_argument('--lr', type=float, default=30, 29 | help='initial learning rate') 30 | parser.add_argument('--clip', type=float, default=0.25, 31 | help='gradient clipping') 32 | parser.add_argument('--epochs', type=int, default=8000, 33 | help='upper epoch limit') 34 | parser.add_argument('--batch_size', type=int, default=80, metavar='N', 35 | help='batch size') 36 | parser.add_argument('--bptt', type=int, default=70, 37 | help='sequence length') 38 | parser.add_argument('--dropout', type=float, default=0.4, 39 | help='dropout applied to layers (0 = no dropout)') 40 | parser.add_argument('--dropouth', type=float, default=0.3, 41 | help='dropout for rnn layers (0 = no dropout)') 42 | parser.add_argument('--dropouti', type=float, default=0.65, 43 | help='dropout for input embedding layers (0 = no dropout)') 44 | parser.add_argument('--dropoute', type=float, default=0.1, 45 | help='dropout to remove words from embedding layer (0 = no dropout)') 46 | parser.add_argument('--dropoutl', type=float, default=-0.2, 47 | help='dropout applied to layers (0 = no dropout)') 48 | parser.add_argument('--wdrop', type=float, default=0.5, 49 | help='amount of weight dropout to apply to the RNN hidden to hidden matrix') 50 | parser.add_argument('--tied', action='store_false', 51 | help='tie the word embedding and softmax weights') 52 | parser.add_argument('--seed', type=int, default=1111, 53 | help='random seed') 54 | parser.add_argument('--nonmono', type=int, default=5, 55 | help='random seed') 56 | parser.add_argument('--cuda', action='store_false', 57 | help='use CUDA') 58 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 59 | help='report interval') 60 | parser.add_argument('--save', type=str, required=True, 61 | help='path to the directory that save the final model') 62 | parser.add_argument('--alpha', type=float, default=2, 63 | help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)') 64 | parser.add_argument('--beta', type=float, default=1, 65 | help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)') 66 | parser.add_argument('--var', type=float, default=0, 67 | help='regularization for prior') 68 | parser.add_argument('--wdecay', type=float, default=1.2e-6, 69 | help='weight decay applied to all weights') 70 | parser.add_argument('--continue_train', action='store_true', 71 | help='continue train from a checkpoint') 72 | parser.add_argument('--n_experts', type=int, default=10, 73 | help='number of experts') 74 | parser.add_argument('--small_batch_size', type=int, default=-1, 75 | help='the batch size for computation. batch_size should be divisible by small_batch_size.\ 76 | In our implementation, we compute gradients with small_batch_size multiple times, and accumulate the gradients\ 77 | until batch_size is reached. An update step is then performed.') 78 | parser.add_argument('--max_seq_len_delta', type=int, default=40, 79 | help='max sequence length') 80 | parser.add_argument('--single_gpu', default=False, action='store_true', help='use single GPU') 81 | args = parser.parse_args() 82 | 83 | print('finetune load path: {}/model.pt. '.format(args.save)) 84 | print('log save path: {}/finetune_log.txt'.format(args.save)) 85 | print('model save path: {}/finetune_model.pt'.format(args.save)) 86 | 87 | log_file = os.path.join(args.save, 'finetune_log.txt') 88 | if not args.continue_train: 89 | if os.path.exists(log_file): 90 | os.remove(log_file) 91 | 92 | def logging(s, print_=True, log_=True): 93 | if print_: 94 | print(s) 95 | if log_: 96 | with open(log_file, 'a+') as f_log: 97 | f_log.write(s + '\n') 98 | 99 | if args.dropoutl < 0: 100 | args.dropoutl = args.dropouth 101 | if args.small_batch_size < 0: 102 | args.small_batch_size = args.batch_size 103 | 104 | # Set the random seed manually for reproducibility. 105 | torch.manual_seed(args.seed) 106 | if torch.cuda.is_available(): 107 | if not args.cuda: 108 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 109 | else: 110 | torch.cuda.manual_seed_all(args.seed) 111 | 112 | ############################################################################### 113 | # Load data 114 | ############################################################################### 115 | 116 | corpus = data.Corpus(args.data) 117 | 118 | eval_batch_size = 10 119 | test_batch_size = 1 120 | train_data = batchify(corpus.train, args.batch_size, args) 121 | val_data = batchify(corpus.valid, eval_batch_size, args) 122 | test_data = batchify(corpus.test, test_batch_size, args) 123 | 124 | ############################################################################### 125 | # Build the model 126 | ############################################################################### 127 | 128 | ntokens = len(corpus.dictionary) 129 | if args.continue_train: 130 | model = torch.load(os.path.join(args.save, 'finetune_model.pt')) 131 | else: 132 | model = torch.load(os.path.join(args.save, 'model.pt')) 133 | if args.cuda: 134 | if args.single_gpu: 135 | parallel_model = model.cuda() 136 | else: 137 | parallel_model = nn.DataParallel(model, dim=1).cuda() 138 | else: 139 | parallel_model = model 140 | total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in model.parameters()) 141 | logging('Args: {}'.format(args)) 142 | logging('Model total parameters: {}'.format(total_params)) 143 | 144 | criterion = nn.CrossEntropyLoss() 145 | 146 | ############################################################################### 147 | # Training code 148 | ############################################################################### 149 | 150 | def evaluate(data_source, batch_size=10): 151 | # Turn on evaluation mode which disables dropout. 152 | model.eval() 153 | total_loss = 0 154 | ntokens = len(corpus.dictionary) 155 | hidden = model.init_hidden(batch_size) 156 | for i in range(0, data_source.size(0) - 1, args.bptt): 157 | data, targets = get_batch(data_source, i, args, evaluation=True) 158 | targets = targets.view(-1) 159 | 160 | log_prob, hidden = parallel_model(data, hidden) 161 | loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data 162 | 163 | total_loss += len(data) * loss 164 | hidden = repackage_hidden(hidden) 165 | return total_loss[0] / len(data_source) 166 | 167 | def train(): 168 | assert args.batch_size % args.small_batch_size == 0, 'batch_size must be divisible by small_batch_size' 169 | 170 | # Turn on training mode which enables dropout. 171 | total_loss = 0 172 | start_time = time.time() 173 | ntokens = len(corpus.dictionary) 174 | hidden = [model.init_hidden(args.small_batch_size) for _ in range(args.batch_size // args.small_batch_size)] 175 | batch, i = 0, 0 176 | while i < train_data.size(0) - 1 - 1: 177 | bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. 178 | # Prevent excessively small or negative sequence lengths 179 | seq_len = max(5, int(np.random.normal(bptt, 5))) 180 | # There's a very small chance that it could select a very long sequence length resulting in OOM 181 | seq_len = min(seq_len, args.bptt + args.max_seq_len_delta) 182 | 183 | lr2 = optimizer.param_groups[0]['lr'] 184 | optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt 185 | model.train() 186 | data, targets = get_batch(train_data, i, args, seq_len=seq_len) 187 | 188 | optimizer.zero_grad() 189 | 190 | start, end, s_id = 0, args.small_batch_size, 0 191 | while start < args.batch_size: 192 | cur_data, cur_targets = data[:, start: end], targets[:, start: end].contiguous().view(-1) 193 | 194 | # Starting each batch, we detach the hidden state from how it was previously produced. 195 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 196 | hidden[s_id] = repackage_hidden(hidden[s_id]) 197 | 198 | log_prob, hidden[s_id], rnn_hs, dropped_rnn_hs, prior = parallel_model(cur_data, hidden[s_id], return_h=True) 199 | raw_loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), cur_targets) 200 | 201 | loss = raw_loss 202 | # Activiation Regularization 203 | loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) 204 | # Temporal Activation Regularization (slowness) 205 | loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) 206 | #regularize for prior 207 | prior_sum = prior.sum(0) 208 | cv = (prior_sum.var() * (prior_sum.size(1) - 1)).sqrt() / prior_sum.mean() 209 | loss = loss + sum(args.var * cv * cv) 210 | loss *= args.small_batch_size / args.batch_size 211 | total_loss += raw_loss.data * args.small_batch_size / args.batch_size 212 | loss.backward() 213 | 214 | s_id += 1 215 | start = end 216 | end = start + args.small_batch_size 217 | 218 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 219 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) 220 | optimizer.step() 221 | 222 | # total_loss += raw_loss.data 223 | optimizer.param_groups[0]['lr'] = lr2 224 | if batch % args.log_interval == 0 and batch > 0: 225 | cur_loss = total_loss[0] / args.log_interval 226 | elapsed = time.time() - start_time 227 | logging('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 228 | 'loss {:5.2f} | ppl {:8.2f}'.format( 229 | epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], 230 | elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) 231 | total_loss = 0 232 | start_time = time.time() 233 | ### 234 | batch += 1 235 | i += seq_len 236 | 237 | # Loop over epochs. 238 | lr = args.lr 239 | stored_loss = evaluate(val_data) 240 | best_val_loss = [] 241 | # At any point you can hit Ctrl + C to break out of training early. 242 | try: 243 | #optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 244 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 245 | if args.continue_train: 246 | optimizer_state = torch.load(os.path.join(args.save, 'finetune_optimizer.pt')) 247 | optimizer.load_state_dict(optimizer_state) 248 | 249 | for epoch in range(1, args.epochs+1): 250 | epoch_start_time = time.time() 251 | train() 252 | if 't0' in optimizer.param_groups[0]: 253 | tmp = {} 254 | for prm in model.parameters(): 255 | tmp[prm] = prm.data.clone() 256 | prm.data = optimizer.state[prm]['ax'].clone() 257 | 258 | val_loss2 = evaluate(val_data) 259 | logging('-' * 89) 260 | logging('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 261 | 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 262 | val_loss2, math.exp(val_loss2))) 263 | logging('-' * 89) 264 | 265 | if val_loss2 < stored_loss: 266 | save_checkpoint(model, optimizer, args.save, finetune=True) 267 | logging('Saving Averaged!') 268 | stored_loss = val_loss2 269 | 270 | for prm in model.parameters(): 271 | prm.data = tmp[prm].clone() 272 | 273 | if (len(best_val_loss)>args.nonmono and val_loss2 > min(best_val_loss[:-args.nonmono])): 274 | logging('Done!') 275 | break 276 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 277 | #optimizer.param_groups[0]['lr'] /= 2. 278 | best_val_loss.append(val_loss2) 279 | 280 | except KeyboardInterrupt: 281 | logging('-' * 89) 282 | logging('Exiting from training early') 283 | 284 | # Load the best saved model. 285 | model = torch.load(os.path.join(args.save, 'finetune_model.pt')) 286 | parallel_model = nn.DataParallel(model, dim=1).cuda() 287 | 288 | # Run on test data. 289 | test_loss = evaluate(test_data, test_batch_size) 290 | logging('=' * 89) 291 | logging('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 292 | test_loss, math.exp(test_loss))) 293 | logging('=' * 89) 294 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Language Modeling on Penn Tree Bank 3 | # 4 | # This file generates new sentences sampled from the language model 5 | # 6 | ############################################################################### 7 | 8 | import argparse 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | 13 | import data 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch PTB Language Model') 16 | 17 | # Model parameters. 18 | parser.add_argument('--data', type=str, default='./penn', 19 | help='location of the data corpus') 20 | parser.add_argument('--checkpoint', type=str, default='./model.pt', 21 | help='model checkpoint to use') 22 | parser.add_argument('--outf', type=str, default='generated.txt', 23 | help='output file for generated text') 24 | parser.add_argument('--words', type=int, default='1000', 25 | help='number of words to generate') 26 | parser.add_argument('--seed', type=int, default=1111, 27 | help='random seed') 28 | parser.add_argument('--cuda', action='store_true', 29 | help='use CUDA') 30 | parser.add_argument('--temperature', type=float, default=1.0, 31 | help='temperature - higher will increase diversity') 32 | parser.add_argument('--log-interval', type=int, default=100, 33 | help='reporting interval') 34 | args = parser.parse_args() 35 | 36 | # Set the random seed manually for reproducibility. 37 | torch.manual_seed(args.seed) 38 | if torch.cuda.is_available(): 39 | if not args.cuda: 40 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 41 | else: 42 | torch.cuda.manual_seed(args.seed) 43 | 44 | if args.temperature < 1e-3: 45 | parser.error("--temperature has to be greater or equal 1e-3") 46 | 47 | with open(args.checkpoint, 'rb') as f: 48 | model = torch.load(f) 49 | model.eval() 50 | 51 | if args.cuda: 52 | model.cuda() 53 | else: 54 | model.cpu() 55 | 56 | corpus = data.Corpus(args.data) 57 | ntokens = len(corpus.dictionary) 58 | hidden = model.init_hidden(1) 59 | input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) 60 | if args.cuda: 61 | input.data = input.data.cuda() 62 | 63 | with open(args.outf, 'w') as outf: 64 | for i in range(args.words): 65 | output, hidden = model(input, hidden, return_prob=True) 66 | word_weights = output.squeeze().data.div(args.temperature).exp().cpu() 67 | word_idx = torch.multinomial(word_weights, 1)[0] 68 | input.data.fill_(word_idx) 69 | word = corpus.dictionary.idx2word[word_idx] 70 | 71 | outf.write(word + ('\n' if i % 20 == 19 else ' ')) 72 | 73 | if i % args.log_interval == 0: 74 | print('| Generated {}/{} words'.format(i, args.words)) 75 | -------------------------------------------------------------------------------- /get_data.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | cd data 3 | 4 | echo "- Downloading Penn Treebank (PTB)" 5 | mkdir -p penn 6 | cd penn 7 | wget --quiet --continue -O train.txt https://raw.githubusercontent.com/yangsaiyong/tf-adaptive-softmax-lstm-lm/master/ptb_data/ptb.train.txt 8 | wget --quiet --continue -O valid.txt https://raw.githubusercontent.com/yangsaiyong/tf-adaptive-softmax-lstm-lm/master/ptb_data/ptb.valid.txt 9 | wget --quiet --continue -O test.txt https://raw.githubusercontent.com/yangsaiyong/tf-adaptive-softmax-lstm-lm/master/ptb_data/ptb.test.txt 10 | 11 | cd .. 12 | 13 | echo "- Downloading WikiText-2 (WT2)" 14 | wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip 15 | unzip -q wikitext-2-v1.zip 16 | cd wikitext-2 17 | mv wiki.train.tokens train.txt 18 | mv wiki.valid.tokens valid.txt 19 | mv wiki.test.tokens test.txt 20 | 21 | echo "---" 22 | echo "Happy language modeling :)" 23 | 24 | cd .. 25 | -------------------------------------------------------------------------------- /locked_dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | class LockedDropout(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, x, dropout=0.5): 10 | if not self.training or not dropout: 11 | return x 12 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) 13 | # mask = Variable(m, requires_grad=False) / (1 - dropout) 14 | mask = Variable(m.div_(1 - dropout), requires_grad=False) 15 | mask = mask.expand_as(x) 16 | return mask * x 17 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | import time 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | import gc 12 | 13 | import data 14 | import model 15 | 16 | from utils import batchify, get_batch, repackage_hidden, create_exp_dir, save_checkpoint 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch PennTreeBank/WikiText2 RNN/LSTM Language Model') 19 | parser.add_argument('--data', type=str, default='./penn/', 20 | help='location of the data corpus') 21 | parser.add_argument('--model', type=str, default='LSTM', 22 | help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU, SRU)') 23 | parser.add_argument('--emsize', type=int, default=400, 24 | help='size of word embeddings') 25 | parser.add_argument('--nhid', type=int, default=1150, 26 | help='number of hidden units per layer') 27 | parser.add_argument('--nhidlast', type=int, default=-1, 28 | help='number of hidden units for the last rnn layer') 29 | parser.add_argument('--nlayers', type=int, default=3, 30 | help='number of layers') 31 | parser.add_argument('--lr', type=float, default=30, 32 | help='initial learning rate') 33 | parser.add_argument('--clip', type=float, default=0.25, 34 | help='gradient clipping') 35 | parser.add_argument('--epochs', type=int, default=8000, 36 | help='upper epoch limit') 37 | parser.add_argument('--batch_size', type=int, default=20, metavar='N', 38 | help='batch size') 39 | parser.add_argument('--bptt', type=int, default=70, 40 | help='sequence length') 41 | parser.add_argument('--dropout', type=float, default=0.4, 42 | help='dropout applied to layers (0 = no dropout)') 43 | parser.add_argument('--dropouth', type=float, default=0.3, 44 | help='dropout for rnn layers (0 = no dropout)') 45 | parser.add_argument('--dropouti', type=float, default=0.65, 46 | help='dropout for input embedding layers (0 = no dropout)') 47 | parser.add_argument('--dropoute', type=float, default=0.1, 48 | help='dropout to remove words from embedding layer (0 = no dropout)') 49 | parser.add_argument('--dropoutl', type=float, default=-0.2, 50 | help='dropout applied to layers (0 = no dropout)') 51 | parser.add_argument('--wdrop', type=float, default=0.5, 52 | help='amount of weight dropout to apply to the RNN hidden to hidden matrix') 53 | parser.add_argument('--tied', action='store_false', 54 | help='tie the word embedding and softmax weights') 55 | parser.add_argument('--seed', type=int, default=1111, 56 | help='random seed') 57 | parser.add_argument('--nonmono', type=int, default=5, 58 | help='random seed') 59 | parser.add_argument('--cuda', action='store_false', 60 | help='use CUDA') 61 | parser.add_argument('--log-interval', type=int, default=200, metavar='N', 62 | help='report interval') 63 | parser.add_argument('--save', type=str, default='EXP', 64 | help='path to save the final model') 65 | parser.add_argument('--alpha', type=float, default=2, 66 | help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)') 67 | parser.add_argument('--beta', type=float, default=1, 68 | help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)') 69 | parser.add_argument('--var', type=float, default=0, 70 | help='regularization for prior') 71 | parser.add_argument('--wdecay', type=float, default=1.2e-6, 72 | help='weight decay applied to all weights') 73 | parser.add_argument('--continue_train', action='store_true', 74 | help='continue train from a checkpoint') 75 | parser.add_argument('--n_experts', type=int, default=10, 76 | help='number of experts') 77 | parser.add_argument('--num4second', type=int, default=0, 78 | help='the number of softmax for second layer') 79 | parser.add_argument('--num4first', type=int, default=0, 80 | help='the number of softmax for first layer') 81 | parser.add_argument('--num4embed', type=int, default=0, 82 | help='the number of softmax for embeddings') 83 | parser.add_argument('--small_batch_size', type=int, default=-1, 84 | help='the batch size for computation. batch_size should be divisible by small_batch_size.\ 85 | In our implementation, we compute gradients with small_batch_size multiple times, and accumulate the gradients\ 86 | until batch_size is reached. An update step is then performed.') 87 | parser.add_argument('--max_seq_len_delta', type=int, default=40, 88 | help='max sequence length') 89 | parser.add_argument('--single_gpu', default=False, action='store_true', 90 | help='use single GPU') 91 | args = parser.parse_args() 92 | 93 | if args.nhidlast < 0: 94 | args.nhidlast = args.emsize 95 | if args.dropoutl < 0: 96 | args.dropoutl = args.dropouth 97 | if args.small_batch_size < 0: 98 | args.small_batch_size = args.batch_size 99 | 100 | if not args.continue_train: 101 | #args.save = '{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) 102 | args.save = '{}'.format(args.save) 103 | create_exp_dir(args.save, scripts_to_save=['main.py', 'model.py']) 104 | 105 | def logging(s, print_=True, log_=True): 106 | if print_: 107 | print(s) 108 | if log_: 109 | with open(os.path.join(args.save, 'log.txt'), 'a+') as f_log: 110 | f_log.write(s + '\n') 111 | 112 | # Set the random seed manually for reproducibility. 113 | np.random.seed(args.seed) 114 | torch.manual_seed(args.seed) 115 | if torch.cuda.is_available(): 116 | if not args.cuda: 117 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 118 | else: 119 | torch.cuda.manual_seed_all(args.seed) 120 | 121 | ############################################################################### 122 | # Load data 123 | ############################################################################### 124 | 125 | corpus = data.Corpus(args.data) 126 | 127 | eval_batch_size = 10 128 | test_batch_size = 1 129 | train_data = batchify(corpus.train, args.batch_size, args) 130 | val_data = batchify(corpus.valid, eval_batch_size, args) 131 | test_data = batchify(corpus.test, test_batch_size, args) 132 | 133 | ############################################################################### 134 | # Build the model 135 | ############################################################################### 136 | 137 | ntokens = len(corpus.dictionary) 138 | if args.continue_train: 139 | model = torch.load(os.path.join(args.save, 'model.pt')) 140 | else: 141 | model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nhidlast, args.nlayers, 142 | args.dropout, args.dropouth, args.dropouti, args.dropoute, args.wdrop, 143 | args.tied, args.dropoutl, args.n_experts, args.num4embed, args.num4first, args.num4second) 144 | 145 | if args.cuda: 146 | if args.single_gpu: 147 | parallel_model = model.cuda() 148 | else: 149 | parallel_model = nn.DataParallel(model, dim=1).cuda() 150 | else: 151 | parallel_model = model 152 | 153 | total_params = sum(x.data.nelement() for x in model.parameters()) 154 | logging('Args: {}'.format(args)) 155 | logging('Model total parameters: {}'.format(total_params)) 156 | 157 | criterion = nn.CrossEntropyLoss() 158 | 159 | ############################################################################### 160 | # Training code 161 | ############################################################################### 162 | 163 | def evaluate(data_source, batch_size=10): 164 | # Turn on evaluation mode which disables dropout. 165 | model.eval() 166 | total_loss = 0 167 | ntokens = len(corpus.dictionary) 168 | hidden = model.init_hidden(batch_size) 169 | for i in range(0, data_source.size(0) - 1, args.bptt): 170 | data, targets = get_batch(data_source, i, args, evaluation=True) 171 | targets = targets.view(-1) 172 | 173 | log_prob, hidden = parallel_model(data, hidden) 174 | loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data 175 | 176 | total_loss += loss * len(data) 177 | 178 | hidden = repackage_hidden(hidden) 179 | return total_loss[0] / len(data_source) 180 | 181 | 182 | def train(): 183 | assert args.batch_size % args.small_batch_size == 0, 'batch_size must be divisible by small_batch_size' 184 | 185 | # Turn on training mode which enables dropout. 186 | total_loss = 0 187 | start_time = time.time() 188 | ntokens = len(corpus.dictionary) 189 | hidden = [model.init_hidden(args.small_batch_size) for _ in range(args.batch_size // args.small_batch_size)] 190 | batch, i = 0, 0 191 | while i < train_data.size(0) - 1 - 1: 192 | bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. 193 | # Prevent excessively small or negative sequence lengths 194 | seq_len = max(5, int(np.random.normal(bptt, 5))) 195 | # There's a very small chance that it could select a very long sequence length resulting in OOM 196 | seq_len = min(seq_len, args.bptt + args.max_seq_len_delta) 197 | 198 | lr2 = optimizer.param_groups[0]['lr'] 199 | optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt 200 | model.train() 201 | data, targets = get_batch(train_data, i, args, seq_len=seq_len) 202 | 203 | optimizer.zero_grad() 204 | 205 | start, end, s_id = 0, args.small_batch_size, 0 206 | while start < args.batch_size: 207 | cur_data, cur_targets = data[:, start: end], targets[:, start: end].contiguous().view(-1) 208 | 209 | # Starting each batch, we detach the hidden state from how it was previously produced. 210 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 211 | hidden[s_id] = repackage_hidden(hidden[s_id]) 212 | 213 | log_prob, hidden[s_id], rnn_hs, dropped_rnn_hs, prior = parallel_model(cur_data, hidden[s_id], return_h=True) 214 | raw_loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), cur_targets) 215 | 216 | loss = raw_loss 217 | # Activiation Regularization 218 | loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) 219 | # Temporal Activation Regularization (slowness) 220 | loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) 221 | #regularize for prior 222 | prior_sum = prior.sum(0) 223 | cv = (prior_sum.var() * (prior_sum.size(1) - 1)).sqrt() / prior_sum.mean() 224 | loss = loss + sum(args.var * cv * cv) 225 | loss *= args.small_batch_size / args.batch_size 226 | total_loss += raw_loss.data * args.small_batch_size / args.batch_size 227 | loss.backward() 228 | 229 | s_id += 1 230 | start = end 231 | end = start + args.small_batch_size 232 | 233 | gc.collect() 234 | 235 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 236 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) 237 | optimizer.step() 238 | 239 | # total_loss += raw_loss.data 240 | optimizer.param_groups[0]['lr'] = lr2 241 | if batch % args.log_interval == 0 and batch > 0: 242 | cur_loss = total_loss[0] / args.log_interval 243 | elapsed = time.time() - start_time 244 | logging('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 245 | 'loss {:5.2f} | ppl {:8.2f}'.format( 246 | epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], 247 | elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) 248 | total_loss = 0 249 | start_time = time.time() 250 | ### 251 | batch += 1 252 | i += seq_len 253 | 254 | # Loop over epochs. 255 | lr = args.lr 256 | best_val_loss = [] 257 | stored_loss = 100000000 258 | 259 | # At any point you can hit Ctrl + C to break out of training early. 260 | try: 261 | if args.continue_train: 262 | optimizer_state = torch.load(os.path.join(args.save, 'optimizer.pt')) 263 | if 't0' in optimizer_state['param_groups'][0]: 264 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 265 | else: 266 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 267 | optimizer.load_state_dict(optimizer_state) 268 | else: 269 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 270 | 271 | for epoch in range(1, args.epochs+1): 272 | epoch_start_time = time.time() 273 | train() 274 | if 't0' in optimizer.param_groups[0]: 275 | tmp = {} 276 | for prm in model.parameters(): 277 | tmp[prm] = prm.data.clone() 278 | prm.data = optimizer.state[prm]['ax'].clone() 279 | 280 | val_loss2 = evaluate(val_data) 281 | logging('-' * 89) 282 | logging('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 283 | 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 284 | val_loss2, math.exp(val_loss2))) 285 | logging('-' * 89) 286 | 287 | if val_loss2 < stored_loss: 288 | save_checkpoint(model, optimizer, args.save) 289 | logging('Saving Averaged!') 290 | stored_loss = val_loss2 291 | 292 | for prm in model.parameters(): 293 | prm.data = tmp[prm].clone() 294 | 295 | else: 296 | val_loss = evaluate(val_data, eval_batch_size) 297 | logging('-' * 89) 298 | logging('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 299 | 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), 300 | val_loss, math.exp(val_loss))) 301 | logging('-' * 89) 302 | 303 | if val_loss < stored_loss: 304 | save_checkpoint(model, optimizer, args.save) 305 | logging('Saving Normal!') 306 | stored_loss = val_loss 307 | 308 | if 't0' not in optimizer.param_groups[0] and (len(best_val_loss)>args.nonmono and val_loss > min(best_val_loss[:-args.nonmono])): 309 | logging('Switching!') 310 | optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 311 | #optimizer.param_groups[0]['lr'] /= 2. 312 | best_val_loss.append(val_loss) 313 | 314 | except KeyboardInterrupt: 315 | logging('-' * 89) 316 | logging('Exiting from training early') 317 | 318 | # Load the best saved model. 319 | model = torch.load(os.path.join(args.save, 'model.pt')) 320 | parallel_model = nn.DataParallel(model, dim=1).cuda() 321 | 322 | # Run on test data. 323 | test_loss = evaluate(test_data, test_batch_size) 324 | logging('=' * 89) 325 | logging('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 326 | test_loss, math.exp(test_loss))) 327 | logging('=' * 89) 328 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | from embed_regularize import embedded_dropout 8 | from locked_dropout import LockedDropout 9 | from weight_drop import WeightDrop 10 | 11 | class RNNModel(nn.Module): 12 | """Container module with an encoder, a recurrent module, and a decoder.""" 13 | 14 | def __init__(self, rnn_type, ntoken, ninp, nhid, nhidlast, nlayers, 15 | dropout=0.5, dropouth=0.5, dropouti=0.5, dropoute=0.1, wdrop=0, 16 | tie_weights=False, ldropout=0.6, n_experts=10, num4embed=0, num4first=0, num4second=0): 17 | super(RNNModel, self).__init__() 18 | self.lockdrop = LockedDropout() 19 | self.encoder = nn.Embedding(ntoken, ninp) 20 | 21 | self.rnns = [torch.nn.LSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else nhidlast, 1, dropout=0) for l in range(nlayers)] 22 | if wdrop: 23 | self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns] 24 | self.rnns = torch.nn.ModuleList(self.rnns) 25 | 26 | self.all_experts = n_experts + num4embed + num4first + num4second 27 | self.prior = nn.Linear(nhidlast, self.all_experts, bias=False) 28 | self.latent = nn.Linear(nhidlast, n_experts*ninp) 29 | if num4embed > 0: 30 | self.weight4embed = nn.Linear(ninp, num4embed*ninp) 31 | if num4first > 0: 32 | self.weight4first = nn.Linear(nhid, num4first*ninp) 33 | if num4second > 0: 34 | self.weight4second = nn.Linear(nhid, num4second*ninp) 35 | self.decoder = nn.Linear(ninp, ntoken) 36 | 37 | # Optionally tie weights as in: 38 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 39 | # https://arxiv.org/abs/1608.05859 40 | # and 41 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 42 | # https://arxiv.org/abs/1611.01462 43 | if tie_weights: 44 | #if nhid != ninp: 45 | # raise ValueError('When using the tied flag, nhid must be equal to emsize') 46 | self.decoder.weight = self.encoder.weight 47 | 48 | self.num4embed = num4embed 49 | self.num4first = num4first 50 | self.num4second = num4second 51 | self.init_weights() 52 | 53 | self.rnn_type = rnn_type 54 | self.ninp = ninp 55 | self.nhid = nhid 56 | self.nhidlast = nhidlast 57 | self.nlayers = nlayers 58 | self.dropout = dropout 59 | self.dropouti = dropouti 60 | self.dropouth = dropouth 61 | self.dropoute = dropoute 62 | self.dropoutl = ldropout 63 | self.n_experts = n_experts 64 | self.ntoken = ntoken 65 | 66 | size = 0 67 | for p in self.parameters(): 68 | size += p.nelement() 69 | print('param size: {}'.format(size)) 70 | 71 | def init_weights(self): 72 | initrange = 0.1 73 | self.encoder.weight.data.uniform_(-initrange, initrange) 74 | self.decoder.bias.data.fill_(0) 75 | self.decoder.weight.data.uniform_(-initrange, initrange) 76 | self.latent.bias.data.fill_(0) 77 | if self. num4embed > 0: 78 | self.weight4embed.bias.data.fill_(0) 79 | if self.num4first > 0: 80 | self.weight4first.bias.data.fill_(0) 81 | if self.num4second > 0: 82 | self.weight4second.bias.data.fill_(0) 83 | 84 | def forward(self, input, hidden, return_h=False, return_prob=False): 85 | batch_size = input.size(1) 86 | 87 | emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0) 88 | #emb = self.idrop(emb) 89 | 90 | emb = self.lockdrop(emb, self.dropouti) 91 | list4mos = [] 92 | if self.num4embed > 0: 93 | embed4mos = nn.functional.tanh(self.weight4embed(emb)) 94 | embed4mos = embed4mos.view(emb.size(0), emb.size(1), self.num4embed, self.ninp).transpose(1, 2).transpose(1, 0).contiguous() 95 | embed4mos = embed4mos.view(-1, emb.size(1), self.ninp) 96 | list4mos.extend(list(torch.chunk(embed4mos, self.num4embed, 0))) 97 | 98 | raw_output = emb 99 | new_hidden = [] 100 | #raw_output, hidden = self.rnn(emb, hidden) 101 | raw_outputs = [] 102 | outputs = [] 103 | for l, rnn in enumerate(self.rnns): 104 | current_input = raw_output 105 | raw_output, new_h = rnn(raw_output, hidden[l]) 106 | new_hidden.append(new_h) 107 | raw_outputs.append(raw_output) 108 | if l != self.nlayers - 1: 109 | #self.hdrop(raw_output) 110 | raw_output = self.lockdrop(raw_output, self.dropouth) 111 | outputs.append(raw_output) 112 | if l == 0 and self.num4first > 0: 113 | first4mos = nn.functional.tanh(self.weight4first(raw_output)) 114 | first4mos = first4mos.view(raw_output.size(0), raw_output.size(1), self.num4first, self.ninp).transpose(1, 2).transpose(1, 0).contiguous() 115 | first4mos = first4mos.view(-1, raw_output.size(1), self.ninp) 116 | list4mos.extend(list(torch.chunk(first4mos, self.num4first, 0))) 117 | if l == 1 and self.num4second > 0: 118 | second4mos = nn.functional.tanh(self.weight4second(raw_output)) 119 | second4mos = second4mos.view(raw_output.size(0), raw_output.size(1), self.num4second, self.ninp).transpose(1, 2).transpose(1, 0).contiguous() 120 | second4mos = second4mos.view(-1, raw_output.size(1), self.ninp) 121 | list4mos.extend(list(torch.chunk(second4mos, self.num4second, 0))) 122 | hidden = new_hidden 123 | 124 | output = self.lockdrop(raw_output, self.dropout) 125 | outputs.append(output) 126 | 127 | latent = nn.functional.tanh(self.latent(output)) 128 | #apply same mask to all context vec 129 | transd = latent.view(raw_output.size(0), raw_output.size(1), self.n_experts, -1).transpose(1, 2).transpose(1, 0).contiguous().view(-1, raw_output.size(1), self.ninp) 130 | list4mos.extend(list(torch.chunk(transd, self.n_experts, 0))) 131 | concated = torch.cat(list4mos, 1) 132 | dropped = self.lockdrop(concated.view(-1, raw_output.size(1), self.ninp), self.dropoutl) 133 | contextvec = dropped.view(raw_output.size(0), self.all_experts, raw_output.size(1), self.ninp).transpose(1, 2).contiguous() 134 | logit = self.decoder(contextvec.view(-1, self.ninp)) 135 | 136 | prior_logit = self.prior(output).view(-1, self.all_experts) 137 | prior = nn.functional.softmax(prior_logit) 138 | 139 | prob = nn.functional.softmax(logit.view(-1, self.ntoken)).view(-1, self.all_experts, self.ntoken) 140 | prob = (prob * prior.unsqueeze(2).expand_as(prob)).sum(1) 141 | 142 | if return_prob: 143 | model_output = prob 144 | else: 145 | log_prob = torch.log(prob.add_(1e-8)) 146 | model_output = log_prob 147 | 148 | model_output = model_output.view(-1, batch_size, self.ntoken) 149 | prior = prior.view(-1, batch_size, self.all_experts) 150 | 151 | if return_h: 152 | return model_output, hidden, raw_outputs, outputs, prior 153 | return model_output, hidden 154 | 155 | def init_hidden(self, bsz): 156 | weight = next(self.parameters()).data 157 | return [(Variable(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else self.nhidlast).zero_()), 158 | Variable(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else self.nhidlast).zero_())) 159 | for l in range(self.nlayers)] 160 | 161 | if __name__ == '__main__': 162 | model = RNNModel('LSTM', 10, 12, 12, 12, 2) 163 | input = Variable(torch.LongTensor(13, 9).random_(0, 10)) 164 | hidden = model.init_hidden(9) 165 | model(input, hidden) 166 | 167 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | def repackage_hidden(h): 6 | """Wraps hidden states in new Variables, to detach them from their history.""" 7 | if type(h) == Variable: 8 | return Variable(h.data) 9 | else: 10 | return tuple(repackage_hidden(v) for v in h) 11 | 12 | def batchify(data, bsz, args): 13 | # Work out how cleanly we can divide the dataset into bsz parts. 14 | nbatch = data.size(0) // bsz 15 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 16 | data = data.narrow(0, 0, nbatch * bsz) 17 | # Evenly divide the data across the bsz batches. 18 | data = data.view(bsz, -1).t().contiguous() 19 | print(data.size()) 20 | if args.cuda: 21 | data = data.cuda() 22 | return data 23 | 24 | def get_batch(source, i, args, seq_len=None, evaluation=False): 25 | seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i) 26 | data = Variable(source[i:i+seq_len], volatile=evaluation) 27 | # target = Variable(source[i+1:i+1+seq_len].view(-1)) 28 | target = Variable(source[i+1:i+1+seq_len]) 29 | return data, target 30 | 31 | def create_exp_dir(path, scripts_to_save=None): 32 | if not os.path.exists(path): 33 | os.mkdir(path) 34 | 35 | print('Experiment dir : {}'.format(path)) 36 | if scripts_to_save is not None: 37 | os.makedirs(os.path.join(path, 'scripts'), exist_ok=True) 38 | for script in scripts_to_save: 39 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 40 | shutil.copyfile(script, dst_file) 41 | 42 | def save_checkpoint(model, optimizer, path, finetune=False): 43 | if finetune: 44 | torch.save(model, os.path.join(path, 'finetune_model.pt')) 45 | torch.save(optimizer.state_dict(), os.path.join(path, 'finetune_optimizer.pt')) 46 | else: 47 | torch.save(model, os.path.join(path, 'model.pt')) 48 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer.pt')) 49 | -------------------------------------------------------------------------------- /weight_drop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from functools import wraps 4 | 5 | class WeightDrop(torch.nn.Module): 6 | def __init__(self, module, weights, dropout=0, variational=False): 7 | super(WeightDrop, self).__init__() 8 | self.module = module 9 | self.weights = weights 10 | self.dropout = dropout 11 | self.variational = variational 12 | self._setup() 13 | 14 | def widget_demagnetizer_y2k_edition(*args, **kwargs): 15 | # We need to replace flatten_parameters with a nothing function 16 | # It must be a function rather than a lambda as otherwise pickling explodes 17 | # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION! 18 | # (╯°□°)╯︵ ┻━┻ 19 | return 20 | 21 | def _setup(self): 22 | # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN 23 | if issubclass(type(self.module), torch.nn.RNNBase): 24 | self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition 25 | 26 | for name_w in self.weights: 27 | print('Applying weight drop of {} to {}'.format(self.dropout, name_w)) 28 | w = getattr(self.module, name_w) 29 | del self.module._parameters[name_w] 30 | self.module.register_parameter(name_w + '_raw', Parameter(w.data)) 31 | 32 | def _setweights(self): 33 | for name_w in self.weights: 34 | raw_w = getattr(self.module, name_w + '_raw') 35 | w = None 36 | if self.variational: 37 | mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1)) 38 | if raw_w.is_cuda: mask = mask.cuda() 39 | mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True) 40 | w = mask.expand_as(raw_w) * raw_w 41 | else: 42 | w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training) 43 | setattr(self.module, name_w, w) 44 | 45 | def forward(self, *args): 46 | self._setweights() 47 | return self.module.forward(*args) 48 | 49 | if __name__ == '__main__': 50 | import torch 51 | from weight_drop import WeightDrop 52 | 53 | # Input is (seq, batch, input) 54 | x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda() 55 | h0 = None 56 | 57 | ### 58 | 59 | print('Testing WeightDrop') 60 | print('=-=-=-=-=-=-=-=-=-=') 61 | 62 | ### 63 | 64 | print('Testing WeightDrop with Linear') 65 | 66 | lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9) 67 | lin.cuda() 68 | run1 = [x.sum() for x in lin(x).data] 69 | run2 = [x.sum() for x in lin(x).data] 70 | 71 | print('All items should be different') 72 | print('Run 1:', run1) 73 | print('Run 2:', run2) 74 | 75 | assert run1[0] != run2[0] 76 | assert run1[1] != run2[1] 77 | 78 | print('---') 79 | 80 | ### 81 | 82 | print('Testing WeightDrop with LSTM') 83 | 84 | wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9) 85 | wdrnn.cuda() 86 | 87 | run1 = [x.sum() for x in wdrnn(x, h0)[0].data] 88 | run2 = [x.sum() for x in wdrnn(x, h0)[0].data] 89 | 90 | print('First timesteps should be equal, all others should differ') 91 | print('Run 1:', run1) 92 | print('Run 2:', run2) 93 | 94 | # First time step, not influenced by hidden to hidden weights, should be equal 95 | assert run1[0] == run2[0] 96 | # Second step should not 97 | assert run1[1] != run2[1] 98 | 99 | print('---') 100 | --------------------------------------------------------------------------------