├── LICENSE ├── README.md ├── data.py ├── data └── wikitext-2 │ ├── README │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── generate_rmc.py ├── generate_rnn.py ├── pics ├── nth_results.jpg ├── rmc.png └── rmc_paper_result.png ├── relational_rnn_general.py ├── relational_rnn_models.py ├── requirements.txt ├── rnn_models.py ├── train_embeddings.py ├── train_nth_farthest.py ├── train_rmc.py └── train_rnn.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # relational-rnn-pytorch 2 | 3 | An implementation of DeepMind's [Relational Recurrent Neural Networks](https://arxiv.org/abs/1806.01822) (Santoro et al. 2018) in PyTorch. 4 | 5 | ![](./pics/rmc.png) 6 | ![](./pics/rmc_paper_result.png) 7 | 8 | 9 | Relational Memory Core (RMC) module is originally from [official Sonnet implementation](https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/relational_memory.py). However, currently they do not provide a full language modeling benchmark code. 10 | 11 | This repo is a port of RMC with additional comments. It features a full-fledged word language modeling benchmark vs. traditional LSTM. 12 | 13 | It supports any arbitrary word token-based text dataset, including WikiText-2 & WikiText-103. 14 | 15 | Both RMC & LSTM models support [adaptive softmax](https://pytorch.org/docs/stable/nn.html#adaptivelogsoftmaxwithloss) for much lower memory usage of large vocabulary dataset. RMC supports PyTorch's `DataParallel`, so you can easily experiment with a multi-GPU setup. 16 | 17 | benchmark codes are hard-forked from [official PyTorch word-language-model example](https://github.com/pytorch/examples/tree/master/word_language_model) 18 | 19 | It also features an N-th farthest synthetic task from the paper (see below). 20 | 21 | # Requirements 22 | PyTorch 0.4.1 or later (Tested on 1.0.0) & Python 3.6 23 | 24 | # Examples 25 | `python train_rmc.py --cuda ` for full training & test run of RMC with GPU. 26 | 27 | `python train_rmc.py --cuda --adaptivesoftmax --cutoffs 1000 5000 20000` if using large vocabulary dataset (like WikiText-103) to fit all the tensors in the VRAM. 28 | 29 | `python generate_rmc.py --cuda` for generating sentences from the trained model. 30 | 31 | `python train_rnn.py --cuda` for full training & test run of traditional RNN with GPU. 32 | 33 | All default hyperparameters of RMC & LSTM are results from a two-week experiment using WikiText-2. 34 | 35 | # Data Preparation 36 | Tested with WikiText-2 and WikiText-103. WikiText-2 is bundled. 37 | 38 | Create a subfolder inside `./data` and place word-level `train.txt`, `valid.txt`, and `test.txt` inside the subfolder. 39 | 40 | Specify `--data=(subfolder name)` and you are good to go. 41 | 42 | The code performs tokenization at the first training run, and the corpus is saved as `pickle`. The code will load the `pickle` file after the first run. 43 | 44 | # WikiText-2 Benchmark Results 45 | Both RMC & LSTM have ~11M parameters. Please refer to the training code for details on hyperparameters. 46 | 47 | | Models | Valid Perplexity|Test Perplexity | Forward pass ms/batch (TITAN Xp) | Forward pass ms/batch (TITAN V) | 48 | |:-------------:|:-------------:|:-------------:| :-------------:| :-------------:| 49 | | LSTM (CuDNN) |111.31 | 105.56 | 26~27 | 40~41 | 50 | | LSTM (For Loop) |Same as CuDNN | Same as CuDNN | 30~31 | 60~61 | 51 | | RMC | 112.77 | 107.21 | 110~130 | 220~230| 52 | 53 | RMC can reach a comparable performance to LSTM (with heavy hyperparameter search), but it turns out that the RMC is very slow. The multi-head self-attention at every time step may be the culprit here. 54 | Using LSTMCell with for loop (which is more "fair" benchmark for RMC) slows down the forward pass, but it's still much faster. 55 | 56 | Please also note that the hyperparameter for RMC is a worst-case scenario in terms of speed, because it used a single memory slot (as described in the paper) and did not benefit from a row-wise weight sharing from multi-slot memory. 57 | 58 | Interesting to note here is that the speed is slower in TITAN V than TITAN Xp. The reason might be that the models are relatively small and the model calls small linear operations frequently. 59 | 60 | Maybe TITAN Xp (~1,900Mhz unlocked CUDA clock speed vs. TITAN V's 1,335Mhz limit) benefits from these kind of workload. Or maybe TITAN V's CUDA kernel launch latency is higher for the ops in the model. 61 | 62 | I'm not an expert in details of CUDA. Please share your results! 63 | 64 | # RMC Hyperparameter Search Results 65 | Attention parameters tend to overfit the WikiText-2. reducing the hyperparmeters for attention (key_size) can combat the overfitting. 66 | 67 | Applying dropout at the output logit before the softmax (like the LSTM one) helped preventing the overfitting. 68 | 69 | |embed & head size| # heads | attention MLP layers | key size | dropout at output | memory slots | test ppl| 70 | |:----:|:----:|:----:|:----:|:----:|:----:|:----:| 71 | |128| 4| 3| 128| No| 1| 128.81 | 72 | |128| 4| 3| 128| No| 1| 128.81 | 73 | |128| 8| 3| 128| No| 1| 141.84 | 74 | |128| 4| 3| 32| No |1 |123.26 | 75 | |128| 4| 3| 32| Yes| 1| 112.4 | 76 | |128| 4| 3| 64| No |1 |124.44 | 77 | |128| 4| 3| 64| Yes| 1| 110.16 | 78 | |128| 4| 2| 64| Yes| 1| 111.67 | 79 | |64 |4 |3 |64 |Yes |1 |133.68 | 80 | |64 |4 |3 |32 |Yes |1 |135.93 | 81 | |64 |4 |3 |64 |Yes |4 |137.93 | 82 | |192| 4| 3| 64| Yes| 1| **107.21** | 83 | |192| 4| 3| 64| Yes| 4| 114.85 | 84 | |256| 4| 3| 256| No| 1| 194.73 | 85 | |256| 4| 3| 64| Yes| 1| 126.39 | 86 | 87 | 88 | # About WikiText-103 89 | The original RMC paper presents WikiText-103 results with a larger model & batch size (6 Tesla P100, each with 64 batch size, so a total of 384. Ouch). 90 | 91 | Using a full softmax easily blows up the VRAM. Using `--adaptivesoftmax` is highly recommended. If using `--adaptivesoftmax`, `--cutoffs` should be properly provided. Please refer to the [original API description](https://pytorch.org/docs/stable/nn.html#adaptivelogsoftmaxwithloss) 92 | 93 | I don't have such hardware and my resource is too limited to do the experiments. Benchmark result, or any other contributions are very welcome! 94 | 95 | # Nth Farthest Task 96 | 97 | The objective of the task is: Given k randomly labelled (from 1 to k) D-dimensional vectors, identify which is the Nth farthest vector from vector M. (The answer is an integer from 1 to k.) 98 | 99 | The specific task in the paper is: given 8 labelled 16-dimensional vectors, which is the Nth farthest vector from vector M? The vectors are labelled randomly so the model has to recognise that the Mth vector is the vector labelled as M as opposed to the vector in the Mth position in the input. 100 | 101 | The input to the model comprises 8 40-dimensional vectors for each example. Each of these 40-dimensional vectors is structured like this: 102 | 103 | ``` 104 | [(vector 1) (label: which vector is it, from 1 to 8, one-hot encoded) (N, one-hot encoded) (M, one-hot encoded)] 105 | ``` 106 | 107 | #### Example 108 | 109 | `python train_nth_farthest.py --cuda` for training and testing on the Nth Farthest Task with GPU(s). 110 | 111 | This uses the `RelationalMemory` class in `relational_rnn_general.py`, which is a version of `relational_rnn_models.py` without the language-modelling specific code. 112 | 113 | Please refer to`train_nth_farthest.py` for details on hyperparameter values. These are taken from Appendix A1 in the paper and from the Sonnet implementation when the hyperparameter values are not given in the paper. 114 | 115 | Note: new examples are generated per epoch as in the Sonnet implementation. This seems to be consistent with the paper, which does not specify the number of examples used. 116 | 117 | #### Experiment results 118 | 119 | The model has been trained with a single TITAN Xp GPU for forever until it reaches 91% test accuracy. Below are the results with 3 independent runs: 120 | ![](./pics/nth_results.jpg) 121 | 122 | The model does break the 25% barrier if trained long enough, but the wall clock time is roughly over 2~3x longer than those reported in the paper. 123 | 124 | #### TODO 125 | 126 | Experiment with different hyperparameters 127 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class Dictionary(object): 7 | def __init__(self): 8 | self.word2idx = {} 9 | self.idx2word = [] 10 | self.idx2count = [] 11 | 12 | def add_word(self, word): 13 | if word not in self.word2idx: 14 | self.idx2word.append(word) 15 | self.idx2count.append(1) 16 | self.word2idx[word] = len(self.idx2word) - 1 17 | else: 18 | self.idx2count[self.word2idx[word]] += 1 19 | return self.word2idx[word] 20 | 21 | def __len__(self): 22 | return len(self.idx2word) 23 | 24 | 25 | class Corpus(object): 26 | def __init__(self, path): 27 | self.dictionary = Dictionary() 28 | tokens_train = self.add_corpus(os.path.join(path, 'train.txt')) 29 | tokens_valid = self.add_corpus(os.path.join(path, 'valid.txt')) 30 | tokens_test = self.add_corpus(os.path.join(path, 'test.txt')) 31 | 32 | # sort the words by word frequency in descending order 33 | # this is for using adaptive softmax: it assumes that the most frequent word get index 0 34 | idx_argsorted = np.flip(np.argsort(self.dictionary.idx2count), axis=-1) 35 | 36 | # re-create given the sorted ones 37 | self.dictionary.idx2count = np.array(self.dictionary.idx2count)[idx_argsorted].tolist() 38 | self.dictionary.idx2word = np.array(self.dictionary.idx2word)[idx_argsorted].tolist() 39 | self.dictionary.word2idx = dict(zip(self.dictionary.idx2word, 40 | np.arange(len(self.dictionary.idx2word)).tolist())) 41 | 42 | self.train = self.tokenize(os.path.join(path, 'train.txt'), tokens_train) 43 | self.valid = self.tokenize(os.path.join(path, 'valid.txt'), tokens_valid) 44 | self.test = self.tokenize(os.path.join(path, 'test.txt'), tokens_test) 45 | 46 | def add_corpus(self, path): 47 | """Tokenizes a text file.""" 48 | assert os.path.exists(path) 49 | # Add words to the dictionary 50 | with open(path, 'r', encoding="utf8") as f: 51 | tokens = 0 52 | for line in f: 53 | words = line.split() + [''] 54 | tokens += len(words) 55 | for word in words: 56 | self.dictionary.add_word(word) 57 | 58 | return tokens 59 | 60 | def tokenize(self, path, tokens): 61 | # Tokenize file content 62 | with open(path, 'r', encoding="utf8") as f: 63 | ids = torch.LongTensor(tokens) 64 | token = 0 65 | for line in f: 66 | words = line.split() + [''] 67 | for word in words: 68 | ids[token] = self.dictionary.word2idx[word] 69 | token += 1 70 | 71 | return ids 72 | -------------------------------------------------------------------------------- /data/wikitext-2/README: -------------------------------------------------------------------------------- 1 | This is raw data from the wikitext-2 dataset. 2 | 3 | See https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ 4 | -------------------------------------------------------------------------------- /generate_rmc.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | 3 | # This file generates new sentences sampled from the language model 4 | 5 | ############################################################################### 6 | 7 | import argparse 8 | 9 | import torch 10 | import pickle 11 | import data 12 | import os 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Language Model') 15 | 16 | # Model parameters. 17 | parser.add_argument('--data', type=str, default='./data/wikitext-2', 18 | help='location of the data corpus') 19 | parser.add_argument('--checkpoint', type=str, default=None, 20 | help='model checkpoint to use') 21 | parser.add_argument('--outf', type=str, default='generated.txt', 22 | help='output file for generated text') 23 | parser.add_argument('--words', type=int, default='1000', 24 | help='number of words to generate') 25 | parser.add_argument('--seed', type=int, default=1111, 26 | help='random seed') 27 | parser.add_argument('--cuda', action='store_true', 28 | help='use CUDA') 29 | parser.add_argument('--temperature', type=float, default=1., 30 | help='temperature - higher will increase diversity') 31 | parser.add_argument('--log-interval', type=int, default=100, 32 | help='reporting interval') 33 | args = parser.parse_args() 34 | 35 | if args.checkpoint is None: 36 | raise ValueError("--checkpoint not provided. specify model_dump_(epoch).pt") 37 | 38 | # Set the random seed manually for reproducibility. 39 | torch.manual_seed(args.seed) 40 | 41 | if torch.cuda.is_available(): 42 | if not args.cuda: 43 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 44 | 45 | device = torch.device("cuda" if args.cuda else "cpu") 46 | 47 | if args.temperature < 1e-3: 48 | parser.error("--temperature has to be greater or equal 1e-3") 49 | 50 | with open(args.checkpoint, 'rb') as f: 51 | model = torch.load(f).to(device) 52 | model.eval() 53 | 54 | corpus_name = os.path.basename(os.path.normpath(args.data)) 55 | corpus_filename = './data/corpus-' + str(corpus_name) + str('.pkl') 56 | if os.path.isfile(corpus_filename): 57 | print("loading pre-built " + str(corpus_name) + " corpus file...") 58 | loadfile = open(corpus_filename, 'rb') 59 | corpus = pickle.load(loadfile) 60 | loadfile.close() 61 | else: 62 | print("building " + str(corpus_name) + " corpus...") 63 | corpus = data.Corpus(args.data) 64 | # save the corpus for later 65 | savefile = open(corpus_filename, 'wb') 66 | pickle.dump(corpus, savefile) 67 | savefile.close() 68 | print("corpus saved to pickle") 69 | 70 | ntokens = len(corpus.dictionary) 71 | memory = model.module.initial_state(1, trainable=False).to(device) 72 | 73 | input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) 74 | 75 | with open(args.outf, 'w') as outf: 76 | with torch.no_grad(): # no tracking history 77 | for i in range(args.words): 78 | output, _, memory = model(input, memory, None, require_logits=True) 79 | word_weights = output.squeeze().div(args.temperature).exp().cpu() 80 | word_idx = torch.multinomial(word_weights, 1)[0] 81 | input.fill_(word_idx) 82 | word = corpus.dictionary.idx2word[word_idx] 83 | 84 | outf.write(word + ('\n' if i % 20 == 19 else ' ')) 85 | 86 | if i % args.log_interval == 0: 87 | print('| Generated {}/{} words'.format(i, args.words)) 88 | -------------------------------------------------------------------------------- /generate_rnn.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | 3 | # This file generates new sentences sampled from the language model 4 | # 5 | ############################################################################### 6 | 7 | import argparse 8 | 9 | import torch 10 | import pickle 11 | import data 12 | import os 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 Language Model') 15 | 16 | # Model parameters. 17 | parser.add_argument('--data', type=str, default='./data/wikitext-2', 18 | help='location of the data corpus') 19 | parser.add_argument('--checkpoint', type=str, default=None, 20 | help='model checkpoint to use') 21 | parser.add_argument('--outf', type=str, default='generated.txt', 22 | help='output file for generated text') 23 | parser.add_argument('--words', type=int, default='1000', 24 | help='number of words to generate') 25 | parser.add_argument('--seed', type=int, default=1111, 26 | help='random seed') 27 | parser.add_argument('--cuda', action='store_true', 28 | help='use CUDA') 29 | parser.add_argument('--temperature', type=float, default=1., 30 | help='temperature - higher will increase diversity') 31 | parser.add_argument('--log-interval', type=int, default=100, 32 | help='reporting interval') 33 | args = parser.parse_args() 34 | 35 | if args.checkpoint is None: 36 | raise ValueError("--checkpoint not provided. specify model_dump_(epoch).pt") 37 | 38 | # Set the random seed manually for reproducibility. 39 | torch.manual_seed(args.seed) 40 | 41 | if torch.cuda.is_available(): 42 | if not args.cuda: 43 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 44 | 45 | device = torch.device("cuda" if args.cuda else "cpu") 46 | 47 | if args.temperature < 1e-3: 48 | parser.error("--temperature has to be greater or equal 1e-3") 49 | 50 | with open(args.checkpoint, 'rb') as f: 51 | model = torch.load(f).to(device) 52 | model.eval() 53 | 54 | corpus_name = os.path.basename(os.path.normpath(args.data)) 55 | corpus_filename = './data/corpus-' + str(corpus_name) + str('.pkl') 56 | if os.path.isfile(corpus_filename): 57 | print("loading pre-built " + str(corpus_name) + " corpus file...") 58 | loadfile = open(corpus_filename, 'rb') 59 | corpus = pickle.load(loadfile) 60 | loadfile.close() 61 | else: 62 | print("building " + str(corpus_name) + " corpus...") 63 | corpus = data.Corpus(args.data) 64 | # save the corpus for later 65 | savefile = open(corpus_filename, 'wb') 66 | pickle.dump(corpus, savefile) 67 | savefile.close() 68 | print("corpus saved to pickle") 69 | 70 | ntokens = len(corpus.dictionary) 71 | hidden = model.init_hidden(1) 72 | input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) 73 | 74 | with open(args.outf, 'w') as outf: 75 | with torch.no_grad(): # no tracking history 76 | for i in range(args.words): 77 | output, hidden = model(input, hidden) 78 | word_weights = output.squeeze().div(args.temperature).exp().cpu() 79 | word_idx = torch.multinomial(word_weights, 1)[0] 80 | input.fill_(word_idx) 81 | word = corpus.dictionary.idx2word[word_idx] 82 | 83 | outf.write(word + ('\n' if i % 20 == 19 else ' ')) 84 | 85 | if i % args.log_interval == 0: 86 | print('| Generated {}/{} words'.format(i, args.words)) 87 | -------------------------------------------------------------------------------- /pics/nth_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L0SG/relational-rnn-pytorch/1b16ae32988625b16b95f920b0f6fe55ed4e45e7/pics/nth_results.jpg -------------------------------------------------------------------------------- /pics/rmc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L0SG/relational-rnn-pytorch/1b16ae32988625b16b95f920b0f6fe55ed4e45e7/pics/rmc.png -------------------------------------------------------------------------------- /pics/rmc_paper_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/L0SG/relational-rnn-pytorch/1b16ae32988625b16b95f920b0f6fe55ed4e45e7/pics/rmc_paper_result.png -------------------------------------------------------------------------------- /relational_rnn_general.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | # this class largely follows the official sonnet implementation 8 | # https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/relational_memory.py 9 | 10 | 11 | class RelationalMemory(nn.Module): 12 | """ 13 | Constructs a `RelationalMemory` object. 14 | This class is same as the RMC from relational_rnn_models.py, but without language modeling-specific variables. 15 | Args: 16 | mem_slots: The total number of memory slots to use. 17 | head_size: The size of an attention head. 18 | input_size: The size of input per step. i.e. the dimension of each input vector 19 | num_heads: The number of attention heads to use. Defaults to 1. 20 | num_blocks: Number of times to compute attention per time step. Defaults 21 | to 1. 22 | forget_bias: Bias to use for the forget gate, assuming we are using 23 | some form of gating. Defaults to 1. 24 | input_bias: Bias to use for the input gate, assuming we are using 25 | some form of gating. Defaults to 0. 26 | gate_style: Whether to use per-element gating ('unit'), 27 | per-memory slot gating ('memory'), or no gating at all (None). 28 | Defaults to `unit`. 29 | attention_mlp_layers: Number of layers to use in the post-attention 30 | MLP. Defaults to 2. 31 | key_size: Size of vector to use for key & query vectors in the attention 32 | computation. Defaults to None, in which case we use `head_size`. 33 | name: Name of the module. 34 | 35 | # NEW flag for this class 36 | return_all_outputs: Whether the model returns outputs for each step (like seq2seq) or only the final output. 37 | Raises: 38 | ValueError: gate_style not one of [None, 'memory', 'unit']. 39 | ValueError: num_blocks is < 1. 40 | ValueError: attention_mlp_layers is < 1. 41 | """ 42 | 43 | def __init__(self, mem_slots, head_size, input_size, num_heads=1, num_blocks=1, forget_bias=1., input_bias=0., 44 | gate_style='unit', attention_mlp_layers=2, key_size=None, return_all_outputs=False): 45 | super(RelationalMemory, self).__init__() 46 | 47 | ########## generic parameters for RMC ########## 48 | self.mem_slots = mem_slots 49 | self.head_size = head_size 50 | self.num_heads = num_heads 51 | self.mem_size = self.head_size * self.num_heads 52 | 53 | # a new fixed params needed for pytorch port of RMC 54 | # +1 is the concatenated input per time step : we do self-attention with the concatenated memory & input 55 | # so if the mem_slots = 1, this value is 2 56 | self.mem_slots_plus_input = self.mem_slots + 1 57 | 58 | if num_blocks < 1: 59 | raise ValueError('num_blocks must be >=1. Got: {}.'.format(num_blocks)) 60 | self.num_blocks = num_blocks 61 | 62 | if gate_style not in ['unit', 'memory', None]: 63 | raise ValueError( 64 | 'gate_style must be one of [\'unit\', \'memory\', None]. got: ' 65 | '{}.'.format(gate_style)) 66 | self.gate_style = gate_style 67 | 68 | if attention_mlp_layers < 1: 69 | raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format( 70 | attention_mlp_layers)) 71 | self.attention_mlp_layers = attention_mlp_layers 72 | 73 | self.key_size = key_size if key_size else self.head_size 74 | 75 | ########## parameters for multihead attention ########## 76 | # value_size is same as head_size 77 | self.value_size = self.head_size 78 | # total size for query-key-value 79 | self.qkv_size = 2 * self.key_size + self.value_size 80 | self.total_qkv_size = self.qkv_size * self.num_heads # denoted as F 81 | 82 | # each head has qkv_sized linear projector 83 | # just using one big param is more efficient, rather than this line 84 | # self.qkv_projector = [nn.Parameter(torch.randn((self.qkv_size, self.qkv_size))) for _ in range(self.num_heads)] 85 | self.qkv_projector = nn.Linear(self.mem_size, self.total_qkv_size) 86 | self.qkv_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.total_qkv_size]) 87 | 88 | # used for attend_over_memory function 89 | self.attention_mlp = nn.ModuleList([nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers) 90 | self.attended_memory_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size]) 91 | self.attended_memory_layernorm2 = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size]) 92 | 93 | ########## parameters for initial embedded input projection ########## 94 | self.input_size = input_size 95 | self.input_projector = nn.Linear(self.input_size, self.mem_size) 96 | 97 | ########## parameters for gating ########## 98 | self.num_gates = 2 * self.calculate_gate_size() 99 | self.input_gate_projector = nn.Linear(self.mem_size, self.num_gates) 100 | self.memory_gate_projector = nn.Linear(self.mem_size, self.num_gates) 101 | # trainable scalar gate bias tensors 102 | self.forget_bias = nn.Parameter(torch.tensor(forget_bias, dtype=torch.float32)) 103 | self.input_bias = nn.Parameter(torch.tensor(input_bias, dtype=torch.float32)) 104 | 105 | ########## number of outputs returned ##### 106 | self.return_all_outputs = return_all_outputs 107 | 108 | def repackage_hidden(self, h): 109 | """Wraps hidden states in new Tensors, to detach them from their history.""" 110 | # needed for truncated BPTT, called at every batch forward pass 111 | if isinstance(h, torch.Tensor): 112 | return h.detach() 113 | else: 114 | return tuple(self.repackage_hidden(v) for v in h) 115 | 116 | def initial_state(self, batch_size, trainable=False): 117 | """ 118 | Creates the initial memory. 119 | We should ensure each row of the memory is initialized to be unique, 120 | so initialize the matrix to be the identity. We then pad or truncate 121 | as necessary so that init_state is of size 122 | (batch_size, self.mem_slots, self.mem_size). 123 | Args: 124 | batch_size: The size of the batch. 125 | trainable: Whether the initial state is trainable. This is always True. 126 | Returns: 127 | init_state: A truncated or padded matrix of size 128 | (batch_size, self.mem_slots, self.mem_size). 129 | """ 130 | init_state = torch.stack([torch.eye(self.mem_slots) for _ in range(batch_size)]) 131 | 132 | # pad the matrix with zeros 133 | if self.mem_size > self.mem_slots: 134 | difference = self.mem_size - self.mem_slots 135 | pad = torch.zeros((batch_size, self.mem_slots, difference)) 136 | init_state = torch.cat([init_state, pad], -1) 137 | 138 | # truncation. take the first 'self.mem_size' components 139 | elif self.mem_size < self.mem_slots: 140 | init_state = init_state[:, :, :self.mem_size] 141 | 142 | return init_state 143 | 144 | def multihead_attention(self, memory): 145 | """ 146 | Perform multi-head attention from 'Attention is All You Need'. 147 | Implementation of the attention mechanism from 148 | https://arxiv.org/abs/1706.03762. 149 | Args: 150 | memory: Memory tensor to perform attention on. 151 | Returns: 152 | new_memory: New memory tensor. 153 | """ 154 | 155 | # First, a simple linear projection is used to construct queries 156 | qkv = self.qkv_projector(memory) 157 | # apply layernorm for every dim except the batch dim 158 | qkv = self.qkv_layernorm(qkv) 159 | 160 | # mem_slots needs to be dynamically computed since mem_slots got concatenated with inputs 161 | # example: self.mem_slots=10 and seq_length is 3, and then mem_slots is 10 + 1 = 11 for each 3 step forward pass 162 | # this is the same as self.mem_slots_plus_input, but defined to keep the sonnet implementation code style 163 | mem_slots = memory.shape[1] # denoted as N 164 | 165 | # split the qkv to multiple heads H 166 | # [B, N, F] => [B, N, H, F/H] 167 | qkv_reshape = qkv.view(qkv.shape[0], mem_slots, self.num_heads, self.qkv_size) 168 | 169 | # [B, N, H, F/H] => [B, H, N, F/H] 170 | qkv_transpose = qkv_reshape.permute(0, 2, 1, 3) 171 | 172 | # [B, H, N, key_size], [B, H, N, key_size], [B, H, N, value_size] 173 | q, k, v = torch.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1) 174 | 175 | # scale q with d_k, the dimensionality of the key vectors 176 | q *= (self.key_size ** -0.5) 177 | 178 | # make it [B, H, N, N] 179 | dot_product = torch.matmul(q, k.permute(0, 1, 3, 2)) 180 | weights = F.softmax(dot_product, dim=-1) 181 | 182 | # output is [B, H, N, V] 183 | output = torch.matmul(weights, v) 184 | 185 | # [B, H, N, V] => [B, N, H, V] => [B, N, H*V] 186 | output_transpose = output.permute(0, 2, 1, 3).contiguous() 187 | new_memory = output_transpose.view((output_transpose.shape[0], output_transpose.shape[1], -1)) 188 | 189 | return new_memory 190 | 191 | @property 192 | def state_size(self): 193 | return [self.mem_slots, self.mem_size] 194 | 195 | @property 196 | def output_size(self): 197 | return self.mem_slots * self.mem_size 198 | 199 | def calculate_gate_size(self): 200 | """ 201 | Calculate the gate size from the gate_style. 202 | Returns: 203 | The per sample, per head parameter size of each gate. 204 | """ 205 | if self.gate_style == 'unit': 206 | return self.mem_size 207 | elif self.gate_style == 'memory': 208 | return 1 209 | else: # self.gate_style == None 210 | return 0 211 | 212 | def create_gates(self, inputs, memory): 213 | """ 214 | Create input and forget gates for this step using `inputs` and `memory`. 215 | Args: 216 | inputs: Tensor input. 217 | memory: The current state of memory. 218 | Returns: 219 | input_gate: A LSTM-like insert gate. 220 | forget_gate: A LSTM-like forget gate. 221 | """ 222 | # We'll create the input and forget gates at once. Hence, calculate double 223 | # the gate size. 224 | 225 | # equation 8: since there is no output gate, h is just a tanh'ed m 226 | memory = torch.tanh(memory) 227 | 228 | # TODO: check this input flattening is correct 229 | # sonnet uses this, but i think it assumes time step of 1 for all cases 230 | # if inputs is (B, T, features) where T > 1, this gets incorrect 231 | # inputs = inputs.view(inputs.shape[0], -1) 232 | 233 | # fixed implementation 234 | if len(inputs.shape) == 3: 235 | if inputs.shape[1] > 1: 236 | raise ValueError( 237 | "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1") 238 | inputs = inputs.view(inputs.shape[0], -1) 239 | # matmul for equation 4 and 5 240 | # there is no output gate, so equation 6 is not implemented 241 | gate_inputs = self.input_gate_projector(inputs) 242 | gate_inputs = gate_inputs.unsqueeze(dim=1) 243 | gate_memory = self.memory_gate_projector(memory) 244 | else: 245 | raise ValueError("input shape of create_gate function is 2, expects 3") 246 | 247 | # this completes the equation 4 and 5 248 | gates = gate_memory + gate_inputs 249 | gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2) 250 | input_gate, forget_gate = gates 251 | assert input_gate.shape[2] == forget_gate.shape[2] 252 | 253 | # to be used for equation 7 254 | input_gate = torch.sigmoid(input_gate + self.input_bias) 255 | forget_gate = torch.sigmoid(forget_gate + self.forget_bias) 256 | 257 | return input_gate, forget_gate 258 | 259 | def attend_over_memory(self, memory): 260 | """ 261 | Perform multiheaded attention over `memory`. 262 | Args: 263 | memory: Current relational memory. 264 | Returns: 265 | The attended-over memory. 266 | """ 267 | for _ in range(self.num_blocks): 268 | attended_memory = self.multihead_attention(memory) 269 | 270 | # Add a skip connection to the multiheaded attention's input. 271 | memory = self.attended_memory_layernorm(memory + attended_memory) 272 | 273 | # add a skip connection to the attention_mlp's input. 274 | attention_mlp = memory 275 | for i, l in enumerate(self.attention_mlp): 276 | attention_mlp = self.attention_mlp[i](attention_mlp) 277 | attention_mlp = F.relu(attention_mlp) 278 | memory = self.attended_memory_layernorm2(memory + attention_mlp) 279 | 280 | return memory 281 | 282 | def forward_step(self, inputs, memory, treat_input_as_matrix=False): 283 | """ 284 | Forward step of the relational memory core. 285 | Args: 286 | inputs: Tensor input. 287 | memory: Memory output from the previous time step. 288 | treat_input_as_matrix: Optional, whether to treat `input` as a sequence 289 | of matrices. Default to False, in which case the input is flattened 290 | into a vector. 291 | Returns: 292 | output: This time step's output. 293 | next_memory: The next version of memory to use. 294 | """ 295 | 296 | if treat_input_as_matrix: 297 | # keep (Batch, Seq, ...) dim (0, 1), flatten starting from dim 2 298 | inputs = inputs.view(inputs.shape[0], inputs.shape[1], -1) 299 | # apply linear layer for dim 2 300 | inputs_reshape = self.input_projector(inputs) 301 | else: 302 | # keep (Batch, ...) dim (0), flatten starting from dim 1 303 | inputs = inputs.view(inputs.shape[0], -1) 304 | # apply linear layer for dim 1 305 | inputs = self.input_projector(inputs) 306 | # unsqueeze the time step to dim 1 307 | inputs_reshape = inputs.unsqueeze(dim=1) 308 | 309 | memory_plus_input = torch.cat([memory, inputs_reshape], dim=1) 310 | next_memory = self.attend_over_memory(memory_plus_input) 311 | 312 | # cut out the concatenated input vectors from the original memory slots 313 | n = inputs_reshape.shape[1] 314 | next_memory = next_memory[:, :-n, :] 315 | 316 | if self.gate_style == 'unit' or self.gate_style == 'memory': 317 | # these gates are sigmoid-applied ones for equation 7 318 | input_gate, forget_gate = self.create_gates(inputs_reshape, memory) 319 | # equation 7 calculation 320 | next_memory = input_gate * torch.tanh(next_memory) 321 | next_memory += forget_gate * memory 322 | 323 | output = next_memory.view(next_memory.shape[0], -1) 324 | 325 | return output, next_memory 326 | 327 | def forward(self, inputs, memory): 328 | # Starting each batch, we detach the hidden state from how it was previously produced. 329 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 330 | memory = self.repackage_hidden(memory) 331 | 332 | # for loop implementation of (entire) recurrent forward pass of the model 333 | # inputs is batch first [batch, seq], and output logit per step is [batch, vocab] 334 | # so the concatenated logits are [seq * batch, vocab] 335 | 336 | # targets are flattened [seq, batch] => [seq * batch], so the dimension is correct 337 | 338 | logits = [] 339 | # shape[1] is seq_lenth T 340 | for idx_step in range(inputs.shape[1]): 341 | logit, memory = self.forward_step(inputs[:, idx_step], memory) 342 | logits.append(logit) 343 | logits = torch.cat(logits) 344 | 345 | if self.return_all_outputs: 346 | return logits, memory 347 | else: 348 | return logit, memory 349 | 350 | 351 | 352 | 353 | # ########## DEBUG: unit test code ########## 354 | # input_size = 44 355 | # seq_length = 1 356 | # batch_size = 32 357 | # model = RelationalMemory(mem_slots=10, head_size=20, input_size=input_size, num_tokens=66, num_heads=8, num_blocks=1, forget_bias=1., input_bias=0.) 358 | # model_memory = model.initial_state(batch_size=batch_size) 359 | # 360 | # # random input 361 | # random_input = torch.randn((32, seq_length, input_size)) 362 | # # random targets 363 | # random_targets = torch.randn((32, seq_length, input_size)) 364 | # 365 | # # take a one step forward 366 | # logit, next_memory = model(random_input, model_memory, random_targets, treat_input_as_matrix=True) 367 | -------------------------------------------------------------------------------- /relational_rnn_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | # this class largely follows the official sonnet implementation 7 | # https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/relational_memory.py 8 | 9 | 10 | class RelationalMemory(nn.Module): 11 | """ 12 | Constructs a `RelationalMemory` object. 13 | Args: 14 | mem_slots: The total number of memory slots to use. 15 | head_size: The size of an attention head. 16 | input_size: The size of input per step. i.e. the dimension of each input vector 17 | num_heads: The number of attention heads to use. Defaults to 1. 18 | num_blocks: Number of times to compute attention per time step. Defaults 19 | to 1. 20 | forget_bias: Bias to use for the forget gate, assuming we are using 21 | some form of gating. Defaults to 1. 22 | input_bias: Bias to use for the input gate, assuming we are using 23 | some form of gating. Defaults to 0. 24 | gate_style: Whether to use per-element gating ('unit'), 25 | per-memory slot gating ('memory'), or no gating at all (None). 26 | Defaults to `unit`. 27 | attention_mlp_layers: Number of layers to use in the post-attention 28 | MLP. Defaults to 2. 29 | key_size: Size of vector to use for key & query vectors in the attention 30 | computation. Defaults to None, in which case we use `head_size`. 31 | name: Name of the module. 32 | Raises: 33 | ValueError: gate_style not one of [None, 'memory', 'unit']. 34 | ValueError: num_blocks is < 1. 35 | ValueError: attention_mlp_layers is < 1. 36 | """ 37 | 38 | def __init__(self, mem_slots, head_size, input_size, num_tokens, num_heads=1, num_blocks=1, forget_bias=1., 39 | input_bias=0., 40 | gate_style='unit', attention_mlp_layers=2, key_size=None, use_adaptive_softmax=False, cutoffs=None): 41 | super(RelationalMemory, self).__init__() 42 | 43 | ########## generic parameters for RMC ########## 44 | self.mem_slots = mem_slots 45 | self.head_size = head_size 46 | self.num_heads = num_heads 47 | self.mem_size = self.head_size * self.num_heads 48 | 49 | # a new fixed params needed for pytorch port of RMC 50 | # +1 is the concatenated input per time step : we do self-attention with the concatenated memory & input 51 | # so if the mem_slots = 1, this value is 2 52 | self.mem_slots_plus_input = self.mem_slots + 1 53 | 54 | if num_blocks < 1: 55 | raise ValueError('num_blocks must be >=1. Got: {}.'.format(num_blocks)) 56 | self.num_blocks = num_blocks 57 | 58 | if gate_style not in ['unit', 'memory', None]: 59 | raise ValueError( 60 | 'gate_style must be one of [\'unit\', \'memory\', None]. got: ' 61 | '{}.'.format(gate_style)) 62 | self.gate_style = gate_style 63 | 64 | if attention_mlp_layers < 1: 65 | raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format( 66 | attention_mlp_layers)) 67 | self.attention_mlp_layers = attention_mlp_layers 68 | 69 | self.key_size = key_size if key_size else self.head_size 70 | 71 | ########## parameters for multihead attention ########## 72 | # value_size is same as head_size 73 | self.value_size = self.head_size 74 | # total size for query-key-value 75 | self.qkv_size = 2 * self.key_size + self.value_size 76 | self.total_qkv_size = self.qkv_size * self.num_heads # denoted as F 77 | 78 | # each head has qkv_sized linear projector 79 | # just using one big param is more efficient, rather than this line 80 | # self.qkv_projector = [nn.Parameter(torch.randn((self.qkv_size, self.qkv_size))) for _ in range(self.num_heads)] 81 | self.qkv_projector = nn.Linear(self.mem_size, self.total_qkv_size) 82 | self.qkv_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.total_qkv_size]) 83 | 84 | # used for attend_over_memory function 85 | self.attention_mlp = nn.ModuleList([nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers) 86 | self.attended_memory_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size]) 87 | self.attended_memory_layernorm2 = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size]) 88 | 89 | ########## parameters for initial embedded input projection ########## 90 | self.input_size = input_size 91 | self.input_projector = nn.Linear(self.input_size, self.mem_size) 92 | 93 | ########## parameters for gating ########## 94 | self.num_gates = 2 * self.calculate_gate_size() 95 | self.input_gate_projector = nn.Linear(self.mem_size, self.num_gates) 96 | self.memory_gate_projector = nn.Linear(self.mem_size, self.num_gates) 97 | # trainable scalar gate bias tensors 98 | self.forget_bias = nn.Parameter(torch.tensor(forget_bias, dtype=torch.float32)) 99 | self.input_bias = nn.Parameter(torch.tensor(input_bias, dtype=torch.float32)) 100 | 101 | ########## parameters for token-to-embed & output-to-token logit for softmax 102 | self.dropout = nn.Dropout() 103 | self.num_tokens = num_tokens 104 | self.token_to_input_encoder = nn.Embedding(self.num_tokens, self.input_size) 105 | 106 | # needs 2 linear layers for tying weights for embedding layers 107 | # first match the "output" of the RMC to input_size, which is the embed dim 108 | self.output_to_embed_decoder = nn.Linear(self.mem_slots * self.mem_size, self.input_size) 109 | self.use_adaptive_softmax = use_adaptive_softmax 110 | if not self.use_adaptive_softmax: 111 | # then, this layer's weight can be tied to the embedding layer 112 | self.embed_to_logit_decoder = nn.Linear(self.input_size, self.num_tokens) 113 | 114 | # tie embedding weights of encoder & decoder 115 | self.embed_to_logit_decoder.weight = self.token_to_input_encoder.weight 116 | 117 | ########## loss function 118 | self.criterion = nn.CrossEntropyLoss() 119 | else: 120 | # use adaptive softmax from the self.input_size logits, instead of the tied embed weights above 121 | self.criterion_adaptive = nn.AdaptiveLogSoftmaxWithLoss(self.input_size, self.num_tokens, 122 | cutoffs=cutoffs) 123 | 124 | def repackage_hidden(self, h): 125 | """Wraps hidden states in new Tensors, to detach them from their history.""" 126 | # needed for truncated BPTT, called at every batch forward pass 127 | if isinstance(h, torch.Tensor): 128 | return h.detach() 129 | else: 130 | return tuple(self.repackage_hidden(v) for v in h) 131 | 132 | def initial_state(self, batch_size, trainable=False): 133 | """ 134 | Creates the initial memory. 135 | We should ensure each row of the memory is initialized to be unique, 136 | so initialize the matrix to be the identity. We then pad or truncate 137 | as necessary so that init_state is of size 138 | (batch_size, self.mem_slots, self.mem_size). 139 | Args: 140 | batch_size: The size of the batch. 141 | trainable: Whether the initial state is trainable. This is always True. 142 | Returns: 143 | init_state: A truncated or padded matrix of size 144 | (batch_size, self.mem_slots, self.mem_size). 145 | """ 146 | init_state = torch.stack([torch.eye(self.mem_slots) for _ in range(batch_size)]) 147 | 148 | # pad the matrix with zeros 149 | if self.mem_size > self.mem_slots: 150 | difference = self.mem_size - self.mem_slots 151 | pad = torch.zeros((batch_size, self.mem_slots, difference)) 152 | init_state = torch.cat([init_state, pad], -1) 153 | 154 | # truncation. take the first 'self.mem_size' components 155 | elif self.mem_size < self.mem_slots: 156 | init_state = init_state[:, :, :self.mem_size] 157 | 158 | return init_state 159 | 160 | def multihead_attention(self, memory): 161 | """ 162 | Perform multi-head attention from 'Attention is All You Need'. 163 | Implementation of the attention mechanism from 164 | https://arxiv.org/abs/1706.03762. 165 | Args: 166 | memory: Memory tensor to perform attention on. 167 | Returns: 168 | new_memory: New memory tensor. 169 | """ 170 | 171 | # First, a simple linear projection is used to construct queries 172 | qkv = self.qkv_projector(memory) 173 | # apply layernorm for every dim except the batch dim 174 | qkv = self.qkv_layernorm(qkv) 175 | 176 | # mem_slots needs to be dynamically computed since mem_slots got concatenated with inputs 177 | # example: self.mem_slots=10 and seq_length is 3, and then mem_slots is 10 + 1 = 11 for each 3 step forward pass 178 | # this is the same as self.mem_slots_plus_input, but defined to keep the sonnet implementation code style 179 | mem_slots = memory.shape[1] # denoted as N 180 | 181 | # split the qkv to multiple heads H 182 | # [B, N, F] => [B, N, H, F/H] 183 | qkv_reshape = qkv.view(qkv.shape[0], mem_slots, self.num_heads, self.qkv_size) 184 | 185 | # [B, N, H, F/H] => [B, H, N, F/H] 186 | qkv_transpose = qkv_reshape.permute(0, 2, 1, 3) 187 | 188 | # [B, H, N, key_size], [B, H, N, key_size], [B, H, N, value_size] 189 | q, k, v = torch.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1) 190 | 191 | # scale q with d_k, the dimensionality of the key vectors 192 | q *= (self.key_size ** -0.5) 193 | 194 | # make it [B, H, N, N] 195 | dot_product = torch.matmul(q, k.permute(0, 1, 3, 2)) 196 | weights = F.softmax(dot_product, dim=-1) 197 | 198 | # output is [B, H, N, V] 199 | output = torch.matmul(weights, v) 200 | 201 | # [B, H, N, V] => [B, N, H, V] => [B, N, H*V] 202 | output_transpose = output.permute(0, 2, 1, 3).contiguous() 203 | new_memory = output_transpose.view((output_transpose.shape[0], output_transpose.shape[1], -1)) 204 | 205 | return new_memory 206 | 207 | @property 208 | def state_size(self): 209 | return [self.mem_slots, self.mem_size] 210 | 211 | @property 212 | def output_size(self): 213 | return self.mem_slots * self.mem_size 214 | 215 | def calculate_gate_size(self): 216 | """ 217 | Calculate the gate size from the gate_style. 218 | Returns: 219 | The per sample, per head parameter size of each gate. 220 | """ 221 | if self.gate_style == 'unit': 222 | return self.mem_size 223 | elif self.gate_style == 'memory': 224 | return 1 225 | else: # self.gate_style == None 226 | return 0 227 | 228 | def create_gates(self, inputs, memory): 229 | """ 230 | Create input and forget gates for this step using `inputs` and `memory`. 231 | Args: 232 | inputs: Tensor input. 233 | memory: The current state of memory. 234 | Returns: 235 | input_gate: A LSTM-like insert gate. 236 | forget_gate: A LSTM-like forget gate. 237 | """ 238 | # We'll create the input and forget gates at once. Hence, calculate double 239 | # the gate size. 240 | 241 | # equation 8: since there is no output gate, h is just a tanh'ed m 242 | memory = torch.tanh(memory) 243 | 244 | # TODO: check this input flattening is correct 245 | # sonnet uses this, but i think it assumes time step of 1 for all cases 246 | # if inputs is (B, T, features) where T > 1, this gets incorrect 247 | # inputs = inputs.view(inputs.shape[0], -1) 248 | 249 | # fixed implementation 250 | if len(inputs.shape) == 3: 251 | if inputs.shape[1] > 1: 252 | raise ValueError( 253 | "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1") 254 | inputs = inputs.view(inputs.shape[0], -1) 255 | # matmul for equation 4 and 5 256 | # there is no output gate, so equation 6 is not implemented 257 | gate_inputs = self.input_gate_projector(inputs) 258 | gate_inputs = gate_inputs.unsqueeze(dim=1) 259 | gate_memory = self.memory_gate_projector(memory) 260 | else: 261 | raise ValueError("input shape of create_gate function is 2, expects 3") 262 | 263 | # this completes the equation 4 and 5 264 | gates = gate_memory + gate_inputs 265 | gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2) 266 | input_gate, forget_gate = gates 267 | assert input_gate.shape[2] == forget_gate.shape[2] 268 | 269 | # to be used for equation 7 270 | input_gate = torch.sigmoid(input_gate + self.input_bias) 271 | forget_gate = torch.sigmoid(forget_gate + self.forget_bias) 272 | 273 | return input_gate, forget_gate 274 | 275 | def attend_over_memory(self, memory): 276 | """ 277 | Perform multiheaded attention over `memory`. 278 | Args: 279 | memory: Current relational memory. 280 | Returns: 281 | The attended-over memory. 282 | """ 283 | for _ in range(self.num_blocks): 284 | attended_memory = self.multihead_attention(memory) 285 | 286 | # Add a skip connection to the multiheaded attention's input. 287 | memory = self.attended_memory_layernorm(memory + attended_memory) 288 | 289 | # add a skip connection to the attention_mlp's input. 290 | attention_mlp = memory 291 | for i, l in enumerate(self.attention_mlp): 292 | attention_mlp = self.attention_mlp[i](attention_mlp) 293 | attention_mlp = F.relu(attention_mlp) 294 | memory = self.attended_memory_layernorm2(memory + attention_mlp) 295 | 296 | return memory 297 | 298 | def forward_step(self, inputs, memory, treat_input_as_matrix=False): 299 | """ 300 | Forward step of the relational memory core. 301 | Args: 302 | inputs: Tensor input. 303 | memory: Memory output from the previous time step. 304 | treat_input_as_matrix: Optional, whether to treat `input` as a sequence 305 | of matrices. Default to False, in which case the input is flattened 306 | into a vector. 307 | Returns: 308 | output: This time step's output. 309 | next_memory: The next version of memory to use. 310 | """ 311 | 312 | # first embed the tokens into vectors 313 | inputs_embed = self.dropout(self.token_to_input_encoder(inputs)) 314 | 315 | if treat_input_as_matrix: 316 | # keep (Batch, Seq, ...) dim (0, 1), flatten starting from dim 2 317 | inputs_embed = inputs_embed.view(inputs_embed.shape[0], inputs_embed.shape[1], -1) 318 | # apply linear layer for dim 2 319 | inputs_reshape = self.input_projector(inputs_embed) 320 | else: 321 | # keep (Batch, ...) dim (0), flatten starting from dim 1 322 | inputs_embed = inputs_embed.view(inputs_embed.shape[0], -1) 323 | # apply linear layer for dim 1 324 | inputs_embed = self.input_projector(inputs_embed) 325 | # unsqueeze the time step to dim 1 326 | inputs_reshape = inputs_embed.unsqueeze(dim=1) 327 | 328 | memory_plus_input = torch.cat([memory, inputs_reshape], dim=1) 329 | next_memory = self.attend_over_memory(memory_plus_input) 330 | 331 | # cut out the concatenated input vectors from the original memory slots 332 | n = inputs_reshape.shape[1] 333 | next_memory = next_memory[:, :-n, :] 334 | 335 | if self.gate_style == 'unit' or self.gate_style == 'memory': 336 | # these gates are sigmoid-applied ones for equation 7 337 | input_gate, forget_gate = self.create_gates(inputs_reshape, memory) 338 | # equation 7 calculation 339 | next_memory = input_gate * torch.tanh(next_memory) 340 | next_memory += forget_gate * memory 341 | 342 | output = next_memory.view(next_memory.shape[0], -1) 343 | 344 | # decode output to logit 345 | output_embed = self.output_to_embed_decoder(output) 346 | # TODO: this dropout is not mentioned in the paper. it's to match word-language-model dropout use case 347 | output_embed = self.dropout(output_embed) 348 | 349 | if not self.use_adaptive_softmax: 350 | logit = self.embed_to_logit_decoder(output_embed) 351 | else: 352 | logit = output_embed 353 | 354 | return logit, next_memory 355 | 356 | def forward(self, inputs, memory, targets, require_logits=False): 357 | # Starting each batch, we detach the hidden state from how it was previously produced. 358 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 359 | memory = self.repackage_hidden(memory) 360 | 361 | # for loop implementation of (entire) recurrent forward pass of the model 362 | # inputs is batch first [batch, seq], and output logit per step is [batch, vocab] 363 | # so the concatenated logits are [seq * batch, vocab] 364 | 365 | # targets are flattened [seq, batch] => [seq * batch], so the dimension is correct 366 | 367 | logits = [] 368 | # shape[1] is seq_lenth T 369 | for idx_step in range(inputs.shape[1]): 370 | logit, memory = self.forward_step(inputs[:, idx_step], memory) 371 | logits.append(logit) 372 | # concat the output from list(seq_length) of [batch, vocab] to [seq * batch, vocab] 373 | logits = torch.cat(logits) 374 | 375 | if targets is not None: 376 | if not self.use_adaptive_softmax: 377 | # calculate loss inside this forward pass for more even VRAM usage of DataParallel 378 | loss = self.criterion(logits, targets) 379 | else: 380 | # calculate the loss using adaptive softmax 381 | _, loss = self.criterion_adaptive(logits, targets) 382 | else: 383 | loss = None 384 | 385 | # the forward pass only returns loss, because returning logits causes uneven VRAM usage of DataParallel 386 | # logits are provided only for sampling stage 387 | if not require_logits: 388 | return loss, memory 389 | else: 390 | return logits, loss, memory 391 | 392 | 393 | 394 | # ########## DEBUG: unit test code ########## 395 | # input_size = 44 396 | # seq_length = 1 397 | # batch_size = 32 398 | # model = RelationalMemory(mem_slots=10, head_size=20, input_size=input_size, num_tokens=66, num_heads=8, num_blocks=1, forget_bias=1., input_bias=0.) 399 | # model_memory = model.initial_state(batch_size=batch_size) 400 | # 401 | # # random input 402 | # random_input = torch.randn((32, seq_length, input_size)) 403 | # # random targets 404 | # random_targets = torch.randn((32, seq_length, input_size)) 405 | # 406 | # # take a one step forward 407 | # logit, next_memory = model(random_input, model_memory, random_targets, treat_input_as_matrix=True) 408 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | -------------------------------------------------------------------------------- /rnn_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class RNNModel(nn.Module): 6 | """Container module with an encoder, a recurrent module, and a decoder.""" 7 | 8 | def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False, use_cudnn_version=True, 9 | use_adaptive_softmax=False, cutoffs=None): 10 | super(RNNModel, self).__init__() 11 | self.use_cudnn_version = use_cudnn_version 12 | self.drop = nn.Dropout(dropout) 13 | self.encoder = nn.Embedding(ntoken, ninp) 14 | if use_cudnn_version: 15 | if rnn_type in ['LSTM', 'GRU']: 16 | self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) 17 | else: 18 | try: 19 | nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] 20 | except KeyError: 21 | raise ValueError("""An invalid option for `--model` was supplied, 22 | options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""") 23 | self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) 24 | else: 25 | if rnn_type in ['LSTM', 'GRU']: 26 | rnn_type = str(rnn_type) + 'Cell' 27 | rnn_modulelist = [] 28 | for i in range(nlayers): 29 | rnn_modulelist.append(getattr(nn, rnn_type)(ninp, nhid)) 30 | if i < nlayers - 1: 31 | rnn_modulelist.append(nn.Dropout(dropout)) 32 | self.rnn = nn.ModuleList(rnn_modulelist) 33 | else: 34 | raise ValueError("non-cudnn version of (RNNCell) is not implemented. use LSTM or GRU instead") 35 | 36 | if not use_adaptive_softmax: 37 | self.use_adaptive_softmax = use_adaptive_softmax 38 | self.decoder = nn.Linear(nhid, ntoken) 39 | # Optionally tie weights as in: 40 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 41 | # https://arxiv.org/abs/1608.05859 42 | # and 43 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 44 | # https://arxiv.org/abs/1611.01462 45 | if tie_weights: 46 | if nhid != ninp: 47 | raise ValueError('When using the tied flag, nhid must be equal to emsize') 48 | self.decoder.weight = self.encoder.weight 49 | else: 50 | # simple linear layer of nhid output size. used for adaptive softmax after 51 | # directly applying softmax at the hidden states is a bad idea 52 | self.decoder_adaptive = nn.Linear(nhid, nhid) 53 | self.use_adaptive_softmax = use_adaptive_softmax 54 | self.cutoffs = cutoffs 55 | if tie_weights: 56 | print("Warning: if using adaptive softmax, tie_weights cannot be applied. Ignored.") 57 | 58 | self.init_weights() 59 | 60 | self.rnn_type = rnn_type 61 | self.nhid = nhid 62 | self.nlayers = nlayers 63 | 64 | def init_weights(self): 65 | initrange = 0.1 66 | self.encoder.weight.data.uniform_(-initrange, initrange) 67 | if not self.use_adaptive_softmax: 68 | self.decoder.bias.data.zero_() 69 | self.decoder.weight.data.uniform_(-initrange, initrange) 70 | 71 | def forward(self, input, hidden): 72 | emb = self.drop(self.encoder(input)) 73 | if self.use_cudnn_version: 74 | output, hidden = self.rnn(emb, hidden) 75 | else: 76 | # for loop implementation with RNNCell 77 | layer_input = emb 78 | new_hidden = [[], []] 79 | for idx_layer in range(0, self.nlayers + 1, 2): 80 | output = [] 81 | hx, cx = hidden[0][int(idx_layer / 2)], hidden[1][int(idx_layer / 2)] 82 | for idx_step in range(input.shape[0]): 83 | hx, cx = self.rnn[idx_layer](layer_input[idx_step], (hx, cx)) 84 | output.append(hx) 85 | output = torch.stack(output) 86 | if idx_layer + 1 < self.nlayers: 87 | output = self.rnn[idx_layer + 1](output) 88 | layer_input = output 89 | new_hidden[0].append(hx) 90 | new_hidden[1].append(cx) 91 | new_hidden[0] = torch.stack(new_hidden[0]) 92 | new_hidden[1] = torch.stack(new_hidden[1]) 93 | hidden = tuple(new_hidden) 94 | 95 | output = self.drop(output) 96 | 97 | if not self.use_adaptive_softmax: 98 | decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2))) 99 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden 100 | else: 101 | decoded = self.decoder_adaptive(output.view(output.size(0) * output.size(1), output.size(2))) 102 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden 103 | 104 | def init_hidden(self, bsz): 105 | weight = next(self.parameters()) 106 | if self.rnn_type == 'LSTM' or self.rnn_type == 'LSTMCell': 107 | return (weight.new_zeros(self.nlayers, bsz, self.nhid), 108 | weight.new_zeros(self.nlayers, bsz, self.nhid)) 109 | else: 110 | return weight.new_zeros(self.nlayers, bsz, self.nhid) 111 | -------------------------------------------------------------------------------- /train_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Template to use Relational RNN module 3 | to predict a scalar from a sequence of embeddings, 4 | e.g. a sentence. 5 | 6 | Input: fixed-length sequence of `num_words` words, 7 | each represented by a `num_embedding_dims` dimensional embedding. 8 | 9 | Output: A scalar. 10 | 11 | Author: Jessica Yung 12 | August 2018 13 | 14 | Relational Memory Core implementation mostly written by Sang-gil Lee, adapted by Jessica Yung. 15 | """ 16 | import torch 17 | import torch.nn as nn 18 | import matplotlib.pyplot as plt 19 | import numpy as np 20 | 21 | from relational_rnn_general import RelationalMemory 22 | 23 | # network params 24 | learning_rate = 1e-3 25 | num_epochs = 50 26 | # dtype = torch.float 27 | 28 | # data params 29 | # Input = seq of `num_words` words, embedding for each word has `num_embedding_dims` dims 30 | num_words = 10 31 | num_embedding_dims = 5 32 | input_size = num_embedding_dims 33 | # Predicting a scalar 34 | output_size = 1 35 | 36 | num_examples = 20 37 | test_size = 0.2 38 | num_train = int((1 - test_size) * num_examples) 39 | batch_size = 4 40 | 41 | #################### 42 | # Generate data 43 | #################### 44 | 45 | X = torch.rand((num_examples, num_words, num_embedding_dims)) 46 | # Predicting a scalar per example 47 | y = torch.rand((num_examples, output_size)) 48 | 49 | X_train = X[:num_train] 50 | X_test = X[num_train:] 51 | y_train = y[:num_train] 52 | y_test = y[num_train:] 53 | 54 | 55 | class RMCArguments: 56 | def __init__(self): 57 | self.memslots = 1 58 | self.headsize = 3 59 | self.numheads = 4 60 | self.input_size = input_size # dimensions per timestep 61 | self.numheads = 4 62 | self.numblocks = 1 63 | self.forgetbias = 1. 64 | self.inputbias = 0. 65 | self.attmlplayers = 3 66 | self.batch_size = batch_size 67 | self.clip = 0.1 68 | 69 | 70 | args = RMCArguments() 71 | 72 | device = torch.device("cpu") 73 | 74 | 75 | #################### 76 | # Build model 77 | #################### 78 | 79 | class RRNN(nn.Module): 80 | def __init__(self, batch_size): 81 | super(RRNN, self).__init__() 82 | self.memory_size_per_row = args.headsize * args.numheads 83 | self.relational_memory = RelationalMemory(mem_slots=args.memslots, head_size=args.headsize, 84 | input_size=args.input_size, 85 | num_heads=args.numheads, num_blocks=args.numblocks, 86 | forget_bias=args.forgetbias, 87 | input_bias=args.inputbias) 88 | # Map from memory to logits (categorical predictions) 89 | self.out = nn.Linear(self.memory_size_per_row, output_size) 90 | 91 | def forward(self, input, memory): 92 | logit, memory = self.relational_memory(input, memory) 93 | out = self.out(logit) 94 | 95 | return out, memory 96 | 97 | 98 | model = RRNN(batch_size).to(device) 99 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 100 | 101 | print("Model built, total trainable params: " + str(total_params)) 102 | 103 | 104 | def get_batch(X, y, batch_num, device, batch_size=32, batch_first=True): 105 | if not batch_first: 106 | raise NotImplementedError 107 | start = batch_num * batch_size 108 | end = (batch_num + 1) * batch_size 109 | return X[start:end].to(device), y[start:end].to(device) 110 | 111 | 112 | loss_fn = torch.nn.MSELoss() 113 | 114 | optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate) 115 | 116 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', factor=0.5, patience=5) 117 | 118 | num_batches = int(len(X_train) / batch_size) 119 | num_test_batches = int(len(X_test) / batch_size) 120 | 121 | memory = model.relational_memory.initial_state(args.batch_size, trainable=True).to(device) 122 | 123 | hist = np.zeros(num_epochs) 124 | 125 | 126 | def accuracy_score(y_pred, y_true): 127 | return np.array(y_pred == y_true).sum() * 1.0 / len(y_true) 128 | 129 | 130 | #################### 131 | # Train model 132 | #################### 133 | 134 | for t in range(num_epochs): 135 | epoch_loss = np.zeros(num_batches) 136 | # epoch_acc = np.zeros(num_batches) 137 | epoch_test_loss = np.zeros(num_test_batches) 138 | # epoch_test_acc = np.zeros(num_test_batches) 139 | for i in range(num_batches): 140 | data, targets = get_batch(X_train, y_train, i, device=device, batch_size=batch_size) 141 | model.zero_grad() 142 | 143 | # forward pass 144 | # replace "_" with "memory" if you want to make the RNN stateful 145 | y_pred, memory = model(data, memory) 146 | 147 | loss = loss_fn(y_pred, targets) 148 | loss = torch.mean(loss) 149 | # y_pred = torch.argmax(y_pred, dim=1) 150 | # acc = accuracy_score(y_pred, targets) 151 | epoch_loss[i] = loss 152 | # epoch_acc[i] = acc 153 | 154 | # Zero out gradient, else they will accumulate between epochs 155 | optimiser.zero_grad() 156 | 157 | # backward pass 158 | loss.backward() 159 | 160 | # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 161 | 162 | # update parameters 163 | optimiser.step() 164 | 165 | # test examples 166 | hist[t] = np.mean(epoch_loss).item() 167 | if t % 10 == 0: 168 | print("train: ", y_pred.squeeze().detach().cpu().numpy(), targets.squeeze().detach().cpu().numpy()) 169 | for i in range(num_test_batches): 170 | with torch.no_grad(): 171 | data, targets = get_batch(X_test, y_test, i, device=device, batch_size=batch_size) 172 | ytest_pred, memory = model(data, memory) 173 | 174 | test_loss = loss_fn(ytest_pred, targets) 175 | test_loss = torch.mean(test_loss) 176 | # ytest_pred = torch.argmax(ytest_pred, dim=1) 177 | # test_acc = accuracy_score(ytest_pred, targets) 178 | epoch_test_loss[i] = loss 179 | # epoch_test_acc[i] = acc 180 | 181 | if t % 10 == 0: 182 | # print(epoch_test_loss) 183 | # print(epoch_test_acc) 184 | print("Epoch {} train loss: {}".format(t, np.mean(epoch_test_loss).item())) 185 | print("Epoch {} test loss: {}".format(t, np.mean(epoch_test_loss).item())) 186 | # print("Epoch {} train acc: {:.2f}".format(t, np.mean(epoch_acc).item())) 187 | # print("Epoch {} test acc: {:.2f}".format(t, np.mean(epoch_test_acc).item())) 188 | print("test: ", ytest_pred.squeeze().detach().cpu().numpy(), targets.squeeze().detach().cpu().numpy()) 189 | 190 | #################### 191 | # Plot losses 192 | #################### 193 | 194 | plt.plot(hist, label="Training loss") 195 | plt.legend() 196 | plt.show() 197 | 198 | """ 199 | # TODO: visualise preds 200 | plt.plot(y_pred.detach().numpy(), label="Preds") 201 | plt.plot(y_train.detach().numpy(), label="Data") 202 | plt.legend() 203 | plt.show() 204 | """ 205 | -------------------------------------------------------------------------------- /train_nth_farthest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of 'Nth Farthest' task 3 | as defined in Santoro, Faulkner and Raposo et. al., 2018 4 | (Relational recurrent neural networks, https://arxiv.org/abs/1806.01822) 5 | 6 | Note: The training data is re-generated each epoch as in the 7 | Sonnet implementation. This avoids overfitting but means that the 8 | experiments may take longer. 9 | 10 | Author: Jessica Yung 11 | August 2018 12 | 13 | Relational Memory Core implementation mostly written by Sang-gil Lee, adapted by Jessica Yung. 14 | """ 15 | import torch 16 | import torch.nn as nn 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | from argparse import ArgumentParser 20 | 21 | from relational_rnn_general import RelationalMemory 22 | 23 | parser = ArgumentParser() 24 | 25 | # Model parameters. 26 | parser.add_argument('--cuda', action='store_true', 27 | help='use CUDA') 28 | 29 | parse_args = parser.parse_args() 30 | 31 | if torch.cuda.is_available(): 32 | if not parse_args.cuda: 33 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 34 | 35 | device = torch.device("cuda" if parse_args.cuda else "cpu") 36 | 37 | # network params 38 | learning_rate = 1e-4 39 | num_epochs = 10000000 40 | dtype = torch.float 41 | mlp_size = 256 42 | 43 | # data params 44 | num_vectors = 8 45 | num_dims = 16 46 | batch_size = 1600 47 | num_batches = 6 # set batches per epoch because we are generating data from scratch each time 48 | num_test_examples = 3200 49 | 50 | #################### 51 | # Generate data 52 | #################### 53 | 54 | # For each example 55 | input_size = num_dims + num_vectors * 3 56 | 57 | 58 | def one_hot_encode(array, num_dims=8): 59 | one_hot = np.zeros((len(array), num_dims)) 60 | for i in range(len(array)): 61 | one_hot[i, array[i]] = 1 62 | return one_hot 63 | 64 | 65 | def get_example(num_vectors, num_dims): 66 | input_size = num_dims + num_vectors * 3 67 | n = np.random.choice(num_vectors, 1) # nth farthest from target vector 68 | labels = np.random.choice(num_vectors, num_vectors, replace=False) 69 | m_index = np.random.choice(num_vectors, 1) # m comes after the m_index-th vector 70 | m = labels[m_index] 71 | 72 | # Vectors sampled from U(-1,1) 73 | vectors = np.random.rand(num_vectors, num_dims) * 2 - 1 74 | target_vector = vectors[m_index] 75 | dist_from_target = np.linalg.norm(vectors - target_vector, axis=1) 76 | X_single = np.zeros((num_vectors, input_size)) 77 | X_single[:, :num_dims] = vectors 78 | labels_onehot = one_hot_encode(labels, num_dims=num_vectors) 79 | X_single[:, num_dims:num_dims + num_vectors] = labels_onehot 80 | nm_onehot = np.reshape(one_hot_encode([n, m], num_dims=num_vectors), -1) 81 | X_single[:, num_dims + num_vectors:] = np.tile(nm_onehot, (num_vectors, 1)) 82 | y_single = labels[np.argsort(dist_from_target)[-(n + 1)]] 83 | 84 | return X_single, y_single 85 | 86 | 87 | def get_examples(num_examples, num_vectors, num_dims, device): 88 | X = np.zeros((num_examples, num_vectors, input_size)) 89 | y = np.zeros(num_examples) 90 | for i in range(num_examples): 91 | X_single, y_single = get_example(num_vectors, num_dims) 92 | X[i, :] = X_single 93 | y[i] = y_single 94 | 95 | X = torch.Tensor(X).to(device) 96 | y = torch.LongTensor(y).to(device) 97 | 98 | return X, y 99 | 100 | 101 | X_test, y_test = get_examples(num_test_examples, num_vectors, num_dims, device) 102 | 103 | 104 | class RMCArguments: 105 | def __init__(self): 106 | self.memslots = 8 107 | self.numheads = 8 108 | self.headsize = int(2048 / (self.numheads * self.memslots)) 109 | self.input_size = input_size # dimensions per timestep 110 | self.numblocks = 1 111 | self.forgetbias = 1. 112 | self.inputbias = 0. 113 | self.attmlplayers = 2 114 | self.batch_size = batch_size 115 | self.clip = 0.1 116 | 117 | 118 | args = RMCArguments() 119 | 120 | 121 | #################### 122 | # Build model 123 | #################### 124 | 125 | class RRNN(nn.Module): 126 | def __init__(self, mlp_size): 127 | super(RRNN, self).__init__() 128 | self.mlp_size = mlp_size 129 | self.memory_size_per_row = args.headsize * args.numheads * args.memslots 130 | self.relational_memory = RelationalMemory(mem_slots=args.memslots, head_size=args.headsize, 131 | input_size=args.input_size, 132 | num_heads=args.numheads, num_blocks=args.numblocks, 133 | forget_bias=args.forgetbias, input_bias=args.inputbias) 134 | # Map from memory to logits (categorical predictions) 135 | self.mlp = nn.Sequential( 136 | nn.Linear(self.memory_size_per_row, self.mlp_size), 137 | nn.ReLU(), 138 | nn.Linear(self.mlp_size, self.mlp_size), 139 | nn.ReLU(), 140 | nn.Linear(self.mlp_size, self.mlp_size), 141 | nn.ReLU(), 142 | nn.Linear(self.mlp_size, self.mlp_size), 143 | nn.ReLU() 144 | ) 145 | self.out = nn.Linear(self.mlp_size, num_vectors) 146 | self.softmax = nn.Softmax(dim=1) 147 | 148 | def forward(self, input, memory): 149 | logit, memory = self.relational_memory(input, memory) 150 | mlp = self.mlp(logit) 151 | out = self.out(mlp) 152 | out = self.softmax(out) 153 | 154 | return out, memory 155 | 156 | 157 | model = RRNN(mlp_size).to(device) 158 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 159 | 160 | print("Model built, total trainable params: " + str(total_params)) 161 | 162 | 163 | def get_batch(X, y, batch_num, batch_size=32, batch_first=True): 164 | if not batch_first: 165 | raise NotImplementedError 166 | start = batch_num * batch_size 167 | end = (batch_num + 1) * batch_size 168 | return X[start:end], y[start:end] 169 | 170 | 171 | loss_fn = torch.nn.CrossEntropyLoss() 172 | 173 | optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate) 174 | 175 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', factor=0.5, patience=5, min_lr=8e-5) 176 | 177 | # num_batches = int(len(X_train) / batch_size) 178 | num_test_batches = int(len(X_test) / batch_size) 179 | 180 | memory = model.relational_memory.initial_state(args.batch_size, trainable=True).to(device) 181 | 182 | hist = np.zeros(num_epochs) 183 | hist_acc = np.zeros(num_epochs) 184 | test_hist = np.zeros(num_epochs) 185 | test_hist_acc = np.zeros(num_epochs) 186 | 187 | 188 | def accuracy_score(y_pred, y_true): 189 | return np.array(y_pred == y_true).sum() * 1.0 / len(y_true) 190 | 191 | 192 | #################### 193 | # Train model 194 | #################### 195 | 196 | for t in range(num_epochs): 197 | epoch_loss = np.zeros(num_batches) 198 | epoch_acc = np.zeros(num_batches) 199 | epoch_test_loss = np.zeros(num_test_batches) 200 | epoch_test_acc = np.zeros(num_test_batches) 201 | for i in range(num_batches): 202 | data, targets = get_examples(batch_size, num_vectors, num_dims, device) 203 | model.zero_grad() 204 | 205 | # forward pass 206 | # replace "_" with "memory" if you want to make the RNN stateful 207 | y_pred, _ = model(data, memory) 208 | 209 | loss = loss_fn(y_pred, targets) 210 | loss = torch.mean(loss) 211 | y_pred = torch.argmax(y_pred, dim=1) 212 | acc = accuracy_score(y_pred, targets) 213 | epoch_loss[i] = loss 214 | epoch_acc[i] = acc 215 | 216 | # Zero out gradient, else they will accumulate between epochs 217 | optimiser.zero_grad() 218 | 219 | # backward pass 220 | loss.backward() 221 | 222 | # this helps prevent exploding gradient in RNNs 223 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) 224 | 225 | # update parameters 226 | optimiser.step() 227 | 228 | # test examples 229 | for i in range(num_test_batches): 230 | with torch.no_grad(): 231 | data, targets = get_batch(X_test, y_test, i, batch_size=batch_size) 232 | ytest_pred, _ = model(data, memory) 233 | 234 | test_loss = loss_fn(ytest_pred, targets) 235 | test_loss = torch.mean(test_loss) 236 | ytest_pred = torch.argmax(ytest_pred, dim=1) 237 | test_acc = accuracy_score(ytest_pred, targets) 238 | epoch_test_loss[i] = test_loss 239 | epoch_test_acc[i] = test_acc 240 | 241 | loss = np.mean(epoch_loss) 242 | acc = np.mean(epoch_acc) 243 | test_loss = np.mean(epoch_test_loss) 244 | test_acc = np.mean(epoch_test_acc) 245 | 246 | hist[t] = loss 247 | hist_acc[t] = acc 248 | test_hist[t] = test_loss 249 | test_hist_acc[t] = test_acc 250 | 251 | if t % 10 == 0: 252 | print("Epoch {} train loss: {}".format(t, loss)) 253 | print("Epoch {} test loss: {}".format(t, test_loss)) 254 | print("Epoch {} train acc: {:.2f}".format(t, acc)) 255 | print("Epoch {} test acc: {:.2f}".format(t, test_acc)) 256 | 257 | #################### 258 | # Plot losses 259 | #################### 260 | 261 | plt.plot(hist, label="Training loss") 262 | plt.plot(test_hist, label="Test loss") 263 | plt.legend() 264 | plt.title("Cross entropy loss") 265 | plt.show() 266 | 267 | # Plot accuracy 268 | plt.plot(hist_acc, label="Training accuracy") 269 | plt.plot(test_hist_acc, label="Test accuracy") 270 | plt.title("Accuracy") 271 | plt.legend() 272 | plt.show() 273 | -------------------------------------------------------------------------------- /train_rmc.py: -------------------------------------------------------------------------------- 1 | # copypasta from main.py of pytorch word_language_model code 2 | # coding: utf-8 3 | import argparse 4 | import time 5 | import math 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.onnx 10 | import datetime 11 | import shutil 12 | import pickle 13 | import data 14 | from relational_rnn_models import RelationalMemory 15 | 16 | # is it faster? 17 | torch.backends.cudnn.benchmark = True 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model') 20 | # hyperparams for text data 21 | parser.add_argument('--data', type=str, default='./data/wikitext-2', 22 | help='location of the data corpus') 23 | parser.add_argument('--emsize', type=int, default=192, 24 | help='size of word embeddings') 25 | 26 | # NEW!: hyperparams for relational memory core (RMC) 27 | parser.add_argument('--memslots', type=int, default=1, 28 | help='number of memory slots of the relational memory core') 29 | parser.add_argument('--headsize', type=int, default=192, 30 | help='size of the each head for multihead attention') 31 | parser.add_argument('--numheads', type=int, default=4, 32 | help='total number of heads for multihead attention') 33 | parser.add_argument('--numblocks', type=int, default=1, 34 | help='Number of times to compute attention per time step') 35 | parser.add_argument('--forgetbias', type=float, default=1., 36 | help='Bias to use for the forget gate, assuming we are using some form of gating') 37 | parser.add_argument('--inputbias', type=float, default=0., 38 | help='Bias to use for the input gate, assuming we are using some form of gating') 39 | parser.add_argument('--gatestyle', type=str, default='unit', 40 | help='Whether to use per-element gating (\'unit\'), per-memory slot gating (\'memory\'), or no gating at all (None).') 41 | parser.add_argument('--attmlplayers', type=int, default=3, 42 | help='Number of layers to use in the post-attention MLP') 43 | parser.add_argument('--keysize', type=int, default=64, 44 | help='Size of vector to use for key & query vectors in the attention' 45 | 'computation. Defaults to None, in which case we use `head_size`') 46 | # parameters for adaptive softmax 47 | parser.add_argument('--adaptivesoftmax', action='store_true', 48 | help='use adaptive softmax during hidden state to output logits.' 49 | 'it uses less memory by approximating softmax of large vocabulary.') 50 | parser.add_argument('--cutoffs', nargs="*", type=int, default=[10000, 50000, 100000], 51 | help='cutoff values for adaptive softmax. list of integers.' 52 | 'optimal values are based on word frequencey and vocabulary size of the dataset.') 53 | 54 | # other hyperparams for general RNN mechanics 55 | parser.add_argument('--lr', type=float, default=0.001, 56 | help='initial learning rate') 57 | parser.add_argument('--clip', type=float, default=0.1, 58 | help='gradient clipping') 59 | parser.add_argument('--epochs', type=int, default=100, 60 | help='upper epoch limit') 61 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 62 | help='batch size') 63 | parser.add_argument('--bptt', type=int, default=100, 64 | help='sequence length') 65 | # dropout of RMC is hard-bound to 0.5 at the embedding layer 66 | # parser.add_argument('--dropout', type=float, default=0.2, 67 | # help='dropout applied to layers (0 = no dropout)') 68 | # embed weight tying is set always to true 69 | # parser.add_argument('--tied', action='store_true', 70 | # help='tie the word embedding and softmax weights') 71 | parser.add_argument('--seed', type=int, default=1111, 72 | help='random seed') 73 | parser.add_argument('--cuda', action='store_true', 74 | help='use CUDA') 75 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 76 | help='report interval') 77 | parser.add_argument('--onnx-export', type=str, default='', 78 | help='path to export the final model in onnx format') 79 | parser.add_argument('--resume', type=int, default=None, 80 | help='if specified with the 1-indexed global epoch, loads the checkpoint and resumes training') 81 | 82 | # experiment name for this run 83 | parser.add_argument('--name', type=str, default=None, 84 | help='name for this experiment. generates folder with the name if specified.') 85 | 86 | args = parser.parse_args() 87 | 88 | # Set the random seed manually for reproducibility. 89 | torch.manual_seed(args.seed) 90 | 91 | if torch.cuda.is_available(): 92 | if not args.cuda: 93 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 94 | 95 | device = torch.device("cuda" if args.cuda else "cpu") 96 | ############################################################################### 97 | # Load data 98 | ############################################################################### 99 | corpus_name = os.path.basename(os.path.normpath(args.data)) 100 | corpus_filename = './data/corpus-' + str(corpus_name) + str('.pkl') 101 | if os.path.isfile(corpus_filename): 102 | print("loading pre-built " + str(corpus_name) + " corpus file...") 103 | loadfile = open(corpus_filename, 'rb') 104 | corpus = pickle.load(loadfile) 105 | loadfile.close() 106 | else: 107 | print("building " + str(corpus_name) + " corpus...") 108 | corpus = data.Corpus(args.data) 109 | # save the corpus for later 110 | savefile = open(corpus_filename, 'wb') 111 | pickle.dump(corpus, savefile) 112 | savefile.close() 113 | print("corpus saved to pickle") 114 | 115 | 116 | # Starting from sequential data, batchify arranges the dataset into columns. 117 | # For instance, with the alphabet as the sequence and batch size 4, we'd get 118 | # ┌ a g m s ┐ 119 | # │ b h n t │ 120 | # │ c i o u │ 121 | # │ d j p v │ 122 | # │ e k q w │ 123 | # └ f l r x ┘. 124 | # These columns are treated as independent by the model, which means that the 125 | # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient 126 | # batch processing. 127 | 128 | def batchify(data, bsz): 129 | # Work out how cleanly we can divide the dataset into bsz parts. 130 | nbatch = data.size(0) // bsz 131 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 132 | data = data.narrow(0, 0, nbatch * bsz) 133 | # Evenly divide the data across the bsz batches. 134 | data = data.view(bsz, -1).t().contiguous() 135 | return data.to(device) 136 | 137 | 138 | eval_batch_size = 32 139 | train_data = batchify(corpus.train, args.batch_size) 140 | val_data = batchify(corpus.valid, eval_batch_size) 141 | test_data = batchify(corpus.test, eval_batch_size) 142 | 143 | # create folder for current experiments 144 | # name: args.name + current time 145 | # includes: entire scripts for faithful reproduction, train & test logs 146 | folder_name = str(datetime.datetime.now())[:-7] 147 | if args.name is not None: 148 | folder_name = str(args.name) + ' ' + folder_name 149 | 150 | os.mkdir(folder_name) 151 | for file in os.listdir(os.getcwd()): 152 | if file.endswith(".py"): 153 | shutil.copy2(file, os.path.join(os.getcwd(), folder_name)) 154 | logger_train = open(os.path.join(os.getcwd(), folder_name, 'train_log.txt'), 'w+') 155 | logger_test = open(os.path.join(os.getcwd(), folder_name, 'test_log.txt'), 'w+') 156 | 157 | # save args to logger 158 | logger_train.write(str(args) + '\n') 159 | 160 | # define saved model file location 161 | savepath = os.path.join(os.getcwd(), folder_name) 162 | 163 | ############################################################################### 164 | # Build the model 165 | ############################################################################### 166 | 167 | ntokens = len(corpus.dictionary) 168 | print("vocabulary size (ntokens): " + str(ntokens)) 169 | if args.adaptivesoftmax: 170 | print("Adaptive Softmax is on: the performance depends on cutoff values. check if the cutoff is properly set") 171 | print("Cutoffs: " + str(args.cutoffs)) 172 | if args.cutoffs[-1] > ntokens: 173 | raise ValueError("the last element of cutoff list must be lower than vocab size of the dataset") 174 | 175 | model = RelationalMemory(mem_slots=args.memslots, head_size=args.headsize, input_size=args.emsize, num_tokens=ntokens, 176 | num_heads=args.numheads, num_blocks=args.numblocks, forget_bias=args.forgetbias, 177 | input_bias=args.inputbias, attention_mlp_layers=args.attmlplayers, key_size=args.keysize, 178 | use_adaptive_softmax=args.adaptivesoftmax, cutoffs=args.cutoffs).to(device) 179 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 180 | model = nn.DataParallel(model) 181 | 182 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 183 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5) 184 | 185 | ############################################################################### 186 | # Load the model checkpoint if specified and restore the global & best epoch 187 | ############################################################################### 188 | if args.resume is not None: 189 | print("--resume detected. loading checkpoint...") 190 | global_epoch = args.resume if args.resume is not None else 0 191 | best_epoch = args.resume if args.resume is not None else 0 192 | if args.resume is not None: 193 | loadpath = os.path.join(os.getcwd(), "model_{}.pt".format(args.resume)) 194 | if not os.path.isfile(loadpath): 195 | raise FileNotFoundError( 196 | "model_{}.pt not found. place the model checkpoint file to the current working directory.".format( 197 | args.resume)) 198 | checkpoint = torch.load(loadpath) 199 | model.load_state_dict(checkpoint["state_dict"]) 200 | optimizer.load_state_dict(checkpoint["optimizer"]) 201 | scheduler.load_state_dict(checkpoint["scheduler"]) 202 | global_epoch = checkpoint["global_epoch"] 203 | best_epoch = checkpoint["best_epoch"] 204 | 205 | print("model built, total trainable params: " + str(total_params)) 206 | 207 | 208 | ############################################################################### 209 | # Training code 210 | ############################################################################### 211 | 212 | # get_batch subdivides the source data into chunks of length args.bptt. 213 | # If source is equal to the example output of the batchify function, with 214 | # a bptt-limit of 2, we'd get the following two Variables for i = 0: 215 | # ┌ a g m s ┐ ┌ b h n t ┐ 216 | # └ b h n t ┘ └ c i o u ┘ 217 | # Note that despite the name of the function, the subdivison of data is not 218 | # done along the batch dimension (i.e. dimension 1), since that was handled 219 | # by the batchify function. The chunks are along dimension 0, corresponding 220 | # to the seq_len dimension in the LSTM. 221 | 222 | 223 | def get_batch(source, i): 224 | seq_len = min(args.bptt, len(source) - 1 - i) 225 | data = source[i:i + seq_len] 226 | target = source[i + 1:i + 1 + seq_len].view(-1) 227 | return data, target 228 | 229 | 230 | def evaluate(data_source): 231 | # Turn on evaluation mode which disables dropout. 232 | model.eval() 233 | total_loss = 0. 234 | ntokens = len(corpus.dictionary) 235 | memory = model.module.initial_state(eval_batch_size, trainable=False).to(device) 236 | 237 | with torch.no_grad(): 238 | for i in range(0, data_source.size(0) - 1, args.bptt): 239 | data, targets = get_batch(data_source, i) 240 | data = torch.t(data) 241 | 242 | loss, memory = model(data, memory, targets) 243 | loss = torch.mean(loss) 244 | 245 | # data has shape [T * B, N] 246 | total_loss += args.bptt * loss.item() 247 | 248 | return total_loss / len(data_source) 249 | 250 | 251 | def train(): 252 | # Turn on training mode which enables dropout. 253 | model.train() 254 | total_loss = 0. 255 | forward_elapsed_time = 0. 256 | start_time = time.time() 257 | ntokens = len(corpus.dictionary) 258 | # in RMC, "hidden state" is called "memory" instead. so use the name "memory" 259 | memory = model.module.initial_state(args.batch_size, trainable=True).to(device) 260 | 261 | for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): 262 | data, targets = get_batch(train_data, i) 263 | # transpose the data to [batch, seq] 264 | data = torch.t(data) 265 | 266 | # synchronize cuda for a proper speed benchmark 267 | torch.cuda.synchronize() 268 | 269 | forward_start_time = time.time() 270 | model.zero_grad() 271 | 272 | # the forward pass of RMC just returns loss and does not return logits (DataParallel code optimization) 273 | loss, memory = model(data, memory, targets) 274 | loss = torch.mean(loss) 275 | total_loss += loss.item() 276 | 277 | # synchronize cuda for a proper speed benchmark 278 | torch.cuda.synchronize() 279 | 280 | forward_elapsed = time.time() - forward_start_time 281 | forward_elapsed_time += forward_elapsed 282 | 283 | loss.backward() 284 | 285 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 286 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 287 | 288 | optimizer.step() 289 | 290 | if batch % args.log_interval == 0 and batch > 0: 291 | cur_loss = total_loss / args.log_interval 292 | elapsed = time.time() - start_time 293 | printlog = '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | forward ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}'.format( 294 | epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], 295 | elapsed * 1000 / args.log_interval, forward_elapsed_time * 1000 / args.log_interval, 296 | cur_loss, math.exp(cur_loss)) 297 | # print and save the log 298 | print(printlog) 299 | logger_train.write(printlog + '\n') 300 | logger_train.flush() 301 | total_loss = 0. 302 | # reset timer 303 | start_time = time.time() 304 | forward_start_time = time.time() 305 | forward_elapsed_time = 0. 306 | 307 | 308 | def export_onnx(path, batch_size, seq_len): 309 | print('The model is also exported in ONNX format at {}'. 310 | format(os.path.realpath(args.onnx_export))) 311 | model.eval() 312 | dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device) 313 | hidden = model.init_hidden(batch_size) 314 | torch.onnx.export(model, (dummy_input, hidden), path) 315 | 316 | 317 | # Loop over epochs. 318 | best_val_loss = None 319 | 320 | # At any point you can hit Ctrl + C to break out of training early. 321 | try: 322 | print("training started...") 323 | if global_epoch > args.epochs: 324 | raise ValueError("global_epoch is higher than args.epochs when resuming training.") 325 | for epoch in range(global_epoch + 1, args.epochs + 1): 326 | global_epoch += 1 327 | epoch_start_time = time.time() 328 | train() 329 | val_loss = evaluate(val_data) 330 | 331 | print('-' * 89) 332 | testlog = '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'.format(epoch, ( 333 | time.time() - epoch_start_time), val_loss, math.exp(val_loss)) 334 | print(testlog) 335 | logger_test.write(testlog + '\n') 336 | logger_test.flush() 337 | print('-' * 89) 338 | 339 | scheduler.step(val_loss) 340 | 341 | # Save the model if the validation loss is the best we've seen so far. 342 | # model_{} contains state_dict and other states, model_dump_{} contains all the dependencies for generate_rmc.py 343 | if not best_val_loss or val_loss < best_val_loss: 344 | try: 345 | os.remove(os.path.join(savepath, "model_{}.pt".format(best_epoch))) 346 | os.remove(os.path.join(savepath, "model_dump_{}.pt").format(best_epoch)) 347 | except FileNotFoundError: 348 | pass 349 | best_epoch = global_epoch 350 | torch.save(model, os.path.join(savepath, "model_dump_{}.pt".format(global_epoch))) 351 | with open(os.path.join(savepath, "model_{}.pt".format(global_epoch)), 'wb') as f: 352 | optimizer_state = optimizer.state_dict() 353 | scheduler_state = scheduler.state_dict() 354 | torch.save({"state_dict": model.state_dict(), 355 | "optimizer": optimizer_state, 356 | "scheduler": scheduler_state, 357 | "global_epoch": global_epoch, 358 | "best_epoch": best_epoch}, f) 359 | best_val_loss = val_loss 360 | else: 361 | pass 362 | 363 | except KeyboardInterrupt: 364 | print('-' * 89) 365 | print('Exiting from training early: loading checkpoint from the best epoch {}...'.format(best_epoch)) 366 | 367 | # Load the best saved model. 368 | with open(os.path.join(savepath, "model_{}.pt".format(best_epoch)), 'rb') as f: 369 | checkpoint = torch.load(f) 370 | model.load_state_dict(checkpoint["state_dict"]) 371 | 372 | # Run on test data. 373 | test_loss = evaluate(test_data) 374 | 375 | print('=' * 89) 376 | testlog = '| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 377 | test_loss, math.exp(test_loss)) 378 | print(testlog) 379 | logger_test.write(testlog + '\n') 380 | logger_test.flush() 381 | print('=' * 89) 382 | 383 | if len(args.onnx_export) > 0: 384 | # Export the model in ONNX format. 385 | export_onnx(args.onnx_export, batch_size=1, seq_len=args.bptt) 386 | -------------------------------------------------------------------------------- /train_rnn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import time 4 | import math 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.onnx 9 | import datetime 10 | import shutil 11 | import pickle 12 | import data 13 | import rnn_models 14 | 15 | # is it faster? 16 | torch.backends.cudnn.benchmark = True 17 | 18 | # same hyperparameter scheme as word-language-model 19 | parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model') 20 | parser.add_argument('--data', type=str, default='./data/wikitext-2', 21 | help='location of the data corpus') 22 | parser.add_argument('--model', type=str, default='LSTM', 23 | help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)') 24 | parser.add_argument('--emsize', type=int, default=300, 25 | help='size of word embeddings') 26 | parser.add_argument('--nhid', type=int, default=300, 27 | help='number of hidden units per layer') 28 | parser.add_argument('--nlayers', type=int, default=1, 29 | help='number of layers') 30 | parser.add_argument('--lr', type=float, default=0.001, 31 | help='initial learning rate') 32 | parser.add_argument('--clip', type=float, default=0.1, 33 | help='gradient clipping') 34 | parser.add_argument('--epochs', type=int, default=100, 35 | help='upper epoch limit') 36 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 37 | help='batch size') 38 | parser.add_argument('--bptt', type=int, default=100, 39 | help='sequence length') 40 | parser.add_argument('--dropout', type=float, default=0.5, 41 | help='dropout applied to layers (0 = no dropout)') 42 | parser.add_argument('--tied', action='store_true', default=True, 43 | help='tie the word embedding and softmax weights') 44 | parser.add_argument('--seed', type=int, default=1111, 45 | help='random seed') 46 | parser.add_argument('--cuda', action='store_true', 47 | help='use CUDA') 48 | parser.add_argument('--cudnn', action='store_true', 49 | help='use cudnn optimized version. i.e. use RNN instead of RNNCell with for loop') 50 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 51 | help='report interval') 52 | parser.add_argument('--save', type=str, default='model.pt', 53 | help='path to save the final model') 54 | parser.add_argument('--onnx-export', type=str, default='', 55 | help='path to export the final model in onnx format') 56 | parser.add_argument('--resume', type=int, default=None, 57 | help='if specified with the 1-indexed global epoch, loads the checkpoint and resumes training') 58 | 59 | # parameters for adaptive softmax 60 | parser.add_argument('--adaptivesoftmax', action='store_true', 61 | help='use adaptive softmax during hidden state to output logits.' 62 | 'it uses less memory by approximating softmax of large vocabulary.') 63 | parser.add_argument('--cutoffs', nargs="*", type=int, default=[10000, 50000, 100000], 64 | help='cutoff values for adaptive softmax. list of integers.' 65 | 'optimal values are based on word frequencey and vocabulary size of the dataset.') 66 | 67 | # experiment name for this run 68 | parser.add_argument('--name', type=str, default=None, 69 | help='name for this experiment. generates folder with the name if specified.') 70 | 71 | args = parser.parse_args() 72 | 73 | # Set the random seed manually for reproducibility. 74 | torch.manual_seed(args.seed) 75 | 76 | if torch.cuda.is_available(): 77 | if not args.cuda: 78 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 79 | 80 | device = torch.device("cuda" if args.cuda else "cpu") 81 | ############################################################################### 82 | # Load data 83 | ############################################################################### 84 | corpus_name = os.path.basename(os.path.normpath(args.data)) 85 | corpus_filename = './data/corpus-' + str(corpus_name) + str('.pkl') 86 | if os.path.isfile(corpus_filename): 87 | print("loading pre-built " + str(corpus_name) + " corpus file...") 88 | loadfile = open(corpus_filename, 'rb') 89 | corpus = pickle.load(loadfile) 90 | loadfile.close() 91 | else: 92 | print("building " + str(corpus_name) + " corpus...") 93 | corpus = data.Corpus(args.data) 94 | # save the corpus for later 95 | savefile = open(corpus_filename, 'wb') 96 | pickle.dump(corpus, savefile) 97 | savefile.close() 98 | print("corpus saved to pickle") 99 | 100 | 101 | # Starting from sequential data, batchify arranges the dataset into columns. 102 | # For instance, with the alphabet as the sequence and batch size 4, we'd get 103 | # ┌ a g m s ┐ 104 | # │ b h n t │ 105 | # │ c i o u │ 106 | # │ d j p v │ 107 | # │ e k q w │ 108 | # └ f l r x ┘. 109 | # These columns are treated as independent by the model, which means that the 110 | # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient 111 | # batch processing. 112 | 113 | def batchify(data, bsz): 114 | # Work out how cleanly we can divide the dataset into bsz parts. 115 | nbatch = data.size(0) // bsz 116 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 117 | data = data.narrow(0, 0, nbatch * bsz) 118 | # Evenly divide the data across the bsz batches. 119 | data = data.view(bsz, -1).t().contiguous() 120 | return data.to(device) 121 | 122 | 123 | eval_batch_size = 32 124 | train_data = batchify(corpus.train, args.batch_size) 125 | val_data = batchify(corpus.valid, eval_batch_size) 126 | test_data = batchify(corpus.test, eval_batch_size) 127 | 128 | # create folder for current experiments 129 | # name: args.name + current time 130 | # includes: entire scripts for faithful reproduction, train & test logs 131 | folder_name = str(datetime.datetime.now())[:-7] 132 | if args.name is not None: 133 | folder_name = str(args.name) + ' ' + folder_name 134 | 135 | os.mkdir(folder_name) 136 | for file in os.listdir(os.getcwd()): 137 | if file.endswith(".py"): 138 | shutil.copy2(file, os.path.join(os.getcwd(), folder_name)) 139 | logger_train = open(os.path.join(os.getcwd(), folder_name, 'train_log.txt'), 'w+') 140 | logger_test = open(os.path.join(os.getcwd(), folder_name, 'test_log.txt'), 'w+') 141 | 142 | # save args to logger 143 | logger_train.write(str(args) + '\n') 144 | 145 | # define saved model file location 146 | savepath = os.path.join(os.getcwd(), folder_name) 147 | 148 | ############################################################################### 149 | # Build the model 150 | ############################################################################### 151 | 152 | ntokens = len(corpus.dictionary) 153 | print("vocabulary size (ntokens): " + str(ntokens)) 154 | if args.adaptivesoftmax: 155 | print("Adaptive Softmax is on: the performance depends on cutoff values. check if the cutoff is properly set") 156 | print("Cutoffs: " + str(args.cutoffs)) 157 | if args.cutoffs[-1] > ntokens: 158 | raise ValueError("the last element of cutoff list must be lower than vocab size of the dataset") 159 | criterion_adaptive = nn.AdaptiveLogSoftmaxWithLoss(args.nhid, ntokens, cutoffs=args.cutoffs).to(device) 160 | else: 161 | criterion = nn.CrossEntropyLoss() 162 | 163 | model = rnn_models.RNNModel(args.model, ntokens, args.emsize, args.nhid, 164 | args.nlayers, args.dropout, args.tied, 165 | use_cudnn_version=args.cudnn, use_adaptive_softmax=args.adaptivesoftmax, 166 | cutoffs=args.cutoffs).to(device) 167 | 168 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 169 | print("model built, total trainable params: " + str(total_params)) 170 | if not args.cudnn: 171 | print( 172 | "--cudnn is set to False. the model will use RNNCell with for loop, instead of cudnn-optimzed RNN API. Expect a minor slowdown.") 173 | 174 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 175 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5) 176 | 177 | ############################################################################### 178 | # Load the model checkpoint if specified and restore the global & best epoch 179 | ############################################################################### 180 | if args.resume is not None: 181 | print("--resume detected. loading checkpoint...") 182 | global_epoch = args.resume if args.resume is not None else 0 183 | best_epoch = args.resume if args.resume is not None else 0 184 | if args.resume is not None: 185 | loadpath = os.path.join(os.getcwd(), "model_{}.pt".format(args.resume)) 186 | if not os.path.isfile(loadpath): 187 | raise FileNotFoundError( 188 | "model_{}.pt not found. place the model checkpoint file to the current working directory.".format( 189 | args.resume)) 190 | checkpoint = torch.load(loadpath) 191 | model.load_state_dict(checkpoint["state_dict"]) 192 | optimizer.load_state_dict(checkpoint["optimizer"]) 193 | scheduler.load_state_dict(checkpoint["scheduler"]) 194 | global_epoch = checkpoint["global_epoch"] 195 | best_epoch = checkpoint["best_epoch"] 196 | 197 | print("model built, total trainable params: " + str(total_params)) 198 | 199 | 200 | ############################################################################### 201 | # Training code 202 | ############################################################################### 203 | 204 | def repackage_hidden(h): 205 | """Wraps hidden states in new Tensors, to detach them from their history.""" 206 | if isinstance(h, torch.Tensor): 207 | return h.detach() 208 | else: 209 | return tuple(repackage_hidden(v) for v in h) 210 | 211 | 212 | # get_batch subdivides the source data into chunks of length args.bptt. 213 | # If source is equal to the example output of the batchify function, with 214 | # a bptt-limit of 2, we'd get the following two Variables for i = 0: 215 | # ┌ a g m s ┐ ┌ b h n t ┐ 216 | # └ b h n t ┘ └ c i o u ┘ 217 | # Note that despite the name of the function, the subdivison of data is not 218 | # done along the batch dimension (i.e. dimension 1), since that was handled 219 | # by the batchify function. The chunks are along dimension 0, corresponding 220 | # to the seq_len dimension in the LSTM. 221 | 222 | 223 | def get_batch(source, i): 224 | seq_len = min(args.bptt, len(source) - 1 - i) 225 | data = source[i:i + seq_len] 226 | target = source[i + 1:i + 1 + seq_len].view(-1) 227 | return data, target 228 | 229 | 230 | def evaluate(data_source): 231 | # Turn on evaluation mode which disables dropout. 232 | model.eval() 233 | total_loss = 0. 234 | ntokens = len(corpus.dictionary) 235 | hidden = model.init_hidden(eval_batch_size) 236 | with torch.no_grad(): 237 | for i in range(0, data_source.size(0) - 1, args.bptt): 238 | data, targets = get_batch(data_source, i) 239 | output, hidden = model(data, hidden) 240 | if not args.adaptivesoftmax: 241 | loss = criterion(output.view(-1, ntokens), targets) 242 | else: 243 | _, loss = criterion_adaptive(output.view(-1, args.nhid), targets) 244 | total_loss += len(data) * loss.item() 245 | hidden = repackage_hidden(hidden) 246 | return total_loss / len(data_source) 247 | 248 | 249 | def train(): 250 | # Turn on training mode which enables dropout. 251 | model.train() 252 | total_loss = 0. 253 | forward_elapsed_time = 0. 254 | start_time = time.time() 255 | ntokens = len(corpus.dictionary) 256 | hidden = model.init_hidden(args.batch_size) 257 | for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): 258 | data, targets = get_batch(train_data, i) 259 | 260 | # synchronize cuda for a proper speed benchmark 261 | torch.cuda.synchronize() 262 | 263 | # Starting each batch, we detach the hidden state from how it was previously produced. 264 | # If we didn't, the model would try backpropagating all the way to start of the dataset. 265 | forward_start_time = time.time() 266 | 267 | hidden = repackage_hidden(hidden) 268 | model.zero_grad() 269 | 270 | output, hidden = model(data, hidden) 271 | if not args.adaptivesoftmax: 272 | loss = criterion(output.view(-1, ntokens), targets) 273 | else: 274 | _, loss = criterion_adaptive(output.view(-1, args.nhid), targets) 275 | total_loss += loss.item() 276 | 277 | # synchronize cuda for a proper speed benchmark 278 | torch.cuda.synchronize() 279 | 280 | forward_elapsed = time.time() - forward_start_time 281 | forward_elapsed_time += forward_elapsed 282 | 283 | loss.backward() 284 | 285 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 286 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 287 | 288 | optimizer.step() 289 | 290 | if batch % args.log_interval == 0 and batch > 0: 291 | cur_loss = total_loss / args.log_interval 292 | elapsed = time.time() - start_time 293 | printlog = '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | forward ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}'.format( 294 | epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], 295 | elapsed * 1000 / args.log_interval, forward_elapsed_time * 1000 / args.log_interval, 296 | cur_loss, math.exp(cur_loss)) 297 | # print and save the log 298 | print(printlog) 299 | logger_train.write(printlog + '\n') 300 | logger_train.flush() 301 | total_loss = 0. 302 | # reset timer 303 | start_time = time.time() 304 | forward_start_time = time.time() 305 | forward_elapsed_time = 0. 306 | 307 | 308 | def export_onnx(path, batch_size, seq_len): 309 | print('The model is also exported in ONNX format at {}'. 310 | format(os.path.realpath(args.onnx_export))) 311 | model.eval() 312 | dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device) 313 | hidden = model.init_hidden(batch_size) 314 | torch.onnx.export(model, (dummy_input, hidden), path) 315 | 316 | 317 | # Loop over epochs. 318 | best_val_loss = None 319 | 320 | # At any point you can hit Ctrl + C to break out of training early. 321 | try: 322 | print("training started...") 323 | if global_epoch > args.epochs: 324 | raise ValueError("global_epoch is higher than args.epochs when resuming training.") 325 | for epoch in range(global_epoch + 1, args.epochs + 1): 326 | global_epoch += 1 327 | epoch_start_time = time.time() 328 | train() 329 | val_loss = evaluate(val_data) 330 | 331 | print('-' * 89) 332 | testlog = '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'.format(epoch, ( 333 | time.time() - epoch_start_time), val_loss, math.exp(val_loss)) 334 | print(testlog) 335 | logger_test.write(testlog + '\n') 336 | logger_test.flush() 337 | print('-' * 89) 338 | 339 | scheduler.step(val_loss) 340 | 341 | # Save the model if the validation loss is the best we've seen so far. 342 | # model_{} contains state_dict and other states, model_dump_{} contains all the dependencies for generate_rmc.py 343 | if not best_val_loss or val_loss < best_val_loss: 344 | try: 345 | os.remove(os.path.join(savepath, "model_{}.pt".format(best_epoch))) 346 | os.remove(os.path.join(savepath, "model_dump_{}.pt").format(best_epoch)) 347 | except FileNotFoundError: 348 | pass 349 | best_epoch = global_epoch 350 | torch.save(model, os.path.join(savepath, "model_dump_{}.pt".format(global_epoch))) 351 | with open(os.path.join(savepath, "model_{}.pt".format(global_epoch)), 'wb') as f: 352 | optimizer_state = optimizer.state_dict() 353 | scheduler_state = scheduler.state_dict() 354 | torch.save({"state_dict": model.state_dict(), 355 | "optimizer": optimizer_state, 356 | "scheduler": scheduler_state, 357 | "global_epoch": global_epoch, 358 | "best_epoch": best_epoch}, f) 359 | best_val_loss = val_loss 360 | else: 361 | pass 362 | 363 | except KeyboardInterrupt: 364 | print('-' * 89) 365 | print('Exiting from training early: loading checkpoint from the best epoch {}...'.format(best_epoch)) 366 | 367 | # Load the best saved model. 368 | with open(os.path.join(savepath, "model_{}.pt".format(best_epoch)), 'rb') as f: 369 | checkpoint = torch.load(f) 370 | model.load_state_dict(checkpoint["state_dict"]) 371 | # after load the rnn params are not a continuous chunk of memory 372 | # this makes them a continuous chunk, and will speed up forward pass 373 | if args.cudnn: 374 | model.rnn.flatten_parameters() 375 | 376 | # Run on test data. 377 | test_loss = evaluate(test_data) 378 | 379 | print('=' * 89) 380 | testlog = '| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( 381 | test_loss, math.exp(test_loss)) 382 | print(testlog) 383 | logger_test.write(testlog + '\n') 384 | logger_test.flush() 385 | print('=' * 89) 386 | 387 | if len(args.onnx_export) > 0: 388 | # Export the model in ONNX format. 389 | export_onnx(args.onnx_export, batch_size=1, seq_len=args.bptt) 390 | --------------------------------------------------------------------------------