├── LICENSE ├── README.md ├── getdata.sh ├── prep_text8.py ├── pytorch ├── .DS_Store ├── README.md ├── data_utils.py ├── eval.py ├── mem_transformer.py ├── run_enwik8_base.sh ├── run_enwik8_large.sh ├── run_lm1b_base.sh ├── run_lm1b_large.sh ├── run_text8_base.sh ├── run_text8_large.sh ├── run_wt103_base.sh ├── run_wt103_large.sh ├── train.py └── utils │ ├── adaptive_softmax.py │ ├── data_parallel.py │ ├── exp_utils.py │ ├── log_uniform_sampler.py │ ├── proj_adaptive_softmax.py │ └── vocabulary.py └── tf ├── README.md ├── avg_checkpoints.py ├── data_utils.py ├── gpu_utils.py ├── model.py ├── scripts ├── enwik8_base_gpu.sh ├── enwik8_large_tpu.sh ├── lm1b_base_gpu.sh ├── lm1b_large_tpu.sh ├── text8_base_gpu.sh ├── text8_large_tpu.sh ├── wt103_base_gpu.sh └── wt103_large_tpu.sh ├── sota ├── download.sh ├── enwik8.sh ├── lm1b.sh ├── text8.sh └── wt103.sh ├── tpu_estimator.py ├── train.py ├── train_gpu.py └── vocabulary.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context 2 | 3 | This repository contains the code in both **PyTorch** and **TensorFlow** for our paper 4 | >[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](http://arxiv.org/abs/1901.02860) 5 | 6 | >Zihang Dai\*, Zhilin Yang\*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov (*: equal contribution) 7 | 8 | >Preprint 2018 9 | 10 | ## TensorFlow 11 | 12 | - The source code is in the `tf/` folder, supporting (1) single-node multi-gpu training, and (2) multi-host TPU training. 13 | - Besides the source code, we also provide pretrained "TensorFlow" models with state-of-the-art (SoTA) performances reported in the paper. 14 | - Please refer to `tf/README.md` for details. 15 | 16 | ## PyTorch 17 | 18 | - The source code is in the `pytorch/` folder, supporting single-node multi-gpu training via the module `nn.DataParallel`. 19 | - Please refer to `pytorch/README.md` for details. 20 | 21 | ## Results 22 | 23 | Transformer-XL achieves new state-of-the-art results on multiple language modeling benchmarks. Transformer-XL is also the first to break through the 1.0 barrier on char-level language modeling. Below is a summary. 24 | 25 | Method | enwiki8 | text8 | One Billion Word | WT-103 | PTB (w/o finetuning) 26 | -- | -- | -- | -- | -- | -- 27 | Previous Best | 1.06 | 1.13 | 23.7 | 20.5 | 55.5 28 | Transformer-XL | **0.99** | **1.08** | **21.8** | **18.3** | **54.5** 29 | 30 | 31 | 32 | ## Acknowledgement 33 | 34 | A large portion of the `getdata.sh` script comes from the [awd-lstm](https://github.com/salesforce/awd-lstm-lm/) repo. Happy Language Modeling :) 35 | -------------------------------------------------------------------------------- /getdata.sh: -------------------------------------------------------------------------------- 1 | echo "=== Acquiring datasets ===" 2 | echo "---" 3 | 4 | mkdir -p data 5 | cd data 6 | 7 | if [[ ! -d 'wikitext-2' ]]; then 8 | echo "- Downloading WikiText-2 (WT2)" 9 | wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip 10 | unzip -q wikitext-2-v1.zip 11 | cd wikitext-2 12 | mv wiki.train.tokens train.txt 13 | mv wiki.valid.tokens valid.txt 14 | mv wiki.test.tokens test.txt 15 | cd .. 16 | fi 17 | 18 | echo "- Downloading WikiText-103 (WT2)" 19 | if [[ ! -d 'wikitext-103' ]]; then 20 | wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip 21 | unzip -q wikitext-103-v1.zip 22 | cd wikitext-103 23 | mv wiki.train.tokens train.txt 24 | mv wiki.valid.tokens valid.txt 25 | mv wiki.test.tokens test.txt 26 | cd .. 27 | fi 28 | 29 | echo "- Downloading enwik8 (Character)" 30 | if [[ ! -d 'enwik8' ]]; then 31 | mkdir -p enwik8 32 | cd enwik8 33 | wget --continue http://mattmahoney.net/dc/enwik8.zip 34 | wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py 35 | python3 prep_enwik8.py 36 | cd .. 37 | fi 38 | 39 | echo "- Downloading text8 (Character)" 40 | if [[ ! -d 'text8' ]]; then 41 | mkdir -p text8 42 | cd text8 43 | wget --continue http://mattmahoney.net/dc/text8.zip 44 | python ../../prep_text8.py 45 | cd .. 46 | fi 47 | 48 | echo "- Downloading Penn Treebank (PTB)" 49 | if [[ ! -d 'penn' ]]; then 50 | wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz 51 | tar -xzf simple-examples.tgz 52 | 53 | mkdir -p penn 54 | cd penn 55 | mv ../simple-examples/data/ptb.train.txt train.txt 56 | mv ../simple-examples/data/ptb.test.txt test.txt 57 | mv ../simple-examples/data/ptb.valid.txt valid.txt 58 | cd .. 59 | 60 | echo "- Downloading Penn Treebank (Character)" 61 | mkdir -p pennchar 62 | cd pennchar 63 | mv ../simple-examples/data/ptb.char.train.txt train.txt 64 | mv ../simple-examples/data/ptb.char.test.txt test.txt 65 | mv ../simple-examples/data/ptb.char.valid.txt valid.txt 66 | cd .. 67 | 68 | rm -rf simple-examples/ 69 | fi 70 | 71 | echo "- Downloading 1B words" 72 | 73 | if [[ ! -d 'one-billion-words' ]]; then 74 | mkdir -p one-billion-words 75 | cd one-billion-words 76 | 77 | wget --no-proxy http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz 78 | tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz 79 | 80 | path="1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/" 81 | cat ${path}/news.en.heldout-00000-of-00050 > valid.txt 82 | cat ${path}/news.en.heldout-00000-of-00050 > test.txt 83 | 84 | wget https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt 85 | 86 | cd .. 87 | fi 88 | 89 | echo "---" 90 | echo "Happy language modeling :)" 91 | -------------------------------------------------------------------------------- /prep_text8.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import os 5 | import sys 6 | import zipfile 7 | 8 | from io import open 9 | 10 | if os.path.exists('train.txt'): 11 | print('Tokenized text8 already exists - skipping processing') 12 | sys.exit() 13 | 14 | data = zipfile.ZipFile('text8.zip').extractall() 15 | data = open('text8', 'r', encoding='utf-8').read() 16 | 17 | print('Length of text8: {}'.format(len(data))) 18 | 19 | num_test_chars = 5000000 20 | 21 | train_data = data[: -2 * num_test_chars] 22 | valid_data = data[-2 * num_test_chars: -num_test_chars] 23 | test_data = data[-num_test_chars:] 24 | 25 | for fn, part in [('train.txt', train_data), ('valid.txt', valid_data), ('test.txt', test_data)]: 26 | print('{} will have {} bytes'.format(fn, len(part))) 27 | print('- Tokenizing...') 28 | # Change space ' ' to underscore '_' 29 | part_str = ' '.join(['_' if c == ' ' else c for c in part.strip()]) 30 | print('- Writing...') 31 | f = open(fn, 'w').write(part_str) 32 | f = open(fn + '.raw', 'w', encoding='utf-8').write(part) 33 | -------------------------------------------------------------------------------- /pytorch/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TimDettmers/transformer-xl/44781ed21dbaec88b280f74d9ae2877f52b492a5/pytorch/.DS_Store -------------------------------------------------------------------------------- /pytorch/README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | This directory contains our pytorch implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our pytorch codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts: 4 | - `*large.sh` are for the SoTA setting with large models which might not be directly runnable on a local GPU machine. 5 | - `*base.sh` are for the base models which can be run on a few GPUs. 6 | 7 | The pytorch implementation produces similar results to the TF codebase under the same settings in our preliminary experiments. 8 | 9 | 10 | ## Prerequisite 11 | 12 | - Pytorch 0.4: `conda install pytorch torchvision -c pytorch` 13 | 14 | 15 | ## Data Prepration 16 | 17 | `bash getdata.sh` 18 | 19 | ## Training and Evaluation 20 | 21 | #### Replicate the "bpc = 1.06" result on `enwik8` with a 12-layer Transformer-XL 22 | 23 | - Make sure the machine have **4 GPUs**, each with **at least 11G memory** 24 | 25 | - Training 26 | 27 | `bash run_enwik8_base.sh train --work_dir PATH_TO_WORK_DIR` 28 | 29 | - Evaluation 30 | 31 | `bash run_enwik8_base.sh eval --work_dir PATH_TO_WORK_DIR` 32 | 33 | 34 | 35 | #### Replicate the "PPL = 24.03" result on `wikitext-103` with Transformer-XL 36 | 37 | - Make sure the machine have **4 GPUs**, each with **at least 11G memory** 38 | 39 | - Training 40 | 41 | `bash run_wt103_base.sh train --work_dir PATH_TO_WORK_DIR` 42 | 43 | - Evaluation 44 | 45 | `bash run_wt103_base.sh eval --work_dir PATH_TO_WORK_DIR` 46 | 47 | 48 | 49 | #### Other options: 50 | 51 | - `--batch_chunk`: this option allows one to trade speed for memory. For `batch_chunk > 1`, the program will split each training batch into `batch_chunk` sub-batches and perform forward and backward on each sub-batch sequentially, with the gradient accumulated and divided by `batch_chunk`. Hence, the memory usage will propertionally lower while the computation time will inversely higher. 52 | - `--div_val`: when using adaptive softmax and embedding, the embedding dimension is divided by `div_val` from bin $i$ to bin $i+1$. This saves both GPU memory and the parameter budget. 53 | - `--fp16` and `--dynamic-loss-scale`: Run in pseudo-fp16 mode (fp16 storage fp32 math) with dynamic loss scaling. 54 | - Note: to explore the `--fp16` option, please make sure the `apex` package is installed (https://github.com/NVIDIA/apex/). 55 | - To see performance without the recurrence mechanism, simply use `mem_len=0` in all your scripts. 56 | - To see performance of a standard Transformer without relative positional encodings or recurrence mechanisms, use `attn_type=2` and `mem_len=0`. 57 | 58 | 59 | #### Other datasets: 60 | 61 | - `Text8` character-level language modeling: check out `run_text8_base.sh` 62 | - `lm1b` word-level language modeling: check out `run_lm1b_base.sh` 63 | -------------------------------------------------------------------------------- /pytorch/data_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import glob 3 | 4 | from collections import Counter, OrderedDict 5 | import numpy as np 6 | import torch 7 | 8 | from utils.vocabulary import Vocab 9 | 10 | class LMOrderedIterator(object): 11 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): 12 | """ 13 | data -- LongTensor -- the LongTensor is strictly ordered 14 | """ 15 | self.bsz = bsz 16 | self.bptt = bptt 17 | self.ext_len = ext_len if ext_len is not None else 0 18 | 19 | self.device = device 20 | 21 | # Work out how cleanly we can divide the dataset into bsz parts. 22 | self.n_step = data.size(0) // bsz 23 | 24 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 25 | data = data.narrow(0, 0, self.n_step * bsz) 26 | 27 | # Evenly divide the data across the bsz batches. 28 | self.data = data.view(bsz, -1).t().contiguous().to(device) 29 | 30 | # Number of mini-batches 31 | self.n_batch = (self.n_step + self.bptt - 1) // self.bptt 32 | 33 | def get_batch(self, i, bptt=None): 34 | if bptt is None: bptt = self.bptt 35 | seq_len = min(bptt, self.data.size(0) - 1 - i) 36 | 37 | end_idx = i + seq_len 38 | beg_idx = max(0, i - self.ext_len) 39 | 40 | data = self.data[beg_idx:end_idx] 41 | target = self.data[i+1:i+1+seq_len] 42 | 43 | return data, target, seq_len 44 | 45 | def get_fixlen_iter(self, start=0): 46 | for i in range(start, self.data.size(0) - 1, self.bptt): 47 | yield self.get_batch(i) 48 | 49 | def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): 50 | max_len = self.bptt + max_deviation * std 51 | i = start 52 | while True: 53 | bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. 54 | bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) 55 | data, target, seq_len = self.get_batch(i, bptt) 56 | i += seq_len 57 | yield data, target, seq_len 58 | if i >= self.data.size(0) - 2: 59 | break 60 | 61 | def __iter__(self): 62 | return self.get_fixlen_iter() 63 | 64 | 65 | class LMShuffledIterator(object): 66 | def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): 67 | """ 68 | data -- list[LongTensor] -- there is no order among the LongTensors 69 | """ 70 | self.data = data 71 | 72 | self.bsz = bsz 73 | self.bptt = bptt 74 | self.ext_len = ext_len if ext_len is not None else 0 75 | 76 | self.device = device 77 | self.shuffle = shuffle 78 | 79 | def get_sent_stream(self): 80 | # index iterator 81 | epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ 82 | else np.array(range(len(self.data))) 83 | 84 | # sentence iterator 85 | for idx in epoch_indices: 86 | yield self.data[idx] 87 | 88 | def stream_iterator(self, sent_stream): 89 | # streams for each data in the batch 90 | streams = [None] * self.bsz 91 | 92 | data = torch.LongTensor(self.bptt, self.bsz) 93 | target = torch.LongTensor(self.bptt, self.bsz) 94 | 95 | n_retain = 0 96 | 97 | while True: 98 | # data : [n_retain+bptt x bsz] 99 | # target : [bptt x bsz] 100 | data[n_retain:].fill_(-1) 101 | target.fill_(-1) 102 | 103 | valid_batch = True 104 | 105 | for i in range(self.bsz): 106 | n_filled = 0 107 | try: 108 | while n_filled < self.bptt: 109 | if streams[i] is None or len(streams[i]) <= 1: 110 | streams[i] = next(sent_stream) 111 | # number of new tokens to fill in 112 | n_new = min(len(streams[i]) - 1, self.bptt - n_filled) 113 | # first n_retain tokens are retained from last batch 114 | data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ 115 | streams[i][:n_new] 116 | target[n_filled:n_filled+n_new, i] = \ 117 | streams[i][1:n_new+1] 118 | streams[i] = streams[i][n_new:] 119 | n_filled += n_new 120 | except StopIteration: 121 | valid_batch = False 122 | break 123 | 124 | if not valid_batch: 125 | return 126 | 127 | data = data.to(self.device) 128 | target = target.to(self.device) 129 | 130 | yield data, target, self.bptt 131 | 132 | n_retain = min(data.size(0), self.ext_len) 133 | if n_retain > 0: 134 | data[:n_retain] = data[-n_retain:] 135 | data.resize_(n_retain + self.bptt, data.size(1)) 136 | 137 | def __iter__(self): 138 | # sent_stream is an iterator 139 | sent_stream = self.get_sent_stream() 140 | 141 | for batch in self.stream_iterator(sent_stream): 142 | yield batch 143 | 144 | 145 | class LMMultiFileIterator(LMShuffledIterator): 146 | def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, 147 | shuffle=False): 148 | 149 | self.paths = paths 150 | self.vocab = vocab 151 | 152 | self.bsz = bsz 153 | self.bptt = bptt 154 | self.ext_len = ext_len if ext_len is not None else 0 155 | 156 | self.device = device 157 | self.shuffle = shuffle 158 | 159 | def get_sent_stream(self, path): 160 | sents = self.vocab.encode_file(path, add_double_eos=True) 161 | if self.shuffle: 162 | np.random.shuffle(sents) 163 | sent_stream = iter(sents) 164 | 165 | return sent_stream 166 | 167 | def __iter__(self): 168 | if self.shuffle: 169 | np.random.shuffle(self.paths) 170 | 171 | for path in self.paths: 172 | # sent_stream is an iterator 173 | sent_stream = self.get_sent_stream(path) 174 | for batch in self.stream_iterator(sent_stream): 175 | yield batch 176 | 177 | 178 | class Corpus(object): 179 | def __init__(self, path, dataset, *args, **kwargs): 180 | self.dataset = dataset 181 | self.vocab = Vocab(*args, **kwargs) 182 | 183 | if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: 184 | self.vocab.count_file(os.path.join(path, 'train.txt')) 185 | self.vocab.count_file(os.path.join(path, 'valid.txt')) 186 | self.vocab.count_file(os.path.join(path, 'test.txt')) 187 | elif self.dataset == 'wt103': 188 | self.vocab.count_file(os.path.join(path, 'train.txt')) 189 | elif self.dataset == 'lm1b': 190 | train_path_pattern = os.path.join( 191 | path, '1-billion-word-language-modeling-benchmark-r13output', 192 | 'training-monolingual.tokenized.shuffled', 'news.en-*') 193 | train_paths = glob.glob(train_path_pattern) 194 | # the vocab will load from file when build_vocab() is called 195 | 196 | self.vocab.build_vocab() 197 | 198 | if self.dataset in ['ptb', 'wt2', 'wt103']: 199 | self.train = self.vocab.encode_file( 200 | os.path.join(path, 'train.txt'), ordered=True) 201 | self.valid = self.vocab.encode_file( 202 | os.path.join(path, 'valid.txt'), ordered=True) 203 | self.test = self.vocab.encode_file( 204 | os.path.join(path, 'test.txt'), ordered=True) 205 | elif self.dataset in ['enwik8', 'text8']: 206 | self.train = self.vocab.encode_file( 207 | os.path.join(path, 'train.txt'), ordered=True, add_eos=False) 208 | self.valid = self.vocab.encode_file( 209 | os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) 210 | self.test = self.vocab.encode_file( 211 | os.path.join(path, 'test.txt'), ordered=True, add_eos=False) 212 | elif self.dataset == 'lm1b': 213 | self.train = train_paths 214 | self.valid = self.vocab.encode_file( 215 | os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) 216 | self.test = self.vocab.encode_file( 217 | os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) 218 | 219 | def get_iterator(self, split, *args, **kwargs): 220 | if split == 'train': 221 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 222 | data_iter = LMOrderedIterator(self.train, *args, **kwargs) 223 | elif self.dataset == 'lm1b': 224 | kwargs['shuffle'] = True 225 | data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) 226 | elif split in ['valid', 'test']: 227 | data = self.valid if split == 'valid' else self.test 228 | if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: 229 | data_iter = LMOrderedIterator(data, *args, **kwargs) 230 | elif self.dataset == 'lm1b': 231 | data_iter = LMShuffledIterator(data, *args, **kwargs) 232 | 233 | return data_iter 234 | 235 | 236 | def get_lm_corpus(datadir, dataset): 237 | fn = os.path.join(datadir, 'cache.pt') 238 | if os.path.exists(fn): 239 | print('Loading cached dataset...') 240 | corpus = torch.load(fn) 241 | else: 242 | print('Producing dataset {}...'.format(dataset)) 243 | kwargs = {} 244 | if dataset in ['wt103', 'wt2']: 245 | kwargs['special'] = [''] 246 | kwargs['lower_case'] = False 247 | elif dataset == 'ptb': 248 | kwargs['special'] = [''] 249 | kwargs['lower_case'] = True 250 | elif dataset == 'lm1b': 251 | kwargs['special'] = [] 252 | kwargs['lower_case'] = False 253 | kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') 254 | elif dataset in ['enwik8', 'text8']: 255 | pass 256 | 257 | corpus = Corpus(datadir, dataset, **kwargs) 258 | torch.save(corpus, fn) 259 | 260 | return corpus 261 | 262 | if __name__ == '__main__': 263 | import argparse 264 | parser = argparse.ArgumentParser(description='unit test') 265 | parser.add_argument('--datadir', type=str, default='../data/text8', 266 | help='location of the data corpus') 267 | parser.add_argument('--dataset', type=str, default='text8', 268 | choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'], 269 | help='dataset name') 270 | args = parser.parse_args() 271 | 272 | corpus = get_lm_corpus(args.datadir, args.dataset) 273 | print('Vocab size : {}'.format(len(corpus.vocab.idx2sym))) 274 | -------------------------------------------------------------------------------- /pytorch/eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import argparse 3 | import time 4 | import math 5 | import os, sys 6 | 7 | import torch 8 | 9 | from data_utils import get_lm_corpus 10 | from mem_transformer import MemTransformerLM 11 | from utils.exp_utils import get_logger 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 14 | parser.add_argument('--data', type=str, default='../data/wikitext-103', 15 | help='location of the data corpus') 16 | parser.add_argument('--dataset', type=str, default='wt103', 17 | choices=['wt103', 'lm1b', 'enwik8', 'text8'], 18 | help='dataset name') 19 | parser.add_argument('--split', type=str, default='all', 20 | choices=['all', 'valid', 'test'], 21 | help='which split to evaluate') 22 | parser.add_argument('--batch_size', type=int, default=10, 23 | help='batch size') 24 | parser.add_argument('--tgt_len', type=int, default=5, 25 | help='number of tokens to predict') 26 | parser.add_argument('--ext_len', type=int, default=0, 27 | help='length of the extended context') 28 | parser.add_argument('--mem_len', type=int, default=0, 29 | help='length of the retained previous heads') 30 | parser.add_argument('--clamp_len', type=int, default=-1, 31 | help='max positional embedding index') 32 | parser.add_argument('--cuda', action='store_true', 33 | help='use CUDA') 34 | parser.add_argument('--work_dir', type=str, required=True, 35 | help='path to the work_dir') 36 | parser.add_argument('--no_log', action='store_true', 37 | help='do not log the eval result') 38 | parser.add_argument('--same_length', action='store_true', 39 | help='set same length attention with masking') 40 | args = parser.parse_args() 41 | assert args.ext_len >= 0, 'extended context length must be non-negative' 42 | 43 | device = torch.device("cuda" if args.cuda else "cpu") 44 | 45 | # Get logger 46 | logging = get_logger(os.path.join(args.work_dir, 'log.txt'), 47 | log_=not args.no_log) 48 | 49 | # Load dataset 50 | corpus = get_lm_corpus(args.data, args.dataset) 51 | ntokens = len(corpus.vocab) 52 | 53 | va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, 54 | device=device, ext_len=args.ext_len) 55 | te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, 56 | device=device, ext_len=args.ext_len) 57 | 58 | # Load the best saved model. 59 | with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: 60 | model = torch.load(f) 61 | model.backward_compatible() 62 | model = model.to(device) 63 | 64 | logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( 65 | args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) 66 | 67 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 68 | if args.clamp_len > 0: 69 | model.clamp_len = args.clamp_len 70 | if args.same_length: 71 | model.same_length = True 72 | 73 | ############################################################################### 74 | # Evaluation code 75 | ############################################################################### 76 | def evaluate(eval_iter): 77 | # Turn on evaluation mode which disables dropout. 78 | model.eval() 79 | total_len, total_loss = 0, 0. 80 | start_time = time.time() 81 | with torch.no_grad(): 82 | mems = tuple() 83 | for idx, (data, target, seq_len) in enumerate(eval_iter): 84 | ret = model(data, target, *mems) 85 | loss, mems = ret[0], ret[1:] 86 | loss = loss.mean() 87 | total_loss += seq_len * loss.item() 88 | total_len += seq_len 89 | total_time = time.time() - start_time 90 | logging('Time : {:.2f}s, {:.2f}ms/segment'.format( 91 | total_time, 1000 * total_time / (idx+1))) 92 | return total_loss / total_len 93 | 94 | # Run on test data. 95 | if args.split == 'all': 96 | test_loss = evaluate(te_iter) 97 | valid_loss = evaluate(va_iter) 98 | elif args.split == 'valid': 99 | valid_loss = evaluate(va_iter) 100 | test_loss = None 101 | elif args.split == 'test': 102 | test_loss = evaluate(te_iter) 103 | valid_loss = None 104 | 105 | def format_log(loss, split): 106 | if args.dataset in ['enwik8', 'text8']: 107 | log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format( 108 | split, loss, loss / math.log(2)) 109 | else: 110 | log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( 111 | split, loss, math.exp(loss)) 112 | return log_str 113 | 114 | log_str = '' 115 | if valid_loss is not None: 116 | log_str += format_log(valid_loss, 'valid') 117 | if test_loss is not None: 118 | log_str += format_log(test_loss, 'test') 119 | 120 | logging('=' * 100) 121 | logging(log_str) 122 | logging('=' * 100) 123 | -------------------------------------------------------------------------------- /pytorch/run_enwik8_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/enwik8/ \ 8 | --dataset enwik8 \ 9 | --n_layer 12 \ 10 | --d_model 512 \ 11 | --n_head 8 \ 12 | --d_head 64 \ 13 | --d_inner 2048 \ 14 | --dropout 0.1 \ 15 | --dropatt 0.0 \ 16 | --optim adam \ 17 | --lr 0.00025 \ 18 | --warmup_step 0 \ 19 | --max_step 400000 \ 20 | --tgt_len 512 \ 21 | --mem_len 512 \ 22 | --eval_tgt_len 128 \ 23 | --batch_size 22 \ 24 | --multi_gpu \ 25 | --gpu0_bsz 4 \ 26 | ${@:2} 27 | elif [[ $1 == 'eval' ]]; then 28 | echo 'Run evaluation...' 29 | python eval.py \ 30 | --cuda \ 31 | --data ../data/enwik8/ \ 32 | --dataset enwik8 \ 33 | --tgt_len 80 \ 34 | --mem_len 2100 \ 35 | --clamp_len 820 \ 36 | --same_length \ 37 | --split test \ 38 | ${@:2} 39 | else 40 | echo 'unknown argment 1' 41 | fi 42 | -------------------------------------------------------------------------------- /pytorch/run_enwik8_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/enwik8/ \ 8 | --dataset enwik8 \ 9 | --n_layer 24 \ 10 | --d_model 1024 \ 11 | --n_head 8 \ 12 | --d_head 128 \ 13 | --d_inner 3072 \ 14 | --dropout 0.15 \ 15 | --dropatt 0.15 \ 16 | --optim adam \ 17 | --lr 0.00025 \ 18 | --warmup_step 4000 \ 19 | --max_step 400000 \ 20 | --tgt_len 768 \ 21 | --mem_len 768 \ 22 | --eval_tgt_len 128 \ 23 | --batch_size 64 \ 24 | --multi_gpu \ 25 | --gpu0_bsz 0 \ 26 | ${@:2} 27 | elif [[ $1 == 'eval' ]]; then 28 | echo 'Run evaluation...' 29 | python eval.py \ 30 | --cuda \ 31 | --data ../data/enwik8/ \ 32 | --dataset enwik8 \ 33 | --tgt_len 128 \ 34 | --mem_len 3800 \ 35 | --clamp_len 1000 \ 36 | --same_length \ 37 | --split test \ 38 | ${@:2} 39 | else 40 | echo 'unknown argment 1' 41 | fi 42 | -------------------------------------------------------------------------------- /pytorch/run_lm1b_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/one-billion-words/ \ 8 | --dataset lm1b \ 9 | --adaptive \ 10 | --n_layer 18 \ 11 | --d_model 1024 \ 12 | --div_val 4 \ 13 | --n_head 8 \ 14 | --d_head 128 \ 15 | --d_inner 4096 \ 16 | --dropout 0.0 \ 17 | --dropatt 0.0 \ 18 | --optim adam \ 19 | --warmup_step 20000 \ 20 | --max_step 500000 \ 21 | --lr 0.00025 \ 22 | --tgt_len 32 \ 23 | --mem_len 32 \ 24 | --eval_tgt_len 32 \ 25 | --batch_size 224 \ 26 | --multi_gpu \ 27 | --gpu0_bsz 32 \ 28 | ${@:2} 29 | elif [[ $1 == 'eval' ]]; then 30 | echo 'Run evaluation...' 31 | python eval.py \ 32 | --cuda \ 33 | --data ../data/one-billion-words/ \ 34 | --dataset lm1b \ 35 | --batch_size 64 \ 36 | --tgt_len 32 \ 37 | --mem_len 128 \ 38 | --split test \ 39 | --same_length \ 40 | ${@:2} 41 | else 42 | echo 'unknown argment 1' 43 | fi 44 | -------------------------------------------------------------------------------- /pytorch/run_lm1b_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/one-billion-words/ \ 8 | --dataset lm1b \ 9 | --adaptive \ 10 | --div_val 4 \ 11 | --n_layer 24 \ 12 | --d_model 1280 \ 13 | --n_head 16 \ 14 | --d_head 80 \ 15 | --d_inner 8192 \ 16 | --dropout 0.05 \ 17 | --dropatt 0.05 \ 18 | --optim adam \ 19 | --warmup_step 30000 \ 20 | --max_step 1200000 \ 21 | --lr 0.00025 \ 22 | --tgt_len 32 \ 23 | --mem_len 32 \ 24 | --eval_tgt_len 32 \ 25 | --batch_size 512 \ 26 | --multi_gpu \ 27 | --gpu0_bsz 0 \ 28 | ${@:2} 29 | elif [[ $1 == 'eval' ]]; then 30 | echo 'Run evaluation...' 31 | python eval.py \ 32 | --cuda \ 33 | --data ../data/one-billion-words/ \ 34 | --dataset lm1b \ 35 | --batch_size 8 \ 36 | --tgt_len 32 \ 37 | --mem_len 128 \ 38 | --split test \ 39 | --same_length \ 40 | ${@:2} 41 | else 42 | echo 'unknown argment 1' 43 | fi 44 | -------------------------------------------------------------------------------- /pytorch/run_text8_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/text8/ \ 8 | --dataset text8 \ 9 | --n_layer 12 \ 10 | --d_model 512 \ 11 | --n_head 8 \ 12 | --d_head 64 \ 13 | --d_inner 2048 \ 14 | --dropout 0.1 \ 15 | --dropatt 0.0 \ 16 | --optim adam \ 17 | --lr 0.00025 \ 18 | --warmup_step 0 \ 19 | --max_step 400000 \ 20 | --tgt_len 512 \ 21 | --mem_len 512 \ 22 | --eval_tgt_len 128 \ 23 | --batch_size 22 \ 24 | --multi_gpu \ 25 | --gpu0_bsz 4 \ 26 | ${@:2} 27 | elif [[ $1 == 'eval' ]]; then 28 | echo 'Run evaluation...' 29 | python eval.py \ 30 | --cuda \ 31 | --data ../data/text8/ \ 32 | --dataset text8 \ 33 | --tgt_len 80 \ 34 | --mem_len 2100 \ 35 | --clamp_len 820 \ 36 | --same_length \ 37 | --split test \ 38 | ${@:2} 39 | else 40 | echo 'unknown argment 1' 41 | fi 42 | -------------------------------------------------------------------------------- /pytorch/run_text8_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/text8/ \ 8 | --dataset text8 \ 9 | --n_layer 24 \ 10 | --d_model 1024 \ 11 | --n_head 8 \ 12 | --d_head 128 \ 13 | --d_inner 3072 \ 14 | --dropout 0.15 \ 15 | --dropatt 0.15 \ 16 | --optim adam \ 17 | --lr 0.00025 \ 18 | --tgt_len 768 \ 19 | --mem_len 768 \ 20 | --eval_tgt_len 128 \ 21 | --batch_size 64 \ 22 | --max_step 400000 \ 23 | ${@:2} 24 | elif [[ $1 == 'eval' ]]; then 25 | echo 'Run evaluation...' 26 | python eval.py \ 27 | --cuda \ 28 | --data ../data/text8/ \ 29 | --dataset text8 \ 30 | --tgt_len 128 \ 31 | --mem_len 3800 \ 32 | --clamp_len 1000 \ 33 | --same_length \ 34 | --split test \ 35 | ${@:2} 36 | else 37 | echo 'unknown argment 1' 38 | fi 39 | -------------------------------------------------------------------------------- /pytorch/run_wt103_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/wikitext-103/ \ 8 | --dataset wt103 \ 9 | --adaptive \ 10 | --n_layer 16 \ 11 | --d_model 410 \ 12 | --n_head 10 \ 13 | --d_head 41 \ 14 | --d_inner 2100 \ 15 | --dropout 0.1 \ 16 | --dropatt 0.0 \ 17 | --optim adam \ 18 | --lr 0.00025 \ 19 | --warmup_step 0 \ 20 | --max_step 200000 \ 21 | --tgt_len 150 \ 22 | --mem_len 150 \ 23 | --eval_tgt_len 150 \ 24 | --batch_size 60 \ 25 | --multi_gpu \ 26 | --gpu0_bsz 4 \ 27 | ${@:2} 28 | elif [[ $1 == 'eval' ]]; then 29 | echo 'Run evaluation...' 30 | python eval.py \ 31 | --cuda \ 32 | --data ../data/wikitext-103/ \ 33 | --dataset wt103 \ 34 | --tgt_len 64 \ 35 | --mem_len 640 \ 36 | --clamp_len 400 \ 37 | --same_length \ 38 | --split test \ 39 | ${@:2} 40 | else 41 | echo 'unknown argment 1' 42 | fi 43 | -------------------------------------------------------------------------------- /pytorch/run_wt103_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 == 'train' ]]; then 4 | echo 'Run training...' 5 | python train.py \ 6 | --cuda \ 7 | --data ../data/wikitext-103/ \ 8 | --dataset wt103 \ 9 | --adaptive \ 10 | --div_val 4 \ 11 | --n_layer 18 \ 12 | --d_model 1024 \ 13 | --n_head 16 \ 14 | --d_head 64 \ 15 | --d_inner 4096 \ 16 | --dropout 0.2 \ 17 | --dropatt 0.2 \ 18 | --optim adam \ 19 | --lr 0.00025 \ 20 | --warmup_step 16000 \ 21 | --max_step 4000000 \ 22 | --tgt_len 384 \ 23 | --mem_len 384 \ 24 | --eval_tgt_len 128 \ 25 | --batch_size 128 \ 26 | --multi_gpu \ 27 | --gpu0_bsz 0 \ 28 | ${@:2} 29 | elif [[ $1 == 'eval' ]]; then 30 | echo 'Run evaluation...' 31 | python eval.py \ 32 | --cuda \ 33 | --data ../data/wikitext-103/ \ 34 | --dataset wt103 \ 35 | --tgt_len 128 \ 36 | --mem_len 1600 \ 37 | --clamp_len 1000 \ 38 | --same_length \ 39 | --split test \ 40 | ${@:2} 41 | else 42 | echo 'unknown argment 1' 43 | fi 44 | -------------------------------------------------------------------------------- /pytorch/utils/adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class AdaptiveLogSoftmax(nn.Module): 10 | def __init__(self, in_features, n_classes, cutoffs, keep_order=False): 11 | super(AdaptiveLogSoftmax, self).__init__() 12 | 13 | cutoffs = list(cutoffs) 14 | 15 | if (cutoffs != sorted(cutoffs)) \ 16 | or (min(cutoffs) <= 0) \ 17 | or (max(cutoffs) >= (n_classes - 1)) \ 18 | or (len(set(cutoffs)) != len(cutoffs)) \ 19 | or any([int(c) != c for c in cutoffs]): 20 | 21 | raise ValueError("cutoffs should be a sequence of unique, positive " 22 | "integers sorted in an increasing order, where " 23 | "each value is between 1 and n_classes-1") 24 | 25 | self.in_features = in_features 26 | self.n_classes = n_classes 27 | self.cutoffs = cutoffs + [n_classes] 28 | 29 | self.shortlist_size = self.cutoffs[0] 30 | self.n_clusters = len(self.cutoffs) - 1 31 | self.head_size = self.shortlist_size + self.n_clusters 32 | 33 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features)) 34 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 35 | 36 | self.keep_order = keep_order 37 | 38 | 39 | def forward(self, hidden, target, weight, bias, keep_order=False): 40 | if hidden.size(0) != target.size(0): 41 | raise RuntimeError('Input and target should have the same size ' 42 | 'in the batch dimension.') 43 | 44 | head_weight = torch.cat( 45 | [weight[:self.shortlist_size], self.cluster_weight], dim=0) 46 | head_bias = torch.cat( 47 | [bias[:self.shortlist_size], self.cluster_bias], dim=0) 48 | 49 | head_logit = F.linear(hidden, head_weight, bias=head_bias) 50 | head_logprob = F.log_softmax(head_logit, dim=1) 51 | 52 | nll = torch.zeros_like(target, 53 | dtype=hidden.dtype, device=hidden.device) 54 | 55 | offset = 0 56 | cutoff_values = [0] + self.cutoffs 57 | for i in range(len(cutoff_values) - 1): 58 | l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1] 59 | 60 | mask_i = (target >= l_idx) & (target < h_idx) 61 | indices_i = mask_i.nonzero().squeeze() 62 | 63 | if indices_i.numel() == 0: 64 | continue 65 | 66 | target_i = target.index_select(0, indices_i) - l_idx 67 | head_logprob_i = head_logprob.index_select(0, indices_i) 68 | 69 | if i == 0: 70 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 71 | else: 72 | weight_i = weight[l_idx:h_idx] 73 | bias_i = bias[l_idx:h_idx] 74 | 75 | hidden_i = hidden.index_select(0, indices_i) 76 | 77 | tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i) 78 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 79 | 80 | logprob_i = head_logprob_i[:, -i] \ 81 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 82 | 83 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 84 | nll.index_copy_(0, indices_i, -logprob_i) 85 | else: 86 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 87 | 88 | offset += logprob_i.size(0) 89 | 90 | return nll 91 | -------------------------------------------------------------------------------- /pytorch/utils/data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | if len(self.device_ids) == 1: 66 | return self.module(*inputs[0], **kwargs[0]) 67 | replicas = self.replicate(self.module, self.device_ids) 68 | if self.gpu0_bsz == 0: 69 | replicas = replicas[1:] 70 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 71 | return self.gather(outputs, self.output_device) 72 | 73 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 74 | return parallel_apply(replicas, inputs, kwargs, device_ids) 75 | 76 | def scatter(self, inputs, kwargs, device_ids): 77 | bsz = inputs[0].size(self.dim) 78 | num_dev = len(self.device_ids) 79 | gpu0_bsz = self.gpu0_bsz 80 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 81 | if gpu0_bsz < bsz_unit: 82 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 83 | delta = bsz - sum(chunk_sizes) 84 | for i in range(delta): 85 | chunk_sizes[i + 1] += 1 86 | if gpu0_bsz == 0: 87 | chunk_sizes = chunk_sizes[1:] 88 | else: 89 | return super().scatter(inputs, kwargs, device_ids) 90 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 91 | 92 | -------------------------------------------------------------------------------- /pytorch/utils/exp_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os, shutil 3 | 4 | import numpy as np 5 | 6 | import torch 7 | 8 | 9 | def logging(s, log_path, print_=True, log_=True): 10 | if print_: 11 | print(s) 12 | if log_: 13 | with open(log_path, 'a+') as f_log: 14 | f_log.write(s + '\n') 15 | 16 | def get_logger(log_path, **kwargs): 17 | return functools.partial(logging, log_path=log_path, **kwargs) 18 | 19 | def create_exp_dir(dir_path, scripts_to_save=None, debug=False): 20 | if debug: 21 | print('Debug Mode : no experiment dir created') 22 | return functools.partial(logging, log_path=None, log_=False) 23 | 24 | if not os.path.exists(dir_path): 25 | os.makedirs(dir_path) 26 | 27 | print('Experiment dir : {}'.format(dir_path)) 28 | if scripts_to_save is not None: 29 | script_path = os.path.join(dir_path, 'scripts') 30 | if not os.path.exists(script_path): 31 | os.makedirs(script_path) 32 | for script in scripts_to_save: 33 | dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) 34 | shutil.copyfile(script, dst_file) 35 | 36 | return get_logger(log_path=os.path.join(dir_path, 'log.txt')) 37 | 38 | def save_checkpoint(model, optimizer, path, epoch): 39 | torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) 40 | torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch))) 41 | -------------------------------------------------------------------------------- /pytorch/utils/log_uniform_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | class LogUniformSampler(object): 6 | def __init__(self, range_max, n_sample): 7 | """ 8 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 9 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 10 | 11 | expected count can be approximated by 1 - (1 - p)^n 12 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 13 | 14 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 15 | """ 16 | with torch.no_grad(): 17 | self.range_max = range_max 18 | log_indices = torch.arange(1., range_max+2., 1.).log_() 19 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 20 | # print('P', self.dist.numpy().tolist()[-30:]) 21 | 22 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 23 | 24 | self.n_sample = n_sample 25 | 26 | def sample(self, labels): 27 | """ 28 | labels: [b1, b2] 29 | Return 30 | true_log_probs: [b1, b2] 31 | samp_log_probs: [n_sample] 32 | neg_samples: [n_sample] 33 | """ 34 | 35 | # neg_samples = torch.empty(0).long() 36 | n_sample = self.n_sample 37 | n_tries = 2 * n_sample 38 | 39 | with torch.no_grad(): 40 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 41 | device = labels.device 42 | neg_samples = neg_samples.to(device) 43 | true_log_probs = self.log_q[labels].to(device) 44 | samp_log_probs = self.log_q[neg_samples].to(device) 45 | return true_log_probs, samp_log_probs, neg_samples 46 | 47 | def sample_logits(embedding, bias, labels, inputs, sampler): 48 | """ 49 | embedding: an nn.Embedding layer 50 | bias: [n_vocab] 51 | labels: [b1, b2] 52 | inputs: [b1, b2, n_emb] 53 | sampler: you may use a LogUniformSampler 54 | Return 55 | logits: [b1, b2, 1 + n_sample] 56 | """ 57 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 58 | n_sample = neg_samples.size(0) 59 | b1, b2 = labels.size(0), labels.size(1) 60 | all_ids = torch.cat([labels.view(-1), neg_samples]) 61 | all_w = embedding(all_ids) 62 | true_w = all_w[: -n_sample].view(b1, b2, -1) 63 | sample_w = all_w[- n_sample:].view(n_sample, -1) 64 | 65 | all_b = bias[all_ids] 66 | true_b = all_b[: -n_sample].view(b1, b2) 67 | sample_b = all_b[- n_sample:] 68 | 69 | hit = (labels[:, :, None] == neg_samples).detach() 70 | 71 | true_logits = torch.einsum('ijk,ijk->ij', 72 | [true_w, inputs]) + true_b - true_log_probs 73 | sample_logits = torch.einsum('lk,ijk->ijl', 74 | [sample_w, inputs]) + sample_b - samp_log_probs 75 | sample_logits.masked_fill_(hit, -1e30) 76 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 77 | 78 | return logits 79 | 80 | 81 | # class LogUniformSampler(object): 82 | # def __init__(self, range_max, unique=False): 83 | # """ 84 | # Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 85 | # `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 86 | # """ 87 | # self.range_max = range_max 88 | # log_indices = torch.arange(1., range_max+2., 1.).log_() 89 | # self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 90 | 91 | # self.unique = unique 92 | 93 | # if self.unique: 94 | # self.exclude_mask = torch.ByteTensor(range_max).fill_(0) 95 | 96 | # def sample(self, n_sample, labels): 97 | # pos_sample, new_labels = labels.unique(return_inverse=True) 98 | # n_pos_sample = pos_sample.size(0) 99 | # n_neg_sample = n_sample - n_pos_sample 100 | 101 | # if self.unique: 102 | # self.exclude_mask.index_fill_(0, pos_sample, 1) 103 | # sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) 104 | # self.exclude_mask.index_fill_(0, pos_sample, 0) 105 | # else: 106 | # sample_dist = self.dist 107 | 108 | # neg_sample = torch.multinomial(sample_dist, n_neg_sample) 109 | 110 | # sample = torch.cat([pos_sample, neg_sample]) 111 | # sample_prob = self.dist[sample] 112 | 113 | # return new_labels, sample, sample_prob 114 | 115 | 116 | if __name__ == '__main__': 117 | S, B = 3, 4 118 | n_vocab = 10000 119 | n_sample = 5 120 | H = 32 121 | 122 | labels = torch.LongTensor(S, B).random_(0, n_vocab) 123 | 124 | # sampler = LogUniformSampler(n_vocab, unique=False) 125 | # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) 126 | 127 | sampler = LogUniformSampler(n_vocab, unique=True) 128 | # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) 129 | 130 | # print('true_probs', true_probs.numpy().tolist()) 131 | # print('samp_probs', samp_probs.numpy().tolist()) 132 | # print('neg_samples', neg_samples.numpy().tolist()) 133 | 134 | # print('sum', torch.sum(sampler.dist).item()) 135 | 136 | # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() 137 | 138 | embedding = nn.Embedding(n_vocab, H) 139 | bias = torch.zeros(n_vocab) 140 | inputs = torch.Tensor(S, B, H).normal_() 141 | 142 | logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) 143 | print('logits', logits.detach().numpy().tolist()) 144 | print('logits shape', logits.size()) 145 | print('out_labels', out_labels.detach().numpy().tolist()) 146 | print('out_labels shape', out_labels.size()) 147 | 148 | -------------------------------------------------------------------------------- /pytorch/utils/proj_adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 10 | CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 11 | 12 | class ProjectedAdaptiveLogSoftmax(nn.Module): 13 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 14 | keep_order=False): 15 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 16 | 17 | self.n_token = n_token 18 | self.d_embed = d_embed 19 | self.d_proj = d_proj 20 | 21 | self.cutoffs = cutoffs + [n_token] 22 | self.cutoff_ends = [0] + self.cutoffs 23 | self.div_val = div_val 24 | 25 | self.shortlist_size = self.cutoffs[0] 26 | self.n_clusters = len(self.cutoffs) - 1 27 | self.head_size = self.shortlist_size + self.n_clusters 28 | 29 | if self.n_clusters > 0: 30 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 31 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 32 | 33 | self.out_layers = nn.ModuleList() 34 | self.out_projs = nn.ParameterList() 35 | 36 | if div_val == 1: 37 | for i in range(len(self.cutoffs)): 38 | if d_proj != d_embed: 39 | self.out_projs.append( 40 | nn.Parameter(torch.Tensor(d_proj, d_embed)) 41 | ) 42 | else: 43 | self.out_projs.append(None) 44 | 45 | self.out_layers.append(nn.Linear(d_embed, n_token)) 46 | else: 47 | for i in range(len(self.cutoffs)): 48 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 49 | d_emb_i = d_embed // (div_val ** i) 50 | 51 | self.out_projs.append( 52 | nn.Parameter(torch.Tensor(d_proj, d_emb_i)) 53 | ) 54 | 55 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 56 | 57 | self.keep_order = keep_order 58 | 59 | def _compute_logit(self, hidden, weight, bias, proj): 60 | if proj is None: 61 | logit = F.linear(hidden, weight, bias=bias) 62 | else: 63 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 64 | proj_hid = F.linear(hidden, proj.t().contiguous()) 65 | logit = F.linear(proj_hid, weight, bias=bias) 66 | # else: 67 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 68 | # if bias is not None: 69 | # logit = logit + bias 70 | 71 | return logit 72 | 73 | def forward(self, hidden, target, keep_order=False): 74 | ''' 75 | hidden :: [len*bsz x d_proj] 76 | target :: [len*bsz] 77 | ''' 78 | 79 | if hidden.size(0) != target.size(0): 80 | raise RuntimeError('Input and target should have the same size ' 81 | 'in the batch dimension.') 82 | 83 | if self.n_clusters == 0: 84 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 85 | self.out_layers[0].bias, self.out_projs[0]) 86 | nll = -F.log_softmax(logit, dim=-1) \ 87 | .gather(1, target.unsqueeze(1)).squeeze(1) 88 | else: 89 | # construct weights and biases 90 | weights, biases = [], [] 91 | for i in range(len(self.cutoffs)): 92 | if self.div_val == 1: 93 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 94 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 95 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 96 | else: 97 | weight_i = self.out_layers[i].weight 98 | bias_i = self.out_layers[i].bias 99 | 100 | if i == 0: 101 | weight_i = torch.cat( 102 | [weight_i, self.cluster_weight], dim=0) 103 | bias_i = torch.cat( 104 | [bias_i, self.cluster_bias], dim=0) 105 | 106 | weights.append(weight_i) 107 | biases.append(bias_i) 108 | 109 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 110 | 111 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 112 | head_logprob = F.log_softmax(head_logit, dim=1) 113 | 114 | nll = torch.zeros_like(target, 115 | dtype=hidden.dtype, device=hidden.device) 116 | 117 | offset = 0 118 | cutoff_values = [0] + self.cutoffs 119 | for i in range(len(cutoff_values) - 1): 120 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 121 | 122 | mask_i = (target >= l_idx) & (target < r_idx) 123 | indices_i = mask_i.nonzero().squeeze() 124 | 125 | if indices_i.numel() == 0: 126 | continue 127 | 128 | target_i = target.index_select(0, indices_i) - l_idx 129 | head_logprob_i = head_logprob.index_select(0, indices_i) 130 | 131 | if i == 0: 132 | logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) 133 | else: 134 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 135 | 136 | hidden_i = hidden.index_select(0, indices_i) 137 | 138 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 139 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 140 | 141 | logprob_i = head_logprob_i[:, -i] \ 142 | + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) 143 | 144 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 145 | nll.index_copy_(0, indices_i, -logprob_i) 146 | else: 147 | nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 148 | 149 | offset += logprob_i.size(0) 150 | 151 | return nll 152 | -------------------------------------------------------------------------------- /pytorch/utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter, OrderedDict 3 | 4 | import torch 5 | 6 | class Vocab(object): 7 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, 8 | delimiter=None, vocab_file=None): 9 | self.counter = Counter() 10 | self.special = special 11 | self.min_freq = min_freq 12 | self.max_size = max_size 13 | self.lower_case = lower_case 14 | self.delimiter = delimiter 15 | self.vocab_file = vocab_file 16 | 17 | def tokenize(self, line, add_eos=False, add_double_eos=False): 18 | line = line.strip() 19 | # convert to lower case 20 | if self.lower_case: 21 | line = line.lower() 22 | 23 | # empty delimiter '' will evaluate False 24 | if self.delimiter == '': 25 | symbols = line 26 | else: 27 | symbols = line.split(self.delimiter) 28 | 29 | if add_double_eos: # lm1b 30 | return [''] + symbols + [''] 31 | elif add_eos: 32 | return symbols + [''] 33 | else: 34 | return symbols 35 | 36 | def count_file(self, path, verbose=False, add_eos=False): 37 | if verbose: print('counting file {} ...'.format(path)) 38 | assert os.path.exists(path) 39 | 40 | sents = [] 41 | with open(path, 'r', encoding='utf-8') as f: 42 | for idx, line in enumerate(f): 43 | if verbose and idx > 0 and idx % 500000 == 0: 44 | print(' line {}'.format(idx)) 45 | symbols = self.tokenize(line, add_eos=add_eos) 46 | self.counter.update(symbols) 47 | sents.append(symbols) 48 | 49 | return sents 50 | 51 | def count_sents(self, sents, verbose=False): 52 | """ 53 | sents : a list of sentences, each a list of tokenized symbols 54 | """ 55 | if verbose: print('counting {} sents ...'.format(len(sents))) 56 | for idx, symbols in enumerate(sents): 57 | if verbose and idx > 0 and idx % 500000 == 0: 58 | print(' line {}'.format(idx)) 59 | self.counter.update(symbols) 60 | 61 | def _build_from_file(self, vocab_file): 62 | self.idx2sym = [] 63 | self.sym2idx = OrderedDict() 64 | 65 | with open(vocab_file, 'r', encoding='utf-8') as f: 66 | for line in f: 67 | symb = line.strip().split()[0] 68 | self.add_symbol(symb) 69 | self.unk_idx = self.sym2idx[''] 70 | 71 | def build_vocab(self): 72 | if self.vocab_file: 73 | print('building vocab from {}'.format(self.vocab_file)) 74 | self._build_from_file(self.vocab_file) 75 | print('final vocab size {}'.format(len(self))) 76 | else: 77 | print('building vocab with min_freq={}, max_size={}'.format( 78 | self.min_freq, self.max_size)) 79 | self.idx2sym = [] 80 | self.sym2idx = OrderedDict() 81 | 82 | for sym in self.special: 83 | self.add_special(sym) 84 | 85 | for sym, cnt in self.counter.most_common(self.max_size): 86 | if cnt < self.min_freq: break 87 | self.add_symbol(sym) 88 | 89 | print('final vocab size {} from {} unique tokens'.format( 90 | len(self), len(self.counter))) 91 | 92 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 93 | add_double_eos=False): 94 | if verbose: print('encoding file {} ...'.format(path)) 95 | assert os.path.exists(path) 96 | encoded = [] 97 | with open(path, 'r', encoding='utf-8') as f: 98 | for idx, line in enumerate(f): 99 | if verbose and idx > 0 and idx % 500000 == 0: 100 | print(' line {}'.format(idx)) 101 | symbols = self.tokenize(line, add_eos=add_eos, 102 | add_double_eos=add_double_eos) 103 | encoded.append(self.convert_to_tensor(symbols)) 104 | 105 | if ordered: 106 | encoded = torch.cat(encoded) 107 | 108 | return encoded 109 | 110 | def encode_sents(self, sents, ordered=False, verbose=False): 111 | if verbose: print('encoding {} sents ...'.format(len(sents))) 112 | encoded = [] 113 | for idx, symbols in enumerate(sents): 114 | if verbose and idx > 0 and idx % 500000 == 0: 115 | print(' line {}'.format(idx)) 116 | encoded.append(self.convert_to_tensor(symbols)) 117 | 118 | if ordered: 119 | encoded = torch.cat(encoded) 120 | 121 | return encoded 122 | 123 | def add_special(self, sym): 124 | if sym not in self.sym2idx: 125 | self.idx2sym.append(sym) 126 | self.sym2idx[sym] = len(self.idx2sym) - 1 127 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 128 | 129 | def add_symbol(self, sym): 130 | if sym not in self.sym2idx: 131 | self.idx2sym.append(sym) 132 | self.sym2idx[sym] = len(self.idx2sym) - 1 133 | 134 | def get_sym(self, idx): 135 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) 136 | return self.idx2sym[idx] 137 | 138 | def get_idx(self, sym): 139 | if sym in self.sym2idx: 140 | return self.sym2idx[sym] 141 | else: 142 | # print('encounter unk {}'.format(sym)) 143 | assert '' not in sym 144 | assert hasattr(self, 'unk_idx') 145 | return self.sym2idx.get(sym, self.unk_idx) 146 | 147 | def get_symbols(self, indices): 148 | return [self.get_sym(idx) for idx in indices] 149 | 150 | def get_indices(self, symbols): 151 | return [self.get_idx(sym) for sym in symbols] 152 | 153 | def convert_to_tensor(self, symbols): 154 | return torch.LongTensor(self.get_indices(symbols)) 155 | 156 | def convert_to_sent(self, indices, exclude=None): 157 | if exclude is None: 158 | return ' '.join([self.get_sym(idx) for idx in indices]) 159 | else: 160 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 161 | 162 | def __len__(self): 163 | return len(self.idx2sym) 164 | -------------------------------------------------------------------------------- /tf/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Introduction 3 | 4 | This directory contains our TF implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our gpu codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts: 5 | - `*large_tpu.sh` are for the SoTA setting on TPUs. These are exactly the commands we used to obtained our best results. 6 | - `*base_gpu.sh` are for the base models which can be run on a few GPUs. 7 | 8 | 9 | ## Prerequisite 10 | 11 | - Python 2.7 12 | - Tensorflow [1.12.0](https://github.com/tensorflow/tensorflow/releases/tag/v1.12.0) 13 | 14 | 15 | 16 | ## Obtain and evaluate pretrained SoTA models 17 | 18 | #### 1. Download preprocessed data (vocab) & pretrained models 19 | 20 | (a) Set your own `DATA_ROOT` in `sota/download.sh` (default to `./`), which will be the root diretory of downloaded model. 21 | 22 | (b) Then, download the model & data by `bash sota/download.sh`. After downloading, the expected directory structure is as follows 23 | 24 | ```markdown 25 | pretrained_xl 26 | tf_enwik8/ 27 | data/ 28 | cache.pkl 29 | corpus-info.json 30 | model/ 31 | checkpoint 32 | model.ckpt* 33 | tf_wt103/ 34 | ... 35 | ... 36 | ``` 37 | 38 | **Note**: we include preprocessed data in the download files to make sure the **same vocabulary** is used. Please see the code `tf/data_utils.py` to understand the data structure. 39 | 40 | 41 | 42 | #### 2. Run evaluation scripts to replicate SoTA results on GPUs 43 | 44 | - **enwik8**: modify the script `sota/enwik8.sh` accordingly (see below) 45 | - set `DATA_ROOT` to the same folder used in the download step (default to `./`) 46 | - set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 2 GPUs => about 60 mins 47 | - run the script: `bash sota/enwik8.sh` 48 | 49 | - **lm1b**: modify the script `sota/lm1b.sh` accordingly (see below) 50 | - set `DATA_ROOT` to the same folder used in the download step (default to `./`) 51 | - set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 1 GPUs => less than 5 mins 52 | - run the script: `bash sota/lm1b.sh` 53 | 54 | - **wt103**: modify the script `sota/wt103.sh` accordingly (see below) 55 | - set `DATA_ROOT` to the same folder used in the download step (default to `./`) 56 | - set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 1 GPUs => less than 5 mins 57 | - run the script: `bash sota/wt103.sh` 58 | 59 | - **text8**: modify the script `sota/text8.sh` accordingly (see below) 60 | - set `DATA_ROOT` to the same folder used in the download step (default to `./`) 61 | - set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 2 GPUs => about 60 mins 62 | - run the script: `bash sota/text8.sh` 63 | 64 | 65 | #### 3. Resources Needed for SoTA Model Training 66 | 67 | We used 32, 32, 64, and 512 TPU cores for training our best models on enwik8, text8, wt103, and lm1b respectively. The training time for each model ranges from 2 to 5 days. 68 | 69 | 70 | 71 | ## Train "Transformer-XL" from scratch with GPUs or TPUs 72 | 73 | ### 1. Download raw data 74 | 75 | `bash getdata.sh` 76 | 77 | 78 | 79 | ### 2. Preprocess, training and evaluation 80 | 81 | For `dataset` in `[enwik8, lm1b, wt103, text8]`: 82 | 83 | - check out `scripts/dataset_base_gpu.sh` for GPU training and evaluation 84 | - check out `scripts/dataset_large_tpu.sh` for TPU training and evaluation 85 | 86 | 87 | 88 | #### (1) Preprocess raw data and create tfrecords 89 | 90 | **NOTE**: The preprocessing for GPU and TPU are different. So, you have to run them separately. 91 | 92 | GPU: 93 | 94 | - create training and validation data: `bash scripts/dataset_bas_gpu.sh train_data` 95 | - create test data: `bash scripts/dataset_base_gpu.sh test_data` 96 | 97 | TPU: 98 | 99 | - Set the Google storage URL in `scripts/dataset_large_tpu.sh`: 100 | - `GSDATA`: data URL 101 | - `GSEXP`: experiment URL 102 | - create training and validation data: `bash scripts/dataset_large_tpu.sh train_data` 103 | - create test data: `bash scripts/dataset_large_tpu.sh test_data` 104 | 105 | 106 | 107 | #### (2) Run training 108 | 109 | Base models on GPUs: 110 | 111 | - Modify the configurations in `scripts/dataset_base_gpu.sh` according to your needs. 112 | - `bash scripts/dataset_base_gpu.sh train` 113 | - If enough resources are available, increasing the model sizes (e.g., `N_LAYER`, `D_MODEL`, `D_EMBED`, `D_HEAD`, `D_INNER`) so that they are closer to the values defined in `scripts/dataset_large_tpu.sh`. Likewise, when resources are limited, decrease the model sizes. It is recommended to ensure that `D_MODEL == D_EMBED` and `D_MODEL == N_HEAD x D_HEAD`. When the model sizes increase, remember to increase `warmup_steps` accordingly to alleviate optimization difficulties. 114 | - Adjust the `NUM_CORE` parameter to reflect the number of GPUs to use. 115 | 116 | Larger models on TPUs: 117 | 118 | - Modify the configurations in `scripts/dataset_large_tpu.sh` according to your needs. 119 | - `bash scripts/dataset_large_tpu.sh train` 120 | 121 | 122 | 123 | #### (3) Run evaluation 124 | 125 | Base models on GPUs: 126 | 127 | - `bash scripts/dataset_base_gpu.sh eval --eval_ckpt_path PATH_TO_CKPT` 128 | 129 | Larger models on TPUs: 130 | 131 | - `bash scripts/dataset_base_tpu.sh eval --eval_ckpt_path PATH_TO_CKPT` 132 | -------------------------------------------------------------------------------- /tf/avg_checkpoints.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Script to average values of variables in a list of checkpoint files.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import numpy as np 23 | import six 24 | from six.moves import zip # pylint: disable=redefined-builtin 25 | import tensorflow as tf 26 | 27 | flags = tf.flags 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("checkpoints", "", 31 | "Comma-separated list of checkpoints to average.") 32 | flags.DEFINE_integer("num_last_checkpoints", 0, 33 | "Averages the last N saved checkpoints." 34 | " If the checkpoints flag is set, this is ignored.") 35 | flags.DEFINE_string("prefix", "", 36 | "Prefix (e.g., directory) to append to each checkpoint.") 37 | flags.DEFINE_string("output_path", "/tmp/averaged.ckpt", 38 | "Path to output the averaged checkpoint to.") 39 | 40 | 41 | def checkpoint_exists(path): 42 | return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or 43 | tf.gfile.Exists(path + ".index")) 44 | 45 | 46 | def main(_): 47 | tf.logging.set_verbosity(tf.logging.INFO) 48 | if FLAGS.checkpoints: 49 | # Get the checkpoints list from flags and run some basic checks. 50 | checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")] 51 | checkpoints = [c for c in checkpoints if c] 52 | if not checkpoints: 53 | raise ValueError("No checkpoints provided for averaging.") 54 | if FLAGS.prefix: 55 | checkpoints = [FLAGS.prefix + c for c in checkpoints] 56 | else: 57 | assert FLAGS.num_last_checkpoints >= 1, "Must average at least one model" 58 | assert FLAGS.prefix, ("Prefix must be provided when averaging last" 59 | " N checkpoints") 60 | checkpoint_state = tf.train.get_checkpoint_state( 61 | os.path.dirname(FLAGS.prefix)) 62 | # Checkpoints are ordered from oldest to newest. 63 | checkpoints = checkpoint_state.all_model_checkpoint_paths[ 64 | -FLAGS.num_last_checkpoints:] 65 | 66 | checkpoints = [c for c in checkpoints if checkpoint_exists(c)] 67 | if not checkpoints: 68 | if FLAGS.checkpoints: 69 | raise ValueError( 70 | "None of the provided checkpoints exist. %s" % FLAGS.checkpoints) 71 | else: 72 | raise ValueError("Could not find checkpoints at %s" % 73 | os.path.dirname(FLAGS.prefix)) 74 | 75 | # Read variables from all checkpoints and average them. 76 | tf.logging.info("Reading variables and averaging checkpoints:") 77 | for c in checkpoints: 78 | tf.logging.info("%s ", c) 79 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 80 | var_values, var_dtypes = {}, {} 81 | for (name, shape) in var_list: 82 | if not name.startswith("global_step"): 83 | var_values[name] = np.zeros(shape) 84 | for checkpoint in checkpoints: 85 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 86 | for name in var_values: 87 | tensor = reader.get_tensor(name) 88 | var_dtypes[name] = tensor.dtype 89 | var_values[name] += tensor 90 | tf.logging.info("Read from checkpoint %s", checkpoint) 91 | for name in var_values: # Average. 92 | var_values[name] /= len(checkpoints) 93 | 94 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): 95 | tf_vars = [ 96 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v]) 97 | for v in var_values 98 | ] 99 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 100 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 101 | global_step = tf.Variable( 102 | 0, name="global_step", trainable=False, dtype=tf.int64) 103 | saver = tf.train.Saver(tf.all_variables()) 104 | 105 | # Build a model consisting only of variables, set them to the average values. 106 | with tf.Session() as sess: 107 | sess.run(tf.initialize_all_variables()) 108 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 109 | six.iteritems(var_values)): 110 | sess.run(assign_op, {p: value}) 111 | # Use the built saver to save the averaged checkpoint. 112 | saver.save(sess, FLAGS.output_path, global_step=global_step) 113 | 114 | tf.logging.info("Averaged checkpoints saved in %s", FLAGS.output_path) 115 | 116 | 117 | if __name__ == "__main__": 118 | tf.app.run() 119 | -------------------------------------------------------------------------------- /tf/data_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import os 7 | from functools import partial 8 | 9 | from collections import Counter, OrderedDict 10 | import pickle 11 | import json 12 | import multiprocessing as mp 13 | 14 | import numpy as np 15 | 16 | from absl import flags 17 | import tensorflow as tf 18 | from vocabulary import Vocab 19 | 20 | from tensorflow.gfile import Exists as exists 21 | from tensorflow.gfile import MakeDirs as makedirs 22 | from tensorflow.gfile import Glob as glob 23 | 24 | 25 | def _preprocess(shard, train, vocab, save_dir, cutoffs, bin_sizes, bsz, tgt_len, 26 | num_core_per_host, use_tpu, num_shuffle): 27 | file_names = [] 28 | num_batch = 0 29 | 30 | path = train[shard] 31 | data_shard = vocab.encode_file(path, ordered=False, add_double_eos=True) 32 | 33 | for shuffle in range(num_shuffle): 34 | basename = "train-{:03d}-{:02d}".format(shard, shuffle) 35 | print("Processing shard {} shuffle {}".format(shard, shuffle)) 36 | 37 | np.random.shuffle(data_shard) 38 | file_name, num_batch_shuffle = create_ordered_tfrecords( 39 | save_dir, basename, np.concatenate(data_shard), bsz, tgt_len, 40 | num_core_per_host, cutoffs, bin_sizes, use_tpu=use_tpu) 41 | file_names.append(file_name) 42 | num_batch += num_batch_shuffle 43 | 44 | return file_names, num_batch 45 | 46 | 47 | class Corpus(object): 48 | def __init__(self, path, dataset, *args, **kwargs): 49 | self.dataset = dataset 50 | self.vocab = Vocab(*args, **kwargs) 51 | 52 | if self.dataset in ["ptb", "wt2", "enwik8", "text8"]: 53 | self.vocab.count_file(os.path.join(path, "train.txt")) 54 | self.vocab.count_file(os.path.join(path, "valid.txt")) 55 | self.vocab.count_file(os.path.join(path, "test.txt")) 56 | elif self.dataset == "wt103": 57 | self.vocab.count_file(os.path.join(path, "train.txt")) 58 | elif self.dataset == "lm1b": 59 | train_path_pattern = os.path.join( 60 | path, "1-billion-word-language-modeling-benchmark-r13output", 61 | "training-monolingual.tokenized.shuffled", "news.en-*") 62 | train_paths = glob(train_path_pattern) 63 | 64 | # the vocab will load from file when build_vocab() is called 65 | # for train_path in sorted(train_paths): 66 | # self.vocab.count_file(train_path, verbose=True) 67 | 68 | self.vocab.build_vocab() 69 | 70 | if self.dataset in ["ptb", "wt2", "wt103"]: 71 | self.train = self.vocab.encode_file( 72 | os.path.join(path, "train.txt"), ordered=True) 73 | self.valid = self.vocab.encode_file( 74 | os.path.join(path, "valid.txt"), ordered=True) 75 | self.test = self.vocab.encode_file( 76 | os.path.join(path, "test.txt"), ordered=True) 77 | elif self.dataset in ["enwik8", "text8"]: 78 | self.train = self.vocab.encode_file( 79 | os.path.join(path, "train.txt"), ordered=True, add_eos=False) 80 | self.valid = self.vocab.encode_file( 81 | os.path.join(path, "valid.txt"), ordered=True, add_eos=False) 82 | self.test = self.vocab.encode_file( 83 | os.path.join(path, "test.txt"), ordered=True, add_eos=False) 84 | elif self.dataset == "lm1b": 85 | self.train = train_paths 86 | valid_path = os.path.join(path, "valid.txt") 87 | test_path = valid_path 88 | self.valid = self.vocab.encode_file( 89 | valid_path, ordered=True, add_double_eos=True) 90 | self.test = self.vocab.encode_file( 91 | test_path, ordered=True, add_double_eos=True) 92 | 93 | if self.dataset == "wt103": 94 | self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)] 95 | elif self.dataset == "lm1b": 96 | self.cutoffs = [0, 60000, 100000, 640000] + [len(self.vocab)] 97 | else: 98 | self.cutoffs = [] 99 | 100 | 101 | def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len, 102 | num_core_per_host, **kwargs): 103 | FLAGS = kwargs.get('FLAGS') 104 | 105 | file_names = [] 106 | use_tpu = FLAGS.use_tpu and not (split == "test" and num_core_per_host == 1) 107 | 108 | if use_tpu: 109 | record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format( 110 | split, bsz, tgt_len, num_core_per_host) 111 | else: 112 | record_name = "record_info-{}.bsz-{}.tlen-{}.json".format( 113 | split, bsz, tgt_len) 114 | 115 | record_info_path = os.path.join(save_dir, record_name) 116 | 117 | if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: 118 | data = getattr(self, split) 119 | bin_sizes = get_bin_sizes( 120 | data, bsz // num_core_per_host, tgt_len, self.cutoffs) 121 | file_name, num_batch = create_ordered_tfrecords( 122 | save_dir, split, data, bsz, tgt_len, num_core_per_host, 123 | self.cutoffs, bin_sizes, 124 | num_passes=FLAGS.num_passes if split == 'train' and use_tpu else 1, 125 | use_tpu=use_tpu) 126 | file_names.append(file_name) 127 | elif self.dataset == "lm1b": 128 | bin_sizes = get_bin_sizes( 129 | self.valid, bsz // num_core_per_host, tgt_len, self.cutoffs) 130 | if split == "train": 131 | np.random.seed(123456) 132 | num_batch = 0 133 | 134 | if FLAGS.num_procs > 1: 135 | _preprocess_wrapper = partial(_preprocess, 136 | train=self.train, vocab=self.vocab, save_dir=save_dir, 137 | cutoffs=self.cutoffs, bin_sizes=bin_sizes, bsz=bsz, 138 | tgt_len=tgt_len, num_core_per_host=num_core_per_host, 139 | use_tpu=use_tpu, num_shuffle=FLAGS.num_shuffle) 140 | 141 | pool = mp.Pool(processes=FLAGS.num_procs) 142 | results = pool.map(_preprocess_wrapper, range(len(self.train))) 143 | for res in results: 144 | file_names.extend(res[0]) 145 | num_batch += res[1] 146 | else: 147 | for shard, path in enumerate(self.train): 148 | data_shard = self.vocab.encode_file(path, ordered=False, 149 | add_double_eos=True) 150 | 151 | num_shuffle = FLAGS.num_shuffle 152 | 153 | for shuffle in range(num_shuffle): 154 | print("Processing shard {} shuffle {}".format(shard, shuffle)) 155 | basename = "train-{:03d}-{:02d}".format(shard, shuffle) 156 | np.random.shuffle(data_shard) 157 | file_name, num_batch_ = create_ordered_tfrecords( 158 | save_dir, basename, np.concatenate(data_shard), bsz, tgt_len, 159 | num_core_per_host, 160 | self.cutoffs, bin_sizes, use_tpu=use_tpu) 161 | file_names.append(file_name) 162 | num_batch += num_batch_ 163 | 164 | else: 165 | file_name, num_batch = create_ordered_tfrecords( 166 | save_dir, split, getattr(self, split), bsz, tgt_len, 167 | num_core_per_host, 168 | self.cutoffs, bin_sizes, use_tpu=use_tpu) 169 | file_names.append(file_name) 170 | 171 | with open(record_info_path, "w") as fp: 172 | record_info = { 173 | "filenames": file_names, 174 | "bin_sizes": bin_sizes, 175 | "num_batch": num_batch 176 | } 177 | json.dump(record_info, fp) 178 | 179 | 180 | def get_bin_sizes(data, batch_size, tgt_len, cutoffs, std_mult=[2.5, 2.5, 2.5]): 181 | """ 182 | Note: the `batch_size` here should be per-core batch size 183 | """ 184 | bin_sizes = [] 185 | 186 | def _nearest_to_eight(x): # so that it's faster on TPUs 187 | y = x - x % 8 188 | return y + 8 if x % 8 >= 4 else max(8, y) 189 | 190 | if cutoffs: 191 | num_batch = len(data) // batch_size // tgt_len 192 | 193 | data = data[:batch_size * num_batch * tgt_len] 194 | data = data.reshape(batch_size, num_batch, tgt_len) 195 | 196 | tot = batch_size * tgt_len 197 | for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])): 198 | mask = (data >= left) * (data < right) 199 | percents = mask.astype(np.float64).sum(2).sum(0) / tot 200 | mean = np.mean(percents) 201 | std = np.std(percents) 202 | 203 | bin_size = int(math.ceil(tgt_len * batch_size * (mean + std_mult[b] * std))) 204 | bin_size = _nearest_to_eight(bin_size) 205 | bin_sizes.append(bin_size) 206 | 207 | return bin_sizes 208 | 209 | 210 | def _int64_feature(values): 211 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 212 | 213 | def _float_feature(values): 214 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 215 | 216 | def batchify(data, batch_size, num_passes): 217 | """ 218 | if use_tpu = True: num_passes > 1 219 | 220 | Since TPU training requires entire [bsz x tgt_len] chunks, it can discard 221 | as many as `bsz * tgt_len` tokens in training. When `bsz` and `tgt_len` are 222 | both large, as in the case of TPU training for Transformer-XL, the problem 223 | may lead to detectable performance drop. 224 | 225 | Here, we use multiple randomly shifted copies to deal with this problem. 226 | """ 227 | if num_passes > 1: 228 | data_len = len(data) 229 | double_data = np.concatenate([data, data]) 230 | data_list = [] 231 | for i in range(num_passes): 232 | start = np.random.randint(0, data_len) 233 | data_list.append(double_data[start:start+data_len]) 234 | data = np.concatenate(data_list) 235 | 236 | num_step = len(data) // batch_size 237 | data = data[:batch_size * num_step] 238 | data = data.reshape(batch_size, num_step) 239 | 240 | return data 241 | 242 | 243 | def create_ordered_tfrecords(save_dir, basename, data, batch_size, tgt_len, 244 | num_core_per_host, cutoffs=[], bin_sizes=[], 245 | num_passes=1, use_tpu=False): 246 | 247 | if use_tpu: 248 | file_name = "{}.bsz-{}.tlen-{}.core-{}.tfrecords".format( 249 | basename, batch_size, tgt_len, num_core_per_host) 250 | else: 251 | file_name = "{}.bsz-{}.tlen-{}.tfrecords".format( 252 | basename, batch_size, tgt_len) 253 | 254 | save_path = os.path.join(save_dir, file_name) 255 | record_writer = tf.python_io.TFRecordWriter(save_path) 256 | 257 | batched_data = batchify(data, batch_size, num_passes) 258 | 259 | num_batch = 0 260 | # for t in range(0, batched_data.shape[1] - tgt_len - 1, tgt_len): 261 | for t in range(0, batched_data.shape[1] - 1, tgt_len): 262 | cur_tgt_len = min(batched_data.shape[1] - 1 - t, tgt_len) 263 | # drop the remainder if use tpu 264 | if use_tpu and cur_tgt_len < tgt_len: 265 | break 266 | if num_batch % 500 == 0: 267 | print(" processing batch {}".format(num_batch)) 268 | for idx in range(batch_size): 269 | inputs = batched_data[idx, t:t + cur_tgt_len] 270 | labels = batched_data[idx, t + 1:t + cur_tgt_len + 1] 271 | 272 | # features dict 273 | feature = { 274 | "inputs": _int64_feature(inputs), 275 | "labels": _int64_feature(labels), 276 | } 277 | 278 | if len(cutoffs) > 0 and use_tpu: 279 | # validate `bin_sizes` and `cutoffs` 280 | assert len(cutoffs) - len(bin_sizes) == 2, \ 281 | "len(cutoffs) - len(bin_sizes) != 2" 282 | 283 | # mask for bin 0 284 | left, right = cutoffs[:2] 285 | inp_mask = ((inputs >= left) * (inputs < right)).astype(np.float32) 286 | tgt_mask = ((labels >= left) * (labels < right)).astype(np.float32) 287 | 288 | feature["inp_mask"] = _float_feature(inp_mask) 289 | feature["tgt_mask"] = _float_feature(tgt_mask) 290 | 291 | # refresh `inp_cnts` and `tgt_cnts` for each TPU core 292 | if idx % (batch_size // num_core_per_host) == 0: 293 | inp_cnts = [0] * len(bin_sizes) 294 | tgt_cnts = [0] * len(bin_sizes) 295 | 296 | head_labels = np.copy(labels) 297 | inp_pos_per_bin, tgt_pos_per_bin = [], [] 298 | for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])): 299 | inp_pos = np.where((inputs >= left) * (inputs < right))[0] 300 | tgt_pos = np.where((labels >= left) * (labels < right))[0] 301 | inp_pos_per_bin.append(inp_pos) 302 | tgt_pos_per_bin.append(tgt_pos) 303 | 304 | head_labels[tgt_pos] = cutoffs[1] + b 305 | 306 | feature["head_labels"] = _int64_feature(head_labels) 307 | 308 | # permutation feature 309 | def _add_perm_feature(feature, pos_per_bin, cnts, prefix): 310 | for b, pos in enumerate(pos_per_bin): 311 | idx_tuple = [] 312 | for p in pos: 313 | if cnts[b] < bin_sizes[b]: 314 | idx_tuple.append([p, cnts[b]]) 315 | cnts[b] += 1 316 | else: 317 | break 318 | 319 | n_tup = len(idx_tuple) 320 | tup = np.array(idx_tuple).reshape(n_tup * 2) 321 | 322 | feature["{}_cnt_{}".format(prefix, b)] = _int64_feature([n_tup]) 323 | feature["{}_tup_{}".format(prefix, b)] = _int64_feature(tup) 324 | 325 | _add_perm_feature(feature, inp_pos_per_bin, inp_cnts, "inp") 326 | _add_perm_feature(feature, tgt_pos_per_bin, tgt_cnts, "tgt") 327 | 328 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 329 | record_writer.write(example.SerializeToString()) 330 | 331 | num_batch += 1 332 | 333 | record_writer.close() 334 | print("Done writing {}. batches: {}".format(file_name, num_batch)) 335 | 336 | return file_name, num_batch 337 | 338 | 339 | def get_lm_corpus(data_dir, dataset): 340 | fn = os.path.join(data_dir, "cache.pkl") 341 | 342 | if exists(fn): 343 | print("Loading cached dataset...") 344 | with open(fn, "rb") as fp: 345 | corpus = pickle.load(fp) 346 | else: 347 | print("Producing dataset...") 348 | kwargs = {} 349 | if dataset in ["wt103", "wt2"]: 350 | kwargs["special"] = [""] 351 | kwargs["lower_case"] = False 352 | elif dataset == "ptb": 353 | kwargs["special"] = [""] 354 | kwargs["lower_case"] = True 355 | elif dataset == "lm1b": 356 | kwargs["special"] = [] 357 | kwargs["lower_case"] = False 358 | kwargs["vocab_file"] = os.path.join(data_dir, "1b_word_vocab.txt") 359 | elif dataset in ["enwik8", "text8"]: 360 | pass 361 | 362 | corpus = Corpus(data_dir, dataset, **kwargs) 363 | 364 | print("Saving dataset...") 365 | with open(fn, "wb") as fp: 366 | pickle.dump(corpus, fp, protocol=2) 367 | 368 | corpus_info = { 369 | "vocab_size" : len(corpus.vocab), 370 | "cutoffs" : corpus.cutoffs, 371 | "dataset" : corpus.dataset 372 | } 373 | with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp: 374 | json.dump(corpus_info, fp) 375 | 376 | return corpus 377 | 378 | 379 | def main(unused_argv): 380 | del unused_argv # Unused 381 | 382 | corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset) 383 | 384 | save_dir = os.path.join(FLAGS.data_dir, "tfrecords") 385 | if not exists(save_dir): 386 | makedirs(save_dir) 387 | 388 | # test mode 389 | if FLAGS.per_host_test_bsz > 0: 390 | corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz, 391 | FLAGS.tgt_len, FLAGS.num_core_per_host, 392 | FLAGS=FLAGS) 393 | return 394 | 395 | for split, batch_size in zip( 396 | ["train", "valid"], 397 | [FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]): 398 | 399 | if batch_size <= 0: continue 400 | print("Converting {} set...".format(split)) 401 | corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len, 402 | FLAGS.num_core_per_host, FLAGS=FLAGS) 403 | 404 | 405 | def load_record_info(record_info_dir, split, per_host_bsz, tgt_len, 406 | num_core_per_host, use_tpu): 407 | if use_tpu: 408 | record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format( 409 | split, per_host_bsz, tgt_len, num_core_per_host) 410 | else: 411 | record_name = "record_info-{}.bsz-{}.tlen-{}.json".format( 412 | split, per_host_bsz, tgt_len) 413 | 414 | record_info_path = os.path.join(record_info_dir, record_name) 415 | with open(record_info_path, "r") as fp: 416 | record_info = json.load(fp) 417 | 418 | return record_info 419 | 420 | def get_input_fn(record_info_dir, split, per_host_bsz, tgt_len, 421 | num_core_per_host, num_hosts=1, use_tpu=False): 422 | """Creates input function.""" 423 | record_info = load_record_info(record_info_dir, split, per_host_bsz, tgt_len, 424 | num_core_per_host, use_tpu=use_tpu) 425 | 426 | file_names = record_info["filenames"] 427 | bin_sizes = record_info["bin_sizes"] 428 | num_batch = record_info["num_batch"] 429 | 430 | tf.logging.info("[{}] File names {}".format(split, file_names)) 431 | 432 | def input_fn(params): 433 | # per-core batch size 434 | per_core_bsz = params["batch_size"] 435 | 436 | # data_dir could be a remote path, e.g., a google storage url 437 | data_dir = params["data_dir"] 438 | 439 | def parser(record): 440 | # preprocess "inp_perm" and "tgt_perm" 441 | def _process_perm_feature(example, prefix): 442 | for b in range(len(bin_sizes)): 443 | cnt = example.pop("{}_cnt_{}".format(prefix, b))[0] 444 | tup = example.pop("{}_tup_{}".format(prefix, b)) 445 | 446 | tup = tf.reshape( 447 | tf.sparse_tensor_to_dense(tup), 448 | shape=[cnt, 2]) 449 | 450 | # tf.float32 451 | perm = tf.sparse_to_dense( 452 | sparse_indices=tup, 453 | output_shape=[tgt_len, bin_sizes[b]], 454 | sparse_values=1.0, 455 | default_value=0.0) 456 | 457 | example["{}_perm_{}".format(prefix, b)] = perm 458 | 459 | # whether allow the last batch with a potentially shorter length 460 | if use_tpu: 461 | record_spec = { 462 | "inputs": tf.FixedLenFeature([tgt_len], tf.int64), 463 | "labels": tf.FixedLenFeature([tgt_len], tf.int64), 464 | } 465 | else: 466 | record_spec = { 467 | "inputs": tf.VarLenFeature(tf.int64), 468 | "labels": tf.VarLenFeature(tf.int64), 469 | } 470 | 471 | # permutation related features 472 | if bin_sizes and use_tpu: 473 | # tf.float32 474 | record_spec["inp_mask"] = tf.FixedLenFeature([tgt_len], tf.float32) 475 | record_spec["tgt_mask"] = tf.FixedLenFeature([tgt_len], tf.float32) 476 | 477 | record_spec["head_labels"] = tf.FixedLenFeature([tgt_len], tf.int64) 478 | 479 | for b in range(len(bin_sizes)): 480 | record_spec["inp_cnt_{}".format(b)] = tf.FixedLenFeature([1], tf.int64) 481 | record_spec["inp_tup_{}".format(b)] = tf.VarLenFeature(tf.int64) 482 | record_spec["tgt_cnt_{}".format(b)] = tf.FixedLenFeature([1], tf.int64) 483 | record_spec["tgt_tup_{}".format(b)] = tf.VarLenFeature(tf.int64) 484 | 485 | # retrieve serialized example 486 | example = tf.parse_single_example( 487 | serialized=record, 488 | features=record_spec) 489 | 490 | # transform permutation tuples to permutation matrices 491 | if bin_sizes and use_tpu: 492 | _process_perm_feature(example, "inp") 493 | _process_perm_feature(example, "tgt") 494 | 495 | # cast int64 into int32 496 | # cast sparse to dense 497 | for key in list(example.keys()): 498 | val = example[key] 499 | if tf.keras.backend.is_sparse(val): 500 | val = tf.sparse.to_dense(val) 501 | if val.dtype == tf.int64: 502 | val = tf.to_int32(val) 503 | example[key] = val 504 | 505 | if use_tpu: 506 | return example 507 | else: 508 | return example["inputs"], example["labels"] 509 | 510 | file_paths = [] 511 | for file_name in file_names: 512 | file_path = os.path.join(data_dir, file_name) 513 | file_paths.append(file_path) 514 | 515 | if split == "train": 516 | dataset = tf.data.Dataset.from_tensor_slices(file_paths) 517 | if len(file_paths) > 1: 518 | dataset = dataset.shuffle(len(file_paths)).repeat() 519 | dataset = tf.data.TFRecordDataset(dataset) 520 | elif num_hosts > 1: 521 | host_id = params["context"].current_host 522 | # drop the remaining batches 523 | num_batch_per_host = num_batch // num_hosts 524 | 525 | my_start_sample_id = (host_id * num_batch_per_host * num_core_per_host * 526 | per_core_bsz) 527 | my_sample_num = num_batch_per_host * num_core_per_host * per_core_bsz 528 | dataset = tf.data.TFRecordDataset(dataset).skip( 529 | my_start_sample_id).take(my_sample_num) 530 | else: 531 | dataset = tf.data.TFRecordDataset(dataset) 532 | 533 | dataset = dataset.map(parser).cache().repeat() 534 | dataset = dataset.batch(per_core_bsz, drop_remainder=True) 535 | dataset = dataset.prefetch(num_core_per_host * per_core_bsz) 536 | else: 537 | # do not shuffle, repeat or cache in evaluation 538 | dataset = tf.data.Dataset.from_tensor_slices(file_paths) 539 | dataset = tf.data.TFRecordDataset(dataset) 540 | dataset = dataset.map(parser) 541 | dataset = dataset.batch(per_core_bsz, drop_remainder=True) 542 | 543 | return dataset 544 | 545 | if split == "train" and num_hosts > 1: 546 | record_info["num_batch"] = num_batch // num_hosts 547 | 548 | return input_fn, record_info 549 | 550 | def get_corpus_info(corpus_info_path): 551 | with open(corpus_info_path, "r") as fp: 552 | corpus_info = json.load(fp) 553 | return corpus_info 554 | 555 | if __name__ == "__main__": 556 | FLAGS = flags.FLAGS 557 | flags.DEFINE_string("data_dir", None, 558 | help="Location of the data corpus") 559 | flags.DEFINE_enum("dataset", "wt103", 560 | ["ptb", "wt2", "wt103", "lm1b", "enwik8", "text8"], 561 | help="Dataset name.") 562 | flags.DEFINE_integer("per_host_train_bsz", 60, 563 | help="train batch size each host") 564 | flags.DEFINE_integer("per_host_valid_bsz", 60, 565 | help="valid batch size each host") 566 | flags.DEFINE_integer("per_host_test_bsz", 0, 567 | help="If > 0, enter test mode and process test set only." 568 | "Otherwise, process train and dev sets only.") 569 | flags.DEFINE_integer("tgt_len", 70, 570 | help="number of tokens to predict") 571 | flags.DEFINE_integer("max_batch", -1, 572 | help="run in debug mode") 573 | flags.DEFINE_integer("num_core_per_host", 8, 574 | help="8 for TPU v2.") 575 | flags.DEFINE_bool("debug", default=False, 576 | help="Process only the first batch without shuffle for lm1b.") 577 | flags.DEFINE_integer("num_procs", 1, 578 | help="number of processes") 579 | flags.DEFINE_integer("num_passes", 10, 580 | help="number of passes when use_tpu=True") 581 | flags.DEFINE_integer("num_shuffle", 4, 582 | help="number of shuffles for lm1b") 583 | flags.DEFINE_bool("use_tpu", True, 584 | help="use tpu") 585 | 586 | tf.app.run(main) 587 | -------------------------------------------------------------------------------- /tf/gpu_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"): 5 | def _assign(op): 6 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def 7 | if node_def.op == "Variable": 8 | return ps_dev 9 | else: 10 | return "/gpu:%d" % gpu 11 | return _assign 12 | 13 | 14 | def average_grads_and_vars(tower_grads_and_vars): 15 | def average_dense(grad_and_vars): 16 | if len(grad_and_vars) == 1: 17 | return grad_and_vars[0][0] 18 | 19 | grad = grad_and_vars[0][0] 20 | for g, _ in grad_and_vars[1:]: 21 | grad += g 22 | return grad / len(grad_and_vars) 23 | 24 | def average_sparse(grad_and_vars): 25 | if len(grad_and_vars) == 1: 26 | return grad_and_vars[0][0] 27 | 28 | indices = [] 29 | values = [] 30 | for g, _ in grad_and_vars: 31 | indices += [g.indices] 32 | values += [g.values] 33 | indices = tf.concat(indices, 0) 34 | values = tf.concat(values, 0) / len(grad_and_vars) 35 | return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape) 36 | 37 | average_grads_and_vars = [] 38 | for grad_and_vars in zip(*tower_grads_and_vars): 39 | if grad_and_vars[0][0] is None: 40 | grad = None 41 | elif isinstance(grad_and_vars[0][0], tf.IndexedSlices): 42 | grad = average_sparse(grad_and_vars) 43 | else: 44 | grad = average_dense(grad_and_vars) 45 | # Keep in mind that the Variables are redundant because they are shared 46 | # across towers. So .. we will just return the first tower's pointer to 47 | # the Variable. 48 | v = grad_and_vars[0][1] 49 | grad_and_var = (grad, v) 50 | average_grads_and_vars.append(grad_and_var) 51 | return average_grads_and_vars 52 | 53 | 54 | def load_from_checkpoint(saver, logdir): 55 | sess = tf.get_default_session() 56 | ckpt = tf.train.get_checkpoint_state(logdir) 57 | if ckpt and ckpt.model_checkpoint_path: 58 | if os.path.isabs(ckpt.model_checkpoint_path): 59 | # Restores from checkpoint with absolute path. 60 | saver.restore(sess, ckpt.model_checkpoint_path) 61 | else: 62 | # Restores from checkpoint with relative path. 63 | saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path)) 64 | return True 65 | return False 66 | -------------------------------------------------------------------------------- /tf/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def positional_embedding(pos_seq, inv_freq, bsz=None): 5 | sinusoid_inp = tf.einsum('i,j->ij', pos_seq, inv_freq) 6 | pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) 7 | if bsz is not None: 8 | return tf.tile(pos_emb[:, None, :], [1, bsz, 1]) 9 | else: 10 | return pos_emb[:, None, :] 11 | 12 | 13 | def positionwise_FF(inp, d_model, d_inner, dropout, kernel_initializer, 14 | scope='ff', is_training=True): 15 | output = inp 16 | with tf.variable_scope(scope): 17 | output = tf.layers.dense(inp, d_inner, activation=tf.nn.relu, 18 | kernel_initializer=kernel_initializer, 19 | name='layer_1') 20 | output = tf.layers.dropout(output, dropout, training=is_training, 21 | name='drop_1') 22 | output = tf.layers.dense(output, d_model, 23 | kernel_initializer=kernel_initializer, 24 | name='layer_2') 25 | output = tf.layers.dropout(output, dropout, training=is_training, 26 | name='drop_2') 27 | output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1) 28 | return output 29 | 30 | 31 | def rel_shift(x): 32 | x_size = tf.shape(x) 33 | 34 | x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]]) 35 | x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]]) 36 | x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) 37 | x = tf.reshape(x, x_size) 38 | 39 | return x 40 | 41 | 42 | def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model, 43 | n_head, d_head, dropout, dropatt, is_training, 44 | kernel_initializer, scope='rel_attn'): 45 | scale = 1 / (d_head ** 0.5) 46 | with tf.variable_scope(scope): 47 | qlen = tf.shape(w)[0] 48 | rlen = tf.shape(r)[0] 49 | bsz = tf.shape(w)[1] 50 | 51 | cat = tf.concat([mems, w], 52 | 0) if mems is not None and mems.shape.ndims > 1 else w 53 | w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False, 54 | kernel_initializer=kernel_initializer, name='qkv') 55 | r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False, 56 | kernel_initializer=kernel_initializer, name='r') 57 | 58 | w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1) 59 | w_head_q = w_head_q[-qlen:] 60 | 61 | klen = tf.shape(w_head_k)[0] 62 | 63 | w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head]) 64 | w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head]) 65 | w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head]) 66 | 67 | r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head]) 68 | 69 | rw_head_q = w_head_q + r_w_bias 70 | rr_head_q = w_head_q + r_r_bias 71 | 72 | AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k) 73 | BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k) 74 | BD = rel_shift(BD) 75 | 76 | attn_score = (AC + BD) * scale 77 | attn_mask_t = attn_mask[:, :, None, None] 78 | attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t 79 | 80 | attn_prob = tf.nn.softmax(attn_score, 1) 81 | attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) 82 | 83 | attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v) 84 | size_t = tf.shape(attn_vec) 85 | attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head]) 86 | 87 | attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False, 88 | kernel_initializer=kernel_initializer, name='o') 89 | attn_out = tf.layers.dropout(attn_out, dropout, training=is_training) 90 | 91 | output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1) 92 | return output 93 | 94 | 95 | def embedding_lookup(lookup_table, x, use_tpu=True): 96 | if use_tpu: 97 | n_token = tf.shape(lookup_table)[0] 98 | one_hot_idx = tf.one_hot(x, n_token) 99 | if one_hot_idx.shape.ndims == 2: 100 | return tf.einsum('nd,in->id', lookup_table, one_hot_idx) 101 | else: 102 | return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx) 103 | else: 104 | return tf.nn.embedding_lookup(lookup_table, x) 105 | 106 | 107 | def mask_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer, 108 | proj_initializer, div_val=1, 109 | proj_same_dim=True, 110 | scope='adaptive_embed', **kwargs): 111 | emb_scale = d_proj ** 0.5 112 | with tf.variable_scope(scope): 113 | if div_val == 1: 114 | lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], 115 | initializer=initializer) 116 | y = embedding_lookup(lookup_table, x, use_tpu=False) 117 | if d_proj != d_embed: 118 | proj_W = tf.get_variable('proj_W', [d_embed, d_proj], 119 | initializer=proj_initializer) 120 | y = tf.einsum('ibe,ed->ibd', y, proj_W) 121 | else: 122 | proj_W = None 123 | ret_params = [lookup_table, proj_W] 124 | else: 125 | tables, projs = [], [] 126 | cutoff_ends = [0] + cutoffs + [n_token] 127 | x_size = tf.shape(x) 128 | y = tf.zeros([x_size[0], x_size[1], d_proj]) 129 | for i in range(len(cutoff_ends) - 1): 130 | with tf.variable_scope('cutoff_{}'.format(i)): 131 | l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1] 132 | mask = (x >= l_idx) & (x < r_idx) 133 | cur_x = tf.boolean_mask(x, mask) - l_idx 134 | cur_d_embed = d_embed // (div_val ** i) 135 | lookup_table = tf.get_variable('lookup_table', 136 | [r_idx - l_idx, cur_d_embed], 137 | initializer=initializer) 138 | cur_y = embedding_lookup(lookup_table, cur_x, use_tpu=False) 139 | if d_proj == cur_d_embed and not proj_same_dim: 140 | proj_W = None 141 | else: 142 | proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj], 143 | initializer=proj_initializer) 144 | cur_y = tf.einsum('id,de->ie', cur_y, proj_W) 145 | mask_idx = tf.to_int64(tf.where(mask)) 146 | y += tf.scatter_nd(mask_idx, cur_y, tf.to_int64(tf.shape(y))) 147 | tables.append(lookup_table) 148 | projs.append(proj_W) 149 | ret_params = [tables, projs] 150 | 151 | y *= emb_scale 152 | return y, ret_params 153 | 154 | 155 | def mul_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer, 156 | proj_initializer, div_val=1, perms=None, 157 | proj_same_dim=True, 158 | scope='adaptive_embed'): 159 | """ 160 | perms: If None, first compute W = W1 x W2 (projection for each bin), 161 | and then compute X x W (embedding lookup). If not None, 162 | use bin-based embedding lookup with max_bin_size defined by 163 | the shape of perms. 164 | """ 165 | emb_scale = d_proj ** 0.5 166 | with tf.variable_scope(scope): 167 | if div_val == 1: 168 | lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], 169 | initializer=initializer) 170 | y = embedding_lookup(lookup_table, x) 171 | if d_proj != d_embed: 172 | proj_W = tf.get_variable('proj_W', [d_embed, d_proj], 173 | initializer=proj_initializer) 174 | y = tf.einsum('ibe,ed->ibd', y, proj_W) 175 | else: 176 | proj_W = None 177 | ret_params = [lookup_table, proj_W] 178 | else: 179 | tables, projs = [], [] 180 | cutoff_ends = [0] + cutoffs + [n_token] 181 | x_size = tf.shape(x) 182 | if perms is None: 183 | cat_lookup = [] 184 | else: 185 | cat_lookup = tf.zeros([x_size[0], x_size[1], d_proj]) 186 | for i in range(len(cutoff_ends) - 1): 187 | with tf.variable_scope('cutoff_{}'.format(i)): 188 | l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1] 189 | cur_d_embed = d_embed // (div_val ** i) 190 | lookup_table = tf.get_variable('lookup_table', 191 | [r_idx - l_idx, cur_d_embed], 192 | initializer=initializer) 193 | if cur_d_embed == d_proj and not proj_same_dim: 194 | proj_W = None 195 | else: 196 | proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj], 197 | initializer=proj_initializer) 198 | if perms is None: 199 | cat_lookup.append(tf.einsum('ie,ed->id', lookup_table, proj_W)) 200 | else: 201 | # speed up the computation of the first bin 202 | # also save some meory 203 | if i == 0: 204 | cur_y = embedding_lookup(lookup_table, tf.minimum(x, r_idx - 1)) 205 | if proj_W is not None: 206 | cur_y = tf.einsum('ibe,ed->ibd', cur_y, proj_W) 207 | cur_y *= perms[i][:, :, None] 208 | cat_lookup += cur_y 209 | else: 210 | cur_x = tf.einsum('ib,ibk->k', tf.to_float(x - l_idx), perms[i]) 211 | cur_x = tf.to_int32(cur_x) 212 | cur_y = embedding_lookup(lookup_table, cur_x) 213 | if proj_W is not None: 214 | cur_y = tf.einsum('ke,ed->kd', cur_y, proj_W) 215 | cat_lookup += tf.einsum('kd,ibk->ibd', cur_y, perms[i]) 216 | tables.append(lookup_table) 217 | projs.append(proj_W) 218 | if perms is None: 219 | cat_lookup = tf.concat(cat_lookup, 0) 220 | y = embedding_lookup(cat_lookup, x) 221 | else: 222 | y = cat_lookup 223 | ret_params = [tables, projs] 224 | 225 | y *= emb_scale 226 | return y, ret_params 227 | 228 | 229 | def mask_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs, 230 | params, tie_projs, 231 | initializer=None, proj_initializer=None, 232 | div_val=1, scope='adaptive_softmax', 233 | proj_same_dim=True, 234 | return_mean=True, **kwargs): 235 | def _logit(x, W, b, proj): 236 | y = x 237 | if proj is not None: 238 | y = tf.einsum('ibd,ed->ibe', y, proj) 239 | return tf.einsum('ibd,nd->ibn', y, W) + b 240 | 241 | params_W, params_projs = params[0], params[1] 242 | 243 | def _gather_logprob(logprob, target): 244 | lp_size = tf.shape(logprob) 245 | r = tf.range(lp_size[0]) 246 | idx = tf.stack([r, target], 1) 247 | return tf.gather_nd(logprob, idx) 248 | 249 | with tf.variable_scope(scope): 250 | if len(cutoffs) == 0: 251 | softmax_b = tf.get_variable('bias', [n_token], 252 | initializer=tf.zeros_initializer()) 253 | output = _logit(hidden, params_W, softmax_b, params_projs) 254 | nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, 255 | logits=output) 256 | else: 257 | cutoff_ends = [0] + cutoffs + [n_token] 258 | nll = tf.zeros_like(target, dtype=tf.float32) 259 | for i in range(len(cutoff_ends) - 1): 260 | with tf.variable_scope('cutoff_{}'.format(i)): 261 | l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1] 262 | mask = (target >= l_idx) & (target < r_idx) 263 | mask_idx = tf.where(mask) 264 | cur_target = tf.boolean_mask(target, mask) - l_idx 265 | cur_d_embed = d_embed // (div_val ** i) 266 | 267 | if div_val == 1: 268 | cur_W = params_W[l_idx: r_idx] 269 | else: 270 | cur_W = params_W[i] 271 | cur_b = tf.get_variable('b', [r_idx - l_idx], 272 | initializer=tf.zeros_initializer()) 273 | if tie_projs[i]: 274 | if div_val == 1: 275 | cur_proj = params_projs 276 | else: 277 | cur_proj = params_projs[i] 278 | else: 279 | if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed: 280 | cur_proj = None 281 | else: 282 | cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj], 283 | initializer=proj_initializer) 284 | if i == 0: 285 | cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed], 286 | initializer=tf.zeros_initializer()) 287 | cluster_b = tf.get_variable('cluster_b', [len(cutoffs)], 288 | initializer=tf.zeros_initializer()) 289 | cur_W = tf.concat([cur_W, cluster_W], 0) 290 | cur_b = tf.concat([cur_b, cluster_b], 0) 291 | 292 | head_logit = _logit(hidden, cur_W, cur_b, cur_proj) 293 | head_logprob = tf.nn.log_softmax(head_logit) 294 | cur_head_logprob = tf.boolean_mask(head_logprob, mask) 295 | cur_logprob = _gather_logprob(cur_head_logprob, cur_target) 296 | else: 297 | cur_head_logprob = tf.boolean_mask(head_logprob, mask) 298 | cur_hidden = tf.boolean_mask(hidden, mask) 299 | tail_logit = tf.squeeze(_logit( 300 | cur_hidden[None], cur_W, cur_b, cur_proj), 0) 301 | tail_logprob = tf.nn.log_softmax(tail_logit) 302 | cur_logprob = (cur_head_logprob[:, cutoff_ends[1] + i - 1] + 303 | _gather_logprob(tail_logprob, cur_target)) 304 | nll += tf.scatter_nd(mask_idx, -cur_logprob, 305 | tf.to_int64(tf.shape(nll))) 306 | if return_mean: 307 | nll = tf.reduce_mean(nll) 308 | return nll 309 | 310 | 311 | def mul_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs, 312 | params, tie_projs, 313 | initializer=None, proj_initializer=None, 314 | div_val=1, perms=None, proj_same_dim=True, 315 | scope='adaptive_softmax', 316 | **kwargs): 317 | def _logit(x, W, b, proj): 318 | y = x 319 | if x.shape.ndims == 3: 320 | if proj is not None: 321 | y = tf.einsum('ibd,ed->ibe', y, proj) 322 | return tf.einsum('ibd,nd->ibn', y, W) + b 323 | else: 324 | if proj is not None: 325 | y = tf.einsum('id,ed->ie', y, proj) 326 | return tf.einsum('id,nd->in', y, W) + b 327 | 328 | params_W, params_projs = params[0], params[1] 329 | 330 | with tf.variable_scope(scope): 331 | if len(cutoffs) == 0: 332 | softmax_b = tf.get_variable('bias', [n_token], 333 | initializer=tf.zeros_initializer()) 334 | output = _logit(hidden, params_W, softmax_b, params_projs) 335 | nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, 336 | logits=output) 337 | nll = tf.reduce_mean(nll) 338 | else: 339 | total_loss, total_cnt = 0, 0 340 | cutoff_ends = [0] + cutoffs + [n_token] 341 | for i in range(len(cutoff_ends) - 1): 342 | with tf.variable_scope('cutoff_{}'.format(i)): 343 | l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1] 344 | 345 | cur_d_embed = d_embed // (div_val ** i) 346 | 347 | if div_val == 1: 348 | cur_W = params_W[l_idx: r_idx] 349 | else: 350 | cur_W = params_W[i] 351 | cur_b = tf.get_variable('b', [r_idx - l_idx], 352 | initializer=tf.zeros_initializer()) 353 | if tie_projs[i]: 354 | if div_val == 1: 355 | cur_proj = params_projs 356 | else: 357 | cur_proj = params_projs[i] 358 | else: 359 | if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed: 360 | cur_proj = None 361 | else: 362 | cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj], 363 | initializer=proj_initializer) 364 | 365 | if i == 0: 366 | cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed], 367 | initializer=tf.zeros_initializer()) 368 | cluster_b = tf.get_variable('cluster_b', [len(cutoffs)], 369 | initializer=tf.zeros_initializer()) 370 | cur_W = tf.concat([cur_W, cluster_W], 0) 371 | cur_b = tf.concat([cur_b, cluster_b], 0) 372 | 373 | head_logit = _logit(hidden, cur_W, cur_b, cur_proj) 374 | 375 | head_target = kwargs.get("head_target") 376 | head_nll = tf.nn.sparse_softmax_cross_entropy_with_logits( 377 | labels=head_target, 378 | logits=head_logit) 379 | 380 | masked_loss = head_nll * perms[i] 381 | total_loss += tf.reduce_sum(masked_loss) 382 | total_cnt += tf.reduce_sum(perms[i]) 383 | 384 | # head_logprob = tf.nn.log_softmax(head_logit) 385 | 386 | # final_logprob = head_logprob * perms[i][:, :, None] 387 | # final_target = tf.one_hot(target, tf.shape(head_logprob)[2]) 388 | # total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target) 389 | # total_cnt += tf.reduce_sum(perms[i]) 390 | else: 391 | cur_head_nll = tf.einsum('ib,ibk->k', head_nll, perms[i]) 392 | 393 | cur_hidden = tf.einsum('ibd,ibk->kd', hidden, perms[i]) 394 | tail_logit = _logit(cur_hidden, cur_W, cur_b, cur_proj) 395 | 396 | tail_target = tf.einsum('ib,ibk->k', tf.to_float(target - l_idx), 397 | perms[i]) 398 | tail_nll = tf.nn.sparse_softmax_cross_entropy_with_logits( 399 | labels=tf.to_int32(tail_target), 400 | logits=tail_logit) 401 | 402 | sum_nll = cur_head_nll + tail_nll 403 | mask = tf.reduce_sum(perms[i], [0, 1]) 404 | 405 | masked_loss = sum_nll * mask 406 | total_loss += tf.reduce_sum(masked_loss) 407 | total_cnt += tf.reduce_sum(mask) 408 | 409 | nll = total_loss / total_cnt 410 | 411 | return nll 412 | 413 | 414 | def _create_mask(qlen, mlen, same_length=False): 415 | attn_mask = tf.ones([qlen, qlen]) 416 | mask_u = tf.matrix_band_part(attn_mask, 0, -1) 417 | mask_dia = tf.matrix_band_part(attn_mask, 0, 0) 418 | attn_mask_pad = tf.zeros([qlen, mlen]) 419 | ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) 420 | if same_length: 421 | mask_l = tf.matrix_band_part(attn_mask, -1, 0) 422 | ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) 423 | return ret 424 | 425 | def _cache_mem(curr_out, prev_mem, mem_len=None): 426 | if mem_len is None or prev_mem is None: 427 | new_mem = curr_out 428 | elif mem_len == 0: 429 | return prev_mem 430 | else: 431 | new_mem = tf.concat([prev_mem, curr_out], 0)[- mem_len:] 432 | 433 | return tf.stop_gradient(new_mem) 434 | 435 | 436 | def transformer(dec_inp, target, mems, n_token, n_layer, d_model, d_embed, 437 | n_head, d_head, d_inner, dropout, dropatt, 438 | initializer, is_training, proj_initializer=None, 439 | mem_len=None, cutoffs=[], div_val=1, tie_projs=[], 440 | same_length=False, clamp_len=-1, use_tpu=True, 441 | input_perms=None, target_perms=None, head_target=None, 442 | untie_r=False, proj_same_dim=True, 443 | scope='transformer'): 444 | """ 445 | cutoffs: a list of python int. Cutoffs for adaptive softmax. 446 | tie_projs: a list of python bools. Whether to tie the projections. 447 | use_tpu: if True, use one_hot in embedding lookup and bin-based implementation 448 | of adaptive softmax. 449 | perms: a list of tensors. Each tensor should of size [len, bsz, bin_size]. 450 | Only used in the adaptive setting. 451 | """ 452 | new_mems = [] 453 | with tf.variable_scope(scope): 454 | if untie_r: 455 | r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], 456 | initializer=initializer) 457 | r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], 458 | initializer=initializer) 459 | else: 460 | r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], 461 | initializer=initializer) 462 | r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], 463 | initializer=initializer) 464 | 465 | qlen = tf.shape(dec_inp)[0] 466 | mlen = tf.shape(mems[0])[0] if mems is not None else 0 467 | klen = mlen + qlen 468 | 469 | if proj_initializer is None: 470 | proj_initializer = initializer 471 | lookup_fn = (mul_adaptive_embedding_lookup if use_tpu else 472 | mask_adaptive_embedding_lookup) 473 | embeddings, shared_params = lookup_fn( 474 | x=dec_inp, 475 | n_token=n_token, 476 | d_embed=d_embed, 477 | d_proj=d_model, 478 | cutoffs=cutoffs, 479 | initializer=initializer, 480 | proj_initializer=proj_initializer, 481 | div_val= div_val, 482 | perms=input_perms, 483 | proj_same_dim=proj_same_dim) 484 | 485 | attn_mask = _create_mask(qlen, mlen, same_length) 486 | 487 | pos_seq = tf.range(klen - 1, -1, -1.0) 488 | if clamp_len > 0: 489 | pos_seq = tf.minimum(pos_seq, clamp_len) 490 | inv_freq = 1 / (10000 ** (tf.range(0, d_model, 2.0) / d_model)) 491 | pos_emb = positional_embedding(pos_seq, inv_freq) 492 | 493 | output = tf.layers.dropout(embeddings, dropout, training=is_training) 494 | pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) 495 | 496 | if mems is None: 497 | mems = [None] * n_layer 498 | 499 | for i in range(n_layer): 500 | # cache new mems 501 | new_mems.append(_cache_mem(output, mems[i], mem_len)) 502 | 503 | with tf.variable_scope('layer_{}'.format(i)): 504 | output = rel_multihead_attn( 505 | w=output, 506 | r=pos_emb, 507 | r_w_bias=r_w_bias if not untie_r else r_w_bias[i], 508 | r_r_bias=r_r_bias if not untie_r else r_r_bias[i], 509 | attn_mask=attn_mask, 510 | mems=mems[i], 511 | d_model=d_model, 512 | n_head=n_head, 513 | d_head=d_head, 514 | dropout=dropout, 515 | dropatt=dropatt, 516 | is_training=is_training, 517 | kernel_initializer=initializer) 518 | output = positionwise_FF( 519 | inp=output, 520 | d_model=d_model, 521 | d_inner=d_inner, 522 | dropout=dropout, 523 | kernel_initializer=initializer, 524 | is_training=is_training) 525 | 526 | output = tf.layers.dropout(output, dropout, training=is_training) 527 | 528 | logsoftmax_fn = (mul_adaptive_logsoftmax if use_tpu else 529 | mask_adaptive_logsoftmax) 530 | loss = logsoftmax_fn( 531 | hidden=output, 532 | target=target, 533 | n_token=n_token, 534 | d_embed=d_embed, 535 | d_proj=d_model, 536 | cutoffs=cutoffs, 537 | params=shared_params, 538 | tie_projs=tie_projs, 539 | initializer=initializer, 540 | proj_initializer=proj_initializer, 541 | div_val=div_val, 542 | perms=target_perms, 543 | head_target=head_target, 544 | proj_same_dim=proj_same_dim) 545 | return loss, new_mems 546 | 547 | -------------------------------------------------------------------------------- /tf/scripts/enwik8_base_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Data 4 | DATA_ROOT=../data/enwik8/ 5 | 6 | # Model 7 | N_LAYER=12 8 | D_MODEL=512 9 | D_EMBED=512 10 | N_HEAD=8 11 | D_HEAD=64 12 | D_INNER=2048 13 | 14 | # Training 15 | TGT_LEN=512 16 | MEM_LEN=512 17 | 18 | BSZ=24 19 | NUM_CORE=4 20 | 21 | # Testing 22 | TEST_TGT_LEN=80 23 | TEST_MEM_LEN=2100 24 | TEST_CLAMP_LEN=820 25 | 26 | TEST_BSZ=10 27 | TEST_NUM_CORE=1 28 | 29 | if [[ $1 == 'train_data' ]]; then 30 | python data_utils.py \ 31 | --data_dir=${DATA_ROOT}/ \ 32 | --dataset=enwik8 \ 33 | --tgt_len=${TGT_LEN} \ 34 | --per_host_train_bsz=${BSZ} \ 35 | --per_host_valid_bsz=${BSZ} \ 36 | --num_passes=1 \ 37 | --use_tpu=False \ 38 | ${@:2} 39 | elif [[ $1 == 'test_data' ]]; then 40 | python data_utils.py \ 41 | --data_dir=${DATA_ROOT}/ \ 42 | --dataset=enwik8 \ 43 | --tgt_len=${TEST_TGT_LEN} \ 44 | --per_host_test_bsz=${TEST_BSZ} \ 45 | --num_passes=1 \ 46 | --use_tpu=False \ 47 | ${@:2} 48 | elif [[ $1 == 'train' ]]; then 49 | echo 'Run training...' 50 | python train_gpu.py \ 51 | --data_dir=${DATA_ROOT}/tfrecords \ 52 | --record_info_dir=${DATA_ROOT}/tfrecords/ \ 53 | --corpus_info_path=${DATA_ROOT}/corpus-info.json \ 54 | --model_dir=EXP-enwik8 \ 55 | --n_layer=${N_LAYER} \ 56 | --d_model=${D_MODEL} \ 57 | --d_embed=${D_EMBED} \ 58 | --n_head=${N_HEAD} \ 59 | --d_head=${D_HEAD} \ 60 | --d_inner=${D_INNER} \ 61 | --dropout=0.1 \ 62 | --dropatt=0.0 \ 63 | --learning_rate=0.00025 \ 64 | --warmup_steps=0 \ 65 | --train_steps=400000 \ 66 | --tgt_len=${TGT_LEN} \ 67 | --mem_len=${MEM_LEN} \ 68 | --train_batch_size=${BSZ} \ 69 | --num_core_per_host=${NUM_CORE} \ 70 | --iterations=200 \ 71 | --save_steps=4000 \ 72 | --do_train=True \ 73 | --do_eval=False \ 74 | ${@:2} 75 | elif [[ $1 == 'eval' ]]; then 76 | echo 'Run evaluation...' 77 | python train_gpu.py \ 78 | --data_dir=${DATA_ROOT}/tfrecords \ 79 | --record_info_dir=${DATA_ROOT}/tfrecords/ \ 80 | --corpus_info_path=${DATA_ROOT}/corpus-info.json \ 81 | --model_dir=EXP-enwik8 \ 82 | --n_layer=${N_LAYER} \ 83 | --d_model=${D_MODEL} \ 84 | --d_embed=${D_EMBED} \ 85 | --n_head=${N_HEAD} \ 86 | --d_head=${D_HEAD} \ 87 | --d_inner=${D_INNER} \ 88 | --dropout=0.0 \ 89 | --dropatt=0.0 \ 90 | --tgt_len=${TEST_TGT_LEN} \ 91 | --mem_len=${TEST_MEM_LEN} \ 92 | --clamp_len=${TEST_CLAMP_LEN} \ 93 | --same_length=True \ 94 | --eval_batch_size=${TEST_BSZ} \ 95 | --num_core_per_host=${TEST_NUM_CORE} \ 96 | --do_train=False \ 97 | --do_eval=True \ 98 | --eval_split=test \ 99 | ${@:2} 100 | else 101 | echo 'unknown argment 1' 102 | fi -------------------------------------------------------------------------------- /tf/scripts/enwik8_large_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Path 4 | LOCAL_DIR=../data/enwik8/ 5 | GSDATA= 6 | GSEXP= 7 | 8 | # TPU setting 9 | NUM_HOST=2 10 | NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16 11 | 12 | TEST_NUM_HOST=1 13 | TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16 14 | 15 | # Model 16 | N_LAYER=24 17 | D_MODEL=1024 18 | D_EMBED=1024 19 | N_HEAD=8 20 | D_HEAD=128 21 | D_INNER=3072 22 | 23 | # Training 24 | TGT_LEN=768 25 | MEM_LEN=768 26 | TRAIN_BSZ=64 27 | VALID_BSZ=64 28 | 29 | # Testing 30 | TEST_TGT_LEN=128 31 | TEST_MEM_LEN=3800 32 | TEST_CLAMP_LEN=1000 33 | TEST_BSZ=16 34 | 35 | if [[ $1 == 'train_data' ]]; then 36 | python data_utils.py \ 37 | --data_dir=${LOCAL_DIR}/ \ 38 | --dataset=enwik8 \ 39 | --tgt_len=${TGT_LEN} \ 40 | --per_host_train_bsz=${TRAIN_BSZ} \ 41 | --per_host_valid_bsz=${VALID_BSZ} \ 42 | --num_core_per_host=${NUM_CORE} \ 43 | --num_passes=10 \ 44 | --use_tpu=True \ 45 | ${@:2} 46 | 47 | SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* 48 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/enwik8-tfrecords/ 49 | 50 | SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* 51 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/enwik8-tfrecords/ 52 | 53 | elif [[ $1 == 'test_data' ]]; then 54 | python data_utils.py \ 55 | --data_dir=${LOCAL_DIR}/ \ 56 | --dataset=enwik8 \ 57 | --tgt_len=${TEST_TGT_LEN} \ 58 | --per_host_test_bsz=${TEST_BSZ} \ 59 | --num_core_per_host=${TEST_NUM_CORE} \ 60 | --num_passes=1 \ 61 | --use_tpu=True \ 62 | ${@:2} 63 | 64 | SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}* 65 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/enwik8-tfrecords/ 66 | 67 | elif [[ $1 == 'train' ]]; then 68 | echo 'Run training...' 69 | python train.py \ 70 | --data_dir=${GSDATA}/enwik8-tfrecords \ 71 | --record_info_dir=${LOCAL_DIR}/tfrecords/ \ 72 | --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ 73 | --model_dir=${GSEXP}/enwik8 \ 74 | --n_layer=${N_LAYER} \ 75 | --d_model=${D_MODEL} \ 76 | --d_embed=${D_EMBED} \ 77 | --n_head=${N_HEAD} \ 78 | --d_head=${D_HEAD} \ 79 | --d_inner=${D_INNER} \ 80 | --dropout=0.15 \ 81 | --dropatt=0.15 \ 82 | --learning_rate=0.00025 \ 83 | --warmup_steps=4000 \ 84 | --train_steps=400000 \ 85 | --tgt_len=${TGT_LEN} \ 86 | --mem_len=${MEM_LEN} \ 87 | --train_batch_size=${TRAIN_BSZ} \ 88 | --use_tpu=True \ 89 | --num_host=${NUM_HOST} \ 90 | --num_core_per_host=${NUM_CORE} \ 91 | --iterations=1000 \ 92 | --save_steps=10000 \ 93 | --do_train=True \ 94 | --do_eval=False \ 95 | ${@:2} 96 | 97 | elif [[ $1 == 'eval' ]]; then 98 | echo 'Run evaluation...' 99 | python train.py \ 100 | --data_dir=${GSDATA}/enwik8-tfrecords \ 101 | --record_info_dir=${LOCAL_DIR}/tfrecords/ \ 102 | --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ 103 | --model_dir=${GSEXP}/enwik8 \ 104 | --n_layer=${N_LAYER} \ 105 | --d_model=${D_MODEL} \ 106 | --d_embed=${D_EMBED} \ 107 | --n_head=${N_HEAD} \ 108 | --d_head=${D_HEAD} \ 109 | --d_inner=${D_INNER} \ 110 | --tgt_len=${TEST_TGT_LEN} \ 111 | --mem_len=${TEST_MEM_LEN} \ 112 | --eval_batch_size=${TEST_BSZ} \ 113 | --num_host=${TEST_NUM_HOST} \ 114 | --num_core_per_host=${TEST_NUM_CORE} \ 115 | --use_tpu=True \ 116 | --do_train=False \ 117 | --do_eval_only=True \ 118 | --eval_split=test \ 119 | ${@:2} 120 | else 121 | echo 'unknown argment 1' 122 | fi 123 | -------------------------------------------------------------------------------- /tf/scripts/lm1b_base_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Data 4 | DATA_ROOT=../data/one-billion-words/ 5 | 6 | # Model 7 | DIV_VAL=4 8 | N_LAYER=18 9 | D_MODEL=1024 10 | D_EMBED=1024 11 | N_HEAD=8 12 | D_HEAD=128 13 | D_INNER=4096 14 | 15 | # Training 16 | TGT_LEN=256 17 | MEM_LEN=256 18 | 19 | BSZ=256 20 | NUM_CORE=4 21 | 22 | # Testing 23 | TEST_TGT_LEN=32 24 | TEST_MEM_LEN=128 25 | TEST_CLAMP_LEN=-1 26 | 27 | TEST_BSZ=16 28 | TEST_NUM_CORE=1 29 | 30 | 31 | if [[ $1 == 'train_data' ]]; then 32 | python data_utils.py \ 33 | --data_dir=${DATA_ROOT}/ \ 34 | --dataset=lm1b \ 35 | --tgt_len=${TGT_LEN} \ 36 | --per_host_train_bsz=${BSZ} \ 37 | --per_host_valid_bsz=${BSZ} \ 38 | --num_passes=1 \ 39 | --use_tpu=False \ 40 | ${@:2} 41 | elif [[ $1 == 'test_data' ]]; then 42 | python data_utils.py \ 43 | --data_dir=${DATA_ROOT}/ \ 44 | --dataset=lm1b \ 45 | --tgt_len=${TEST_TGT_LEN} \ 46 | --per_host_test_bsz=${TEST_BSZ} \ 47 | --num_passes=1 \ 48 | --use_tpu=False \ 49 | ${@:2} 50 | elif [[ $1 == 'train' ]]; then 51 | echo 'Run training...' 52 | python train_gpu.py \ 53 | --data_dir=${DATA_ROOT}/tfrecords \ 54 | --record_info_dir=${DATA_ROOT}/tfrecords/ \ 55 | --corpus_info_path=${DATA_ROOT}/corpus-info.json \ 56 | --model_dir=EXP-lm1b \ 57 | --div_val=${DIV_VAL} \ 58 | --untie_r=True \ 59 | --proj_share_all_but_first=False \ 60 | --proj_same_dim=False \ 61 | --n_layer=${N_LAYER} \ 62 | --d_model=${D_MODEL} \ 63 | --d_embed=${D_EMBED} \ 64 | --n_head=${N_HEAD} \ 65 | --d_head=${D_HEAD} \ 66 | --d_inner=${D_INNER} \ 67 | --dropout=0.1 \ 68 | --dropatt=0.0 \ 69 | --learning_rate=0.00025 \ 70 | --warmup_steps=0 \ 71 | --train_steps=400000 \ 72 | --tgt_len=${TGT_LEN} \ 73 | --mem_len=${MEM_LEN} \ 74 | --train_batch_size=${BSZ} \ 75 | --num_core_per_host=${NUM_CORE} \ 76 | --iterations=200 \ 77 | --save_steps=4000 \ 78 | ${@:2} 79 | elif [[ $1 == 'eval' ]]; then 80 | echo 'Run evaluation...' 81 | python train_gpu.py \ 82 | --data_dir=${DATA_ROOT}/tfrecords \ 83 | --record_info_dir=${DATA_ROOT}/tfrecords/ \ 84 | --corpus_info_path=${DATA_ROOT}/corpus-info.json \ 85 | --model_dir=EXP-lm1b \ 86 | --div_val=${DIV_VAL} \ 87 | --untie_r=True \ 88 | --proj_share_all_but_first=False \ 89 | --proj_same_dim=False \ 90 | --n_layer=${N_LAYER} \ 91 | --d_model=${D_MODEL} \ 92 | --d_embed=${D_EMBED} \ 93 | --n_head=${N_HEAD} \ 94 | --d_head=${D_HEAD} \ 95 | --d_inner=${D_INNER} \ 96 | --dropout=0.0 \ 97 | --dropatt=0.0 \ 98 | --tgt_len=${TEST_TGT_LEN} \ 99 | --mem_len=${TEST_MEM_LEN} \ 100 | --clamp_len=${TEST_CLAMP_LEN} \ 101 | --same_length=True \ 102 | --eval_batch_size=${TEST_BSZ} \ 103 | --num_core_per_host=${TEST_NUM_CORE} \ 104 | --do_train=False \ 105 | --do_eval=True \ 106 | --eval_split=test \ 107 | ${@:2} 108 | else 109 | echo 'unknown argment 1' 110 | fi 111 | -------------------------------------------------------------------------------- /tf/scripts/lm1b_large_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Path 4 | LOCAL_DIR=../data/one-billion-words/ 5 | GSDATA= 6 | GSEXP= 7 | 8 | # TPU setting 9 | NUM_HOST=32 10 | NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16 11 | 12 | TEST_NUM_HOST=1 13 | TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16 14 | 15 | # Model 16 | DIV_VAL=4 17 | N_LAYER=24 18 | D_MODEL=1280 19 | D_EMBED=1280 20 | N_HEAD=16 21 | D_HEAD=80 22 | D_INNER=8192 23 | 24 | # Training 25 | TGT_LEN=32 26 | MEM_LEN=32 27 | TRAIN_BSZ=512 28 | VALID_BSZ=512 29 | TRAIN_BSZ_PER_HOST=$((TRAIN_BSZ / NUM_HOST)) 30 | VALID_BSZ_PER_HOST=$((VALID_BSZ / NUM_HOST)) 31 | 32 | # Testing 33 | TEST_TGT_LEN=32 34 | TEST_MEM_LEN=128 35 | TEST_CLAMP_LEN=-1 36 | TEST_BSZ=8 37 | 38 | if [[ $1 == 'train_data' ]]; then 39 | python data_utils.py \ 40 | --data_dir=${LOCAL_DIR}/ \ 41 | --dataset=lm1b \ 42 | --tgt_len=${TGT_LEN} \ 43 | --per_host_train_bsz=${TRAIN_BSZ_PER_HOST} \ 44 | --per_host_valid_bsz=${VALID_BSZ_PER_HOST} \ 45 | --num_core_per_host=${NUM_CORE} \ 46 | --num_passes=10 \ 47 | --use_tpu=True \ 48 | ${@:2} 49 | 50 | SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* 51 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/lm1b-tfrecords/ 52 | 53 | SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* 54 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/lm1b-tfrecords/ 55 | 56 | elif [[ $1 == 'test_data' ]]; then 57 | python data_utils.py \ 58 | --data_dir=${LOCAL_DIR}/ \ 59 | --dataset=lm1b \ 60 | --tgt_len=${TEST_TGT_LEN} \ 61 | --per_host_test_bsz=${TEST_BSZ} \ 62 | --num_core_per_host=${TEST_NUM_CORE} \ 63 | --num_passes=1 \ 64 | --use_tpu=True \ 65 | ${@:2} 66 | 67 | SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}* 68 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/lm1b-tfrecords/ 69 | 70 | elif [[ $1 == 'train' ]]; then 71 | echo 'Run training...' 72 | python train.py \ 73 | --data_dir=${GSDATA}/lm1b-tfrecords \ 74 | --record_info_dir=${LOCAL_DIR}/tfrecords/ \ 75 | --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ 76 | --model_dir=${GSEXP}/lm1b \ 77 | --div_val=${DIV_VAL} \ 78 | --untie_r=True \ 79 | --proj_share_all_but_first=False \ 80 | --proj_same_dim=False \ 81 | --n_layer=${N_LAYER} \ 82 | --d_model=${D_MODEL} \ 83 | --d_embed=${D_EMBED} \ 84 | --n_head=${N_HEAD} \ 85 | --d_head=${D_HEAD} \ 86 | --d_inner=${D_INNER} \ 87 | --dropout=0.05 \ 88 | --dropatt=0.05 \ 89 | --init_std=0.005 \ 90 | --learning_rate=0.0001 \ 91 | --warmup_steps=30000 \ 92 | --train_steps=1200000 \ 93 | --tgt_len=${TGT_LEN} \ 94 | --mem_len=${MEM_LEN} \ 95 | --train_batch_size=${TRAIN_BSZ} \ 96 | --num_hosts=${NUM_HOST} \ 97 | --num_core_per_host=${NUM_CORE} \ 98 | --iterations=1000 \ 99 | --save_steps=10000 \ 100 | --use_tpu=True \ 101 | --do_eval=False \ 102 | ${@:2} 103 | 104 | elif [[ $1 == 'eval' ]]; then 105 | echo 'Run evaluation...' 106 | python train.py \ 107 | --data_dir=${GSDATA}/lm1b-tfrecords \ 108 | --record_info_dir=${LOCAL_DIR}/tfrecords/ \ 109 | --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ 110 | --model_dir=${GSEXP}/lm1b \ 111 | --div_val=${DIV_VAL} \ 112 | --untie_r=True \ 113 | --proj_share_all_but_first=False \ 114 | --proj_same_dim=False \ 115 | --n_layer=${N_LAYER} \ 116 | --d_model=${D_MODEL} \ 117 | --d_embed=${D_EMBED} \ 118 | --n_head=${N_HEAD} \ 119 | --d_head=${D_HEAD} \ 120 | --d_inner=${D_INNER} \ 121 | --tgt_len=${TEST_TGT_LEN} \ 122 | --mem_len=${TEST_MEM_LEN} \ 123 | --clamp_len=${TEST_CLAMP_LEN} \ 124 | --same_length=True \ 125 | --eval_batch_size=${TEST_BSZ} \ 126 | --num_host=${TEST_NUM_HOST} \ 127 | --num_core_per_host=${TEST_NUM_CORE} \ 128 | --use_tpu=True \ 129 | --do_train=False \ 130 | --do_eval_only=True \ 131 | --eval_split=test \ 132 | ${@:2} 133 | 134 | else 135 | echo 'unknown argment 1' 136 | fi 137 | -------------------------------------------------------------------------------- /tf/scripts/text8_base_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Data 4 | DATA_ROOT=../data/text8/ 5 | 6 | # Model 7 | N_LAYER=12 8 | D_MODEL=512 9 | D_EMBED=512 10 | N_HEAD=8 11 | D_HEAD=64 12 | D_INNER=2048 13 | 14 | # Training 15 | TGT_LEN=512 16 | MEM_LEN=512 17 | 18 | BSZ=24 19 | NUM_CORE=4 20 | 21 | # Testing 22 | TEST_TGT_LEN=80 23 | TEST_MEM_LEN=2100 24 | TEST_CLAMP_LEN=820 25 | 26 | TEST_BSZ=10 27 | TEST_NUM_CORE=1 28 | 29 | if [[ $1 == 'train_data' ]]; then 30 | python data_utils.py \ 31 | --data_dir=${DATA_ROOT}/ \ 32 | --dataset=text8 \ 33 | --tgt_len=${TGT_LEN} \ 34 | --per_host_train_bsz=${BSZ} \ 35 | --per_host_valid_bsz=${BSZ} \ 36 | --num_passes=1 \ 37 | --use_tpu=False \ 38 | ${@:2} 39 | elif [[ $1 == 'test_data' ]]; then 40 | python data_utils.py \ 41 | --data_dir=${DATA_ROOT}/ \ 42 | --dataset=text8 \ 43 | --tgt_len=${TEST_TGT_LEN} \ 44 | --per_host_test_bsz=${TEST_BSZ} \ 45 | --num_passes=1 \ 46 | --use_tpu=False \ 47 | ${@:2} 48 | elif [[ $1 == 'train' ]]; then 49 | echo 'Run training...' 50 | python train_gpu.py \ 51 | --data_dir=${DATA_ROOT}/tfrecords \ 52 | --record_info_dir=${DATA_ROOT}/tfrecords/ \ 53 | --corpus_info_path=${DATA_ROOT}/corpus-info.json \ 54 | --model_dir=EXP-text8 \ 55 | --n_layer=${N_LAYER} \ 56 | --d_model=${D_MODEL} \ 57 | --d_embed=${D_EMBED} \ 58 | --n_head=${N_HEAD} \ 59 | --d_head=${D_HEAD} \ 60 | --d_inner=${D_INNER} \ 61 | --dropout=0.1 \ 62 | --dropatt=0.0 \ 63 | --learning_rate=0.00025 \ 64 | --warmup_steps=0 \ 65 | --train_steps=400000 \ 66 | --tgt_len=${TGT_LEN} \ 67 | --mem_len=${MEM_LEN} \ 68 | --train_batch_size=${BSZ} \ 69 | --num_core_per_host=${NUM_CORE} \ 70 | --iterations=200 \ 71 | --save_steps=4000 \ 72 | --do_train=True \ 73 | --do_eval=False \ 74 | ${@:2} 75 | elif [[ $1 == 'eval' ]]; then 76 | echo 'Run evaluation...' 77 | python train_gpu.py \ 78 | --data_dir=${DATA_ROOT}/tfrecords \ 79 | --record_info_dir=${DATA_ROOT}/tfrecords/ \ 80 | --corpus_info_path=${DATA_ROOT}/corpus-info.json \ 81 | --model_dir=EXP-text8 \ 82 | --n_layer=${N_LAYER} \ 83 | --d_model=${D_MODEL} \ 84 | --d_embed=${D_EMBED} \ 85 | --n_head=${N_HEAD} \ 86 | --d_head=${D_HEAD} \ 87 | --d_inner=${D_INNER} \ 88 | --dropout=0.0 \ 89 | --dropatt=0.0 \ 90 | --tgt_len=${TEST_TGT_LEN} \ 91 | --mem_len=${TEST_MEM_LEN} \ 92 | --clamp_len=${TEST_CLAMP_LEN} \ 93 | --same_length=True \ 94 | --eval_batch_size=${TEST_BSZ} \ 95 | --num_core_per_host=${TEST_NUM_CORE} \ 96 | --do_train=False \ 97 | --do_eval=True \ 98 | --eval_split=test \ 99 | ${@:2} 100 | else 101 | echo 'unknown argment 1' 102 | fi -------------------------------------------------------------------------------- /tf/scripts/text8_large_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Path 4 | LOCAL_DIR=../data/text8/ 5 | GSDATA= 6 | GSEXP= 7 | 8 | # TPU setting 9 | NUM_HOST=2 10 | NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16 11 | 12 | TEST_NUM_HOST=1 13 | TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16 14 | 15 | # Model 16 | N_LAYER=24 17 | D_MODEL=1024 18 | D_EMBED=1024 19 | N_HEAD=8 20 | D_HEAD=128 21 | D_INNER=3072 22 | 23 | # Training 24 | TGT_LEN=768 25 | MEM_LEN=768 26 | TRAIN_BSZ=64 27 | VALID_BSZ=64 28 | 29 | # Testing 30 | TEST_TGT_LEN=128 31 | TEST_MEM_LEN=3800 32 | TEST_CLAMP_LEN=1000 33 | TEST_BSZ=16 34 | 35 | if [[ $1 == 'train_data' ]]; then 36 | python data_utils.py \ 37 | --data_dir=${LOCAL_DIR}/ \ 38 | --dataset=text8 \ 39 | --tgt_len=${TGT_LEN} \ 40 | --per_host_train_bsz=${TRAIN_BSZ} \ 41 | --per_host_valid_bsz=${VALID_BSZ} \ 42 | --num_core_per_host=${NUM_CORE} \ 43 | --num_passes=10 \ 44 | --use_tpu=True \ 45 | ${@:2} 46 | 47 | SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* 48 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/text8-tfrecords/ 49 | 50 | SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* 51 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/text8-tfrecords/ 52 | 53 | elif [[ $1 == 'test_data' ]]; then 54 | python data_utils.py \ 55 | --data_dir=${LOCAL_DIR}/ \ 56 | --dataset=text8 \ 57 | --tgt_len=${TEST_TGT_LEN} \ 58 | --per_host_test_bsz=${TEST_BSZ} \ 59 | --num_core_per_host=${TEST_NUM_CORE} \ 60 | --num_passes=1 \ 61 | --use_tpu=True \ 62 | ${@:2} 63 | 64 | SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}* 65 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/text8-tfrecords/ 66 | 67 | elif [[ $1 == 'train' ]]; then 68 | echo 'Run training...' 69 | python train.py \ 70 | --data_dir=${GSDATA}/text8-tfrecords \ 71 | --record_info_dir=${LOCAL_DIR}/tfrecords/ \ 72 | --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ 73 | --model_dir=${GSEXP}/text8 \ 74 | --n_layer=${N_LAYER} \ 75 | --d_model=${D_MODEL} \ 76 | --d_embed=${D_EMBED} \ 77 | --n_head=${N_HEAD} \ 78 | --d_head=${D_HEAD} \ 79 | --d_inner=${D_INNER} \ 80 | --dropout=0.15 \ 81 | --dropatt=0.15 \ 82 | --learning_rate=0.00025 \ 83 | --warmup_steps=4000 \ 84 | --train_steps=400000 \ 85 | --tgt_len=${TGT_LEN} \ 86 | --mem_len=${MEM_LEN} \ 87 | --train_batch_size=${TRAIN_BSZ} \ 88 | --use_tpu=True \ 89 | --num_host=${NUM_HOST} \ 90 | --num_core_per_host=${NUM_CORE} \ 91 | --iterations=1000 \ 92 | --save_steps=10000 \ 93 | --do_train=True \ 94 | --do_eval=False \ 95 | ${@:2} 96 | 97 | elif [[ $1 == 'eval' ]]; then 98 | echo 'Run evaluation...' 99 | python train.py \ 100 | --data_dir=${GSDATA}/text8-tfrecords \ 101 | --record_info_dir=${LOCAL_DIR}/tfrecords/ \ 102 | --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ 103 | --model_dir=${GSEXP}/text8 \ 104 | --n_layer=${N_LAYER} \ 105 | --d_model=${D_MODEL} \ 106 | --d_embed=${D_EMBED} \ 107 | --n_head=${N_HEAD} \ 108 | --d_head=${D_HEAD} \ 109 | --d_inner=${D_INNER} \ 110 | --tgt_len=${TEST_TGT_LEN} \ 111 | --mem_len=${TEST_MEM_LEN} \ 112 | --eval_batch_size=${TEST_BSZ} \ 113 | --num_host=${TEST_NUM_HOST} \ 114 | --num_core_per_host=${TEST_NUM_CORE} \ 115 | --use_tpu=True \ 116 | --do_train=False \ 117 | --do_eval_only=True \ 118 | --eval_split=test \ 119 | ${@:2} 120 | else 121 | echo 'unknown argment 1' 122 | fi 123 | -------------------------------------------------------------------------------- /tf/scripts/wt103_base_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Data 4 | DATA_ROOT=../data/wikitext-103/ 5 | 6 | # Model 7 | DIV_VAL=1 8 | N_LAYER=16 9 | D_MODEL=410 10 | D_EMBED=410 11 | N_HEAD=10 12 | D_HEAD=41 13 | D_INNER=2100 14 | 15 | # Training 16 | TGT_LEN=150 17 | MEM_LEN=150 18 | 19 | BSZ=60 20 | NUM_CORE=4 21 | 22 | # Testing 23 | TEST_TGT_LEN=64 24 | TEST_MEM_LEN=640 25 | TEST_CLAMP_LEN=400 26 | 27 | TEST_BSZ=10 28 | TEST_NUM_CORE=1 29 | 30 | 31 | if [[ $1 == 'train_data' ]]; then 32 | python data_utils.py \ 33 | --data_dir=${DATA_ROOT}/ \ 34 | --dataset=wt103 \ 35 | --tgt_len=${TGT_LEN} \ 36 | --per_host_train_bsz=${BSZ} \ 37 | --per_host_valid_bsz=${BSZ} \ 38 | --num_passes=1 \ 39 | --use_tpu=False \ 40 | ${@:2} 41 | elif [[ $1 == 'test_data' ]]; then 42 | python data_utils.py \ 43 | --data_dir=${DATA_ROOT}/ \ 44 | --dataset=enwik8 \ 45 | --tgt_len=${TEST_TGT_LEN} \ 46 | --per_host_test_bsz=${TEST_BSZ} \ 47 | --num_passes=1 \ 48 | --use_tpu=False \ 49 | ${@:2} 50 | elif [[ $1 == 'train' ]]; then 51 | echo 'Run training...' 52 | python train_gpu.py \ 53 | --data_dir=${DATA_ROOT}/tfrecords \ 54 | --record_info_dir=${DATA_ROOT}/tfrecords/ \ 55 | --corpus_info_path=${DATA_ROOT}/corpus-info.json \ 56 | --model_dir=EXP-wt103 \ 57 | --div_val=${DIV_VAL} \ 58 | --untie_r=True \ 59 | --proj_share_all_but_first=True \ 60 | --n_layer=${N_LAYER} \ 61 | --d_model=${D_MODEL} \ 62 | --d_embed=${D_EMBED} \ 63 | --n_head=${N_HEAD} \ 64 | --d_head=${D_HEAD} \ 65 | --d_inner=${D_INNER} \ 66 | --dropout=0.1 \ 67 | --dropatt=0.0 \ 68 | --learning_rate=0.00025 \ 69 | --warmup_steps=0 \ 70 | --train_steps=400000 \ 71 | --tgt_len=${TGT_LEN} \ 72 | --mem_len=${MEM_LEN} \ 73 | --train_batch_size=${BSZ} \ 74 | --num_core_per_host=${NUM_CORE} \ 75 | --iterations=200 \ 76 | --save_steps=4000 \ 77 | ${@:2} 78 | elif [[ $1 == 'eval' ]]; then 79 | echo 'Run evaluation...' 80 | python train_gpu.py \ 81 | --data_dir=${DATA_ROOT}/tfrecords \ 82 | --record_info_dir=${DATA_ROOT}/tfrecords/ \ 83 | --corpus_info_path=${DATA_ROOT}/corpus-info.json \ 84 | --model_dir=EXP-wt103 \ 85 | --div_val=${DIV_VAL} \ 86 | --untie_r=True \ 87 | --proj_share_all_but_first=True \ 88 | --n_layer=${N_LAYER} \ 89 | --d_model=${D_MODEL} \ 90 | --d_embed=${D_EMBED} \ 91 | --n_head=${N_HEAD} \ 92 | --d_head=${D_HEAD} \ 93 | --d_inner=${D_INNER} \ 94 | --dropout=0.0 \ 95 | --dropatt=0.0 \ 96 | --tgt_len=${TEST_TGT_LEN} \ 97 | --mem_len=${TEST_MEM_LEN} \ 98 | --clamp_len=${TEST_CLAMP_LEN} \ 99 | --same_length=True \ 100 | --eval_batch_size=${TEST_BSZ} \ 101 | --num_core_per_host=${TEST_NUM_CORE} \ 102 | --do_train=False \ 103 | --do_eval=True \ 104 | --eval_split=test \ 105 | ${@:2} 106 | else 107 | echo 'unknown argment 1' 108 | fi -------------------------------------------------------------------------------- /tf/scripts/wt103_large_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Path 4 | LOCAL_DIR=../data/wikitext-103/ 5 | GSDATA= 6 | GSEXP= 7 | 8 | # TPU setting 9 | NUM_HOST=4 10 | NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16 11 | 12 | TEST_NUM_HOST=1 13 | TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16 14 | 15 | # Model 16 | DIV_VAL=4 17 | N_LAYER=18 18 | D_MODEL=1024 19 | D_EMBED=1024 20 | N_HEAD=16 21 | D_HEAD=64 22 | D_INNER=4096 23 | 24 | # Training 25 | TGT_LEN=384 26 | MEM_LEN=384 27 | TRAIN_BSZ=128 28 | VALID_BSZ=128 29 | 30 | # Testing 31 | TEST_TGT_LEN=128 32 | TEST_MEM_LEN=1600 33 | TEST_CLAMP_LEN=1000 34 | TEST_BSZ=8 35 | 36 | if [[ $1 == 'train_data' ]]; then 37 | python data_utils.py \ 38 | --data_dir=${LOCAL_DIR}/ \ 39 | --dataset=wt103 \ 40 | --tgt_len=${TGT_LEN} \ 41 | --per_host_train_bsz=${TRAIN_BSZ} \ 42 | --per_host_valid_bsz=${VALID_BSZ} \ 43 | --num_core_per_host=${NUM_CORE} \ 44 | --num_passes=10 \ 45 | --use_tpu=True \ 46 | ${@:2} 47 | 48 | SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* 49 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/wt103-tfrecords/ 50 | 51 | SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* 52 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/wt103-tfrecords/ 53 | 54 | elif [[ $1 == 'test_data' ]]; then 55 | python data_utils.py \ 56 | --data_dir=${LOCAL_DIR}/ \ 57 | --dataset=wt103 \ 58 | --tgt_len=${TEST_TGT_LEN} \ 59 | --per_host_test_bsz=${TEST_BSZ} \ 60 | --num_core_per_host=${TEST_NUM_CORE} \ 61 | --num_passes=1 \ 62 | --use_tpu=True \ 63 | ${@:2} 64 | 65 | SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}* 66 | gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/wt103-tfrecords/ 67 | 68 | elif [[ $1 == 'train' ]]; then 69 | echo 'Run training...' 70 | python train.py \ 71 | --data_dir=${GSDATA}/wt103-tfrecords \ 72 | --record_info_dir=${LOCAL_DIR}/tfrecords/ \ 73 | --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ 74 | --model_dir=${GSEXP}/wt103 \ 75 | --div_val=${DIV_VAL} \ 76 | --untie_r=True \ 77 | --proj_share_all_but_first=True \ 78 | --proj_same_dim=True \ 79 | --n_layer=${N_LAYER} \ 80 | --d_model=${D_MODEL} \ 81 | --d_embed=${D_EMBED} \ 82 | --n_head=${N_HEAD} \ 83 | --d_head=${D_HEAD} \ 84 | --d_inner=${D_INNER} \ 85 | --dropout=0.2 \ 86 | --dropatt=0.2 \ 87 | --init_std=0.005 \ 88 | --learning_rate=0.00025 \ 89 | --warmup_steps=16000 \ 90 | --train_steps=4000000 \ 91 | --tgt_len=${TGT_LEN} \ 92 | --mem_len=${MEM_LEN} \ 93 | --train_batch_size=${TRAIN_BSZ} \ 94 | --num_hosts=${NUM_HOST} \ 95 | --num_core_per_host=${NUM_CORE} \ 96 | --iterations=1000 \ 97 | --save_steps=10000 \ 98 | --use_tpu=True \ 99 | --do_eval=False \ 100 | ${@:2} 101 | 102 | elif [[ $1 == 'eval' ]]; then 103 | echo 'Run evaluation...' 104 | python train.py \ 105 | --data_dir=${GSDATA}/wt103-tfrecords \ 106 | --record_info_dir=${LOCAL_DIR}/tfrecords/ \ 107 | --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ 108 | --model_dir=${GSEXP}/wt103 \ 109 | --div_val=${DIV_VAL} \ 110 | --untie_r=True \ 111 | --proj_share_all_but_first=True \ 112 | --proj_same_dim=True \ 113 | --n_layer=${N_LAYER} \ 114 | --d_model=${D_MODEL} \ 115 | --d_embed=${D_EMBED} \ 116 | --n_head=${N_HEAD} \ 117 | --d_head=${D_HEAD} \ 118 | --d_inner=${D_INNER} \ 119 | --tgt_len=${TEST_TGT_LEN} \ 120 | --mem_len=${TEST_MEM_LEN} \ 121 | --clamp_len=${TEST_CLAMP_LEN} \ 122 | --same_length=True \ 123 | --eval_batch_size=${TEST_BSZ} \ 124 | --num_host=${TEST_NUM_HOST} \ 125 | --num_core_per_host=${TEST_NUM_CORE} \ 126 | --use_tpu=True \ 127 | --do_train=False \ 128 | --do_eval_only=True \ 129 | --eval_split=test \ 130 | ${@:2} 131 | 132 | else 133 | echo 'unknown argment 1' 134 | fi 135 | -------------------------------------------------------------------------------- /tf/sota/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | URL=http://curtis.ml.cmu.edu/datasets/pretrained_xl 4 | 5 | DATA_ROOT=./ 6 | 7 | function download () { 8 | fileurl=${1} 9 | filename=${fileurl##*/} 10 | if [ ! -f ${filename} ]; then 11 | echo ">>> Download '${filename}' from '${fileurl}'." 12 | wget --quiet ${fileurl} 13 | else 14 | echo "*** File '${filename}' exists. Skip." 15 | fi 16 | } 17 | 18 | cd $DATA_ROOT 19 | mkdir -p pretrained_xl && cd pretrained_xl 20 | 21 | # enwik8 22 | mkdir -p tf_enwik8 && cd tf_enwik8 23 | 24 | mkdir -p data && cd data 25 | download ${URL}/tf_enwiki8/data/cache.pkl 26 | download ${URL}/tf_enwiki8/data/corpus-info.json 27 | cd .. 28 | 29 | mkdir -p model && cd model 30 | download ${URL}/tf_enwiki8/model/checkpoint 31 | download ${URL}/tf_enwiki8/model/model.ckpt-0.data-00000-of-00001 32 | download ${URL}/tf_enwiki8/model/model.ckpt-0.index 33 | download ${URL}/tf_enwiki8/model/model.ckpt-0.meta 34 | cd .. 35 | 36 | cd .. 37 | 38 | # text8 39 | mkdir -p tf_text8 && cd tf_text8 40 | 41 | mkdir -p data && cd data 42 | download ${URL}/tf_text8/data/cache.pkl 43 | download ${URL}/tf_text8/data/corpus-info.json 44 | cd .. 45 | 46 | mkdir -p model && cd model 47 | download ${URL}/tf_text8/model/checkpoint 48 | download ${URL}/tf_text8/model/model.ckpt-0.data-00000-of-00001 49 | download ${URL}/tf_text8/model/model.ckpt-0.index 50 | download ${URL}/tf_text8/model/model.ckpt-0.meta 51 | cd .. 52 | 53 | cd .. 54 | 55 | # wt103 56 | mkdir -p tf_wt103 && cd tf_wt103 57 | 58 | mkdir -p data && cd data 59 | download ${URL}/tf_wt103/data/cache.pkl 60 | download ${URL}/tf_wt103/data/corpus-info.json 61 | cd .. 62 | 63 | mkdir -p model && cd model 64 | download ${URL}/tf_wt103/model/checkpoint 65 | download ${URL}/tf_wt103/model/model.ckpt-0.data-00000-of-00001 66 | download ${URL}/tf_wt103/model/model.ckpt-0.index 67 | download ${URL}/tf_wt103/model/model.ckpt-0.meta 68 | cd .. 69 | 70 | cd .. 71 | 72 | # lm1b 73 | mkdir -p tf_lm1b && cd tf_lm1b 74 | 75 | mkdir -p data && cd data 76 | download ${URL}/tf_lm1b/data/cache.pkl 77 | download ${URL}/tf_lm1b/data/corpus-info.json 78 | cd .. 79 | 80 | mkdir -p model && cd model 81 | download ${URL}/tf_lm1b/model/checkpoint 82 | download ${URL}/tf_lm1b/model/model.ckpt-1191000.data-00000-of-00001 83 | download ${URL}/tf_lm1b/model/model.ckpt-1191000.index 84 | download ${URL}/tf_lm1b/model/model.ckpt-1191000.meta 85 | cd .. 86 | 87 | cd .. 88 | -------------------------------------------------------------------------------- /tf/sota/enwik8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Data 4 | DATA_ROOT=./ 5 | DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_enwik8/data 6 | MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_enwik8/model 7 | 8 | # Model 9 | N_LAYER=24 10 | D_MODEL=1024 11 | D_EMBED=1024 12 | N_HEAD=8 13 | D_HEAD=128 14 | D_INNER=3072 15 | 16 | # Testing 17 | TEST_TGT_LEN=128 18 | TEST_MEM_LEN=3800 19 | TEST_CLAMP_LEN=1000 20 | 21 | TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0 22 | TEST_BSZ=16 23 | TEST_NUM_CORE=2 24 | 25 | 26 | echo 'Preprocess test set...' 27 | python data_utils.py \ 28 | --data_dir=${DATA_DIR}/ \ 29 | --dataset=enwik8 \ 30 | --tgt_len=${TEST_TGT_LEN} \ 31 | --per_host_test_bsz=${TEST_BSZ} \ 32 | --num_passes=1 \ 33 | --use_tpu=False 34 | 35 | echo 'Run evaluation on test set...' 36 | python train_gpu.py \ 37 | --data_dir=${DATA_DIR}/tfrecords \ 38 | --record_info_dir=${DATA_DIR}/tfrecords/ \ 39 | --corpus_info_path=${DATA_DIR}/corpus-info.json \ 40 | --eval_ckpt_path=${TEST_CKPT_PATH} \ 41 | --model_dir=EXP-enwik8 \ 42 | --n_layer=${N_LAYER} \ 43 | --d_model=${D_MODEL} \ 44 | --d_embed=${D_EMBED} \ 45 | --n_head=${N_HEAD} \ 46 | --d_head=${D_HEAD} \ 47 | --d_inner=${D_INNER} \ 48 | --dropout=0.0 \ 49 | --dropatt=0.0 \ 50 | --tgt_len=${TEST_TGT_LEN} \ 51 | --mem_len=${TEST_MEM_LEN} \ 52 | --clamp_len=${TEST_CLAMP_LEN} \ 53 | --same_length=True \ 54 | --eval_batch_size=${TEST_BSZ} \ 55 | --num_core_per_host=${TEST_NUM_CORE} \ 56 | --do_train=False \ 57 | --do_eval=True \ 58 | --eval_split=test 59 | -------------------------------------------------------------------------------- /tf/sota/lm1b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Data 4 | DATA_ROOT=./ 5 | DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_lm1b/data 6 | MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_lm1b/model 7 | 8 | # Model 9 | DIV_VAL=4 10 | N_LAYER=24 11 | D_MODEL=1280 12 | D_EMBED=1280 13 | N_HEAD=16 14 | D_HEAD=80 15 | D_INNER=8192 16 | 17 | # Testing 18 | TEST_TGT_LEN=32 19 | TEST_MEM_LEN=128 20 | TEST_CLAMP_LEN=-1 21 | 22 | TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-1191000 23 | TEST_BSZ=16 24 | TEST_NUM_CORE=1 25 | 26 | 27 | echo 'Preprocess test set...' 28 | python data_utils.py \ 29 | --data_dir=${DATA_DIR}/ \ 30 | --dataset=lm1b \ 31 | --tgt_len=${TEST_TGT_LEN} \ 32 | --per_host_test_bsz=${TEST_BSZ} \ 33 | --num_passes=1 \ 34 | --use_tpu=False 35 | 36 | echo 'Run evaluation on test set...' 37 | python train_gpu.py \ 38 | --data_dir=${DATA_DIR}/tfrecords \ 39 | --record_info_dir=${DATA_DIR}/tfrecords/ \ 40 | --corpus_info_path=${DATA_DIR}/corpus-info.json \ 41 | --eval_ckpt_path=${TEST_CKPT_PATH} \ 42 | --model_dir=EXP-lm1b \ 43 | --div_val=${DIV_VAL} \ 44 | --untie_r=True \ 45 | --proj_share_all_but_first=False \ 46 | --proj_same_dim=False \ 47 | --n_layer=${N_LAYER} \ 48 | --d_model=${D_MODEL} \ 49 | --d_embed=${D_EMBED} \ 50 | --n_head=${N_HEAD} \ 51 | --d_head=${D_HEAD} \ 52 | --d_inner=${D_INNER} \ 53 | --dropout=0.0 \ 54 | --dropatt=0.0 \ 55 | --tgt_len=${TEST_TGT_LEN} \ 56 | --mem_len=${TEST_MEM_LEN} \ 57 | --clamp_len=${TEST_CLAMP_LEN} \ 58 | --same_length=True \ 59 | --eval_batch_size=${TEST_BSZ} \ 60 | --num_core_per_host=${TEST_NUM_CORE} \ 61 | --do_train=False \ 62 | --do_eval=True \ 63 | --eval_split=test 64 | -------------------------------------------------------------------------------- /tf/sota/text8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Data 4 | DATA_ROOT=./ 5 | DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_text8/data 6 | MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_text8/model 7 | 8 | # Model 9 | N_LAYER=24 10 | D_MODEL=1024 11 | D_EMBED=1024 12 | N_HEAD=8 13 | D_HEAD=128 14 | D_INNER=3072 15 | 16 | # Testing 17 | TEST_TGT_LEN=128 18 | TEST_MEM_LEN=3800 19 | TEST_CLAMP_LEN=1000 20 | 21 | TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0 22 | TEST_BSZ=16 23 | TEST_NUM_CORE=2 24 | 25 | 26 | echo 'Preprocess test set...' 27 | python data_utils.py \ 28 | --data_dir=${DATA_DIR}/ \ 29 | --dataset=text8 \ 30 | --tgt_len=${TEST_TGT_LEN} \ 31 | --per_host_test_bsz=${TEST_BSZ} \ 32 | --num_passes=1 \ 33 | --use_tpu=False 34 | 35 | echo 'Run evaluation on test set...' 36 | python train_gpu.py \ 37 | --data_dir=${DATA_DIR}/tfrecords \ 38 | --record_info_dir=${DATA_DIR}/tfrecords/ \ 39 | --corpus_info_path=${DATA_DIR}/corpus-info.json \ 40 | --eval_ckpt_path=${TEST_CKPT_PATH} \ 41 | --model_dir=EXP-text8 \ 42 | --n_layer=${N_LAYER} \ 43 | --d_model=${D_MODEL} \ 44 | --d_embed=${D_EMBED} \ 45 | --n_head=${N_HEAD} \ 46 | --d_head=${D_HEAD} \ 47 | --d_inner=${D_INNER} \ 48 | --dropout=0.0 \ 49 | --dropatt=0.0 \ 50 | --tgt_len=${TEST_TGT_LEN} \ 51 | --mem_len=${TEST_MEM_LEN} \ 52 | --clamp_len=${TEST_CLAMP_LEN} \ 53 | --same_length=True \ 54 | --eval_batch_size=${TEST_BSZ} \ 55 | --num_core_per_host=${TEST_NUM_CORE} \ 56 | --do_train=False \ 57 | --do_eval=True \ 58 | --eval_split=test 59 | -------------------------------------------------------------------------------- /tf/sota/wt103.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Data 4 | DATA_ROOT=./ 5 | DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_wt103/data 6 | MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_wt103/model 7 | 8 | # Model 9 | DIV_VAL=4 10 | N_LAYER=18 11 | D_MODEL=1024 12 | D_EMBED=1024 13 | N_HEAD=16 14 | D_HEAD=64 15 | D_INNER=4096 16 | 17 | # Training 18 | TGT_LEN=256 19 | MEM_LEN=256 20 | 21 | BSZ=16 22 | NUM_CORE=2 23 | 24 | # Testing 25 | TEST_TGT_LEN=128 26 | TEST_MEM_LEN=1600 27 | TEST_CLAMP_LEN=1000 28 | 29 | TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0 30 | TEST_BSZ=16 31 | TEST_NUM_CORE=1 32 | 33 | 34 | echo 'Preprocess test set...' 35 | python data_utils.py \ 36 | --data_dir=${DATA_DIR}/ \ 37 | --dataset=enwik8 \ 38 | --tgt_len=${TEST_TGT_LEN} \ 39 | --per_host_test_bsz=${TEST_BSZ} \ 40 | --num_passes=1 \ 41 | --use_tpu=False 42 | 43 | 44 | echo 'Run evaluation on test set...' 45 | python train_gpu.py \ 46 | --data_dir=${DATA_DIR}/tfrecords \ 47 | --record_info_dir=${DATA_DIR}/tfrecords/ \ 48 | --corpus_info_path=${DATA_DIR}/corpus-info.json \ 49 | --eval_ckpt_path=${TEST_CKPT_PATH} \ 50 | --model_dir=EXP-wt103 \ 51 | --div_val=${DIV_VAL} \ 52 | --untie_r=True \ 53 | --proj_share_all_but_first=True \ 54 | --n_layer=${N_LAYER} \ 55 | --d_model=${D_MODEL} \ 56 | --d_embed=${D_EMBED} \ 57 | --n_head=${N_HEAD} \ 58 | --d_head=${D_HEAD} \ 59 | --d_inner=${D_INNER} \ 60 | --dropout=0.0 \ 61 | --dropatt=0.0 \ 62 | --tgt_len=${TEST_TGT_LEN} \ 63 | --mem_len=${TEST_MEM_LEN} \ 64 | --clamp_len=${TEST_CLAMP_LEN} \ 65 | --same_length=True \ 66 | --eval_batch_size=${TEST_BSZ} \ 67 | --num_core_per_host=${TEST_NUM_CORE} \ 68 | --do_train=False \ 69 | --do_eval=True \ 70 | --eval_split=test 71 | 72 | -------------------------------------------------------------------------------- /tf/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import time 7 | 8 | from absl import flags 9 | import absl.logging as _logging # pylint: disable=unused-import 10 | 11 | from six.moves import xrange # pylint: disable=redefined-builtin 12 | 13 | import tensorflow as tf 14 | from tensorflow.gfile import Exists as exists 15 | import model 16 | import data_utils 17 | import tpu_estimator 18 | 19 | import numpy as np 20 | from time import sleep 21 | 22 | 23 | # TPU parameters 24 | flags.DEFINE_string("master", default=None, 25 | help="master") 26 | flags.DEFINE_string("tpu", default=None, 27 | help="The Cloud TPU to use for training. This should be either the name " 28 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.") 29 | flags.DEFINE_string("gcp_project", default=None, 30 | help="Project name for the Cloud TPU-enabled project. If not specified, " 31 | "we will attempt to automatically detect the GCE project from metadata.") 32 | flags.DEFINE_string("tpu_zone",default=None, 33 | help="GCE zone where the Cloud TPU is located in. If not specified, we " 34 | "will attempt to automatically detect the GCE project from metadata.") 35 | flags.DEFINE_bool("use_tpu", default=True, 36 | help="Use TPUs rather than plain CPUs.") 37 | flags.DEFINE_integer("num_hosts", default=1, 38 | help="number of TPU hosts") 39 | flags.DEFINE_integer("num_core_per_host", default=8, 40 | help="number of cores per host") 41 | 42 | # Experiment (data/checkpoint/directory) parameters 43 | flags.DEFINE_string("data_dir", default="", 44 | help="Path to tf-records directory.") 45 | flags.DEFINE_string("record_info_dir", default="", 46 | help="Path to local directory containing filenames.txt.") 47 | flags.DEFINE_string("corpus_info_path", default="", 48 | help="Path to corpus-info.json file.") 49 | flags.DEFINE_string("model_dir", default=None, 50 | help="Estimator model_dir.") 51 | flags.DEFINE_bool("do_eval", default=False, 52 | help="Whether to run eval on the dev set.") 53 | flags.DEFINE_bool("track_mean", default=True, 54 | help="Trace mean loss during training.") 55 | flags.DEFINE_string("eval_ckpt_path", None, 56 | help="Checkpoint path for evaluation." 57 | "If set, model_dir will be ignored." 58 | "If unset, will use the latest ckpt in model_dir.") 59 | flags.DEFINE_string("warm_start_path", None, 60 | help="Checkpoint path for warm start." 61 | "If set, will clear Adam states." 62 | "Note that the new model_dir should be different" 63 | " from warm_start_path.") 64 | 65 | # Optimization paramenters 66 | flags.DEFINE_float("learning_rate", default=2.5e-4, 67 | help="Maximum learning rate.") 68 | flags.DEFINE_float("clip", default=0.25, 69 | help="Gradient clipping value.") 70 | # for cosine decay 71 | flags.DEFINE_float("min_lr_ratio", default=0.01, 72 | help="Minimum ratio learning rate.") 73 | flags.DEFINE_integer("warmup_steps", default=0, 74 | help="Number of steps for linear lr warmup.") 75 | 76 | # Training parameters 77 | flags.DEFINE_integer("train_batch_size", default=60, 78 | help="Size of train batch.") 79 | flags.DEFINE_integer("eval_batch_size", default=60, 80 | help="Size of valid batch.") 81 | flags.DEFINE_integer("train_steps", default=100000, 82 | help="Total number of training steps.") 83 | flags.DEFINE_integer("iterations", default=500, 84 | help="Number of iterations per repeat loop.") 85 | flags.DEFINE_integer("save_steps", default=10000, 86 | help="number of steps for model checkpointing.") 87 | 88 | # Evaluation parameters 89 | flags.DEFINE_integer("max_eval_batch", default=-1, 90 | help="Set -1 to turn off. Only used in test mode.") 91 | flags.DEFINE_bool("do_eval_only", default=False, 92 | help="Run evaluation only.") 93 | flags.DEFINE_integer("start_eval_steps", default=10000, 94 | help="Which checkpoint to start with in `do_eval_only` mode.") 95 | flags.DEFINE_string("eval_split", "valid", 96 | help="Which data split to evaluate.") 97 | 98 | # Model paramenters 99 | flags.DEFINE_integer("tgt_len", default=70, 100 | help="Number of steps to predict") 101 | flags.DEFINE_integer("mem_len", default=70, 102 | help="Number of steps to cache") 103 | flags.DEFINE_bool("same_length", default=False, 104 | help="Same length attention") 105 | flags.DEFINE_integer("clamp_len", default=-1, 106 | help="Clamp length") 107 | 108 | flags.DEFINE_integer("n_layer", default=6, 109 | help="Number of layers.") 110 | flags.DEFINE_integer("d_model", default=500, 111 | help="Dimension of the model.") 112 | flags.DEFINE_integer("d_embed", default=500, 113 | help="Dimension of the embeddings.") 114 | flags.DEFINE_integer("n_head", default=10, 115 | help="Number of attention heads.") 116 | flags.DEFINE_integer("d_head", default=50, 117 | help="Dimension of each attention head.") 118 | flags.DEFINE_integer("d_inner", default=1000, 119 | help="Dimension of inner hidden size in positionwise feed-forward.") 120 | flags.DEFINE_float("dropout", default=0.1, 121 | help="Dropout rate.") 122 | flags.DEFINE_float("dropatt", default=0.1, 123 | help="Attention dropout rate.") 124 | flags.DEFINE_bool("untie_r", default=False, 125 | help="untie r_w_bias and r_r_bias") 126 | 127 | # Adaptive Softmax / Embedding 128 | flags.DEFINE_bool("tie_weight", default=True, 129 | help="Tie embedding and softmax weight.") 130 | flags.DEFINE_integer("div_val", default=1, 131 | help="Divide the embedding size by this val for each bin") 132 | flags.DEFINE_bool("proj_share_all_but_first", default=False, 133 | help="True to share all but first projs, False not to share.") 134 | flags.DEFINE_bool("proj_same_dim", default=True, 135 | help="Project the bin with the same dimension.") 136 | 137 | # Parameter initialization 138 | flags.DEFINE_enum("init", default="normal", 139 | enum_values=["normal", "uniform"], 140 | help="Initialization method.") 141 | flags.DEFINE_float("init_std", default=0.02, 142 | help="Initialization std when init is normal.") 143 | flags.DEFINE_float("proj_init_std", default=0.01, 144 | help="Initialization std for embedding projection.") 145 | flags.DEFINE_float("init_range", default=0.1, 146 | help="Initialization std when init is uniform.") 147 | 148 | 149 | FLAGS = flags.FLAGS 150 | 151 | def metric_fn(loss): 152 | """Evaluation metric Fn which runs on CPU.""" 153 | perplexity = tf.exp(tf.reduce_mean(loss)) 154 | bpc = tf.reduce_mean(loss) / tf.constant(math.log(2)) 155 | return { 156 | "perplexity": tf.metrics.mean(perplexity), 157 | "bpc": tf.metrics.mean(bpc), 158 | } 159 | 160 | 161 | def get_model_fn(n_token, cutoffs, train_bin_sizes, eval_bin_sizes): 162 | def model_fn(features, labels, mode, params): 163 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 164 | 165 | 166 | batch_size = params["batch_size"] 167 | 168 | mems = params["cache"] 169 | inp = tf.transpose(features["inputs"], [1, 0]) 170 | tgt = tf.transpose(features["labels"], [1, 0]) 171 | 172 | bin_sizes = train_bin_sizes if is_training else eval_bin_sizes 173 | if bin_sizes: 174 | inp_perms = [tf.transpose(features["inp_mask"], [1, 0])] 175 | tgt_perms = [tf.transpose(features["tgt_mask"], [1, 0])] 176 | 177 | head_tgt = tf.transpose(features["head_labels"], [1, 0]) 178 | 179 | for b in range(len(bin_sizes)): 180 | inp_perm = tf.transpose(features["inp_perm_{}".format(b)], [1, 0, 2]) 181 | tgt_perm = tf.transpose(features["tgt_perm_{}".format(b)], [1, 0, 2]) 182 | 183 | inp_perms.append(inp_perm) 184 | tgt_perms.append(tgt_perm) 185 | else: 186 | inp_perms, tgt_perms, head_tgt = None, None, None 187 | 188 | if FLAGS.init == "uniform": 189 | initializer = tf.initializers.random_uniform( 190 | minval=-FLAGS.init_range, 191 | maxval=FLAGS.init_range, 192 | seed=None) 193 | elif FLAGS.init == "normal": 194 | initializer = tf.initializers.random_normal( 195 | stddev=FLAGS.init_std, 196 | seed=None) 197 | proj_initializer = tf.initializers.random_normal( 198 | stddev=FLAGS.proj_init_std, 199 | seed=None) 200 | 201 | tie_projs = [False for _ in range(len(cutoffs) + 1)] 202 | if FLAGS.proj_share_all_but_first: 203 | for i in range(1, len(tie_projs)): 204 | tie_projs[i] = True 205 | 206 | tf.logging.info("Vocab size : {}".format(n_token)) 207 | tf.logging.info("Batch size : {}".format(batch_size)) 208 | 209 | loss, new_mems = model.transformer( 210 | dec_inp=inp, 211 | target=tgt, 212 | mems=mems, 213 | n_token=n_token, 214 | n_layer=FLAGS.n_layer, 215 | d_model=FLAGS.d_model, 216 | d_embed=FLAGS.d_embed, 217 | n_head=FLAGS.n_head, 218 | d_head=FLAGS.d_head, 219 | d_inner=FLAGS.d_inner, 220 | dropout=FLAGS.dropout, 221 | dropatt=FLAGS.dropatt, 222 | initializer=initializer, 223 | is_training=is_training, 224 | mem_len=FLAGS.mem_len, 225 | cutoffs=cutoffs, 226 | div_val=FLAGS.div_val, 227 | tie_projs=tie_projs, 228 | input_perms=inp_perms, 229 | target_perms=tgt_perms, 230 | head_target=head_tgt, 231 | same_length=FLAGS.same_length, 232 | clamp_len=FLAGS.clamp_len, 233 | use_tpu=FLAGS.use_tpu, 234 | untie_r=FLAGS.untie_r, 235 | proj_same_dim=FLAGS.proj_same_dim) 236 | 237 | total_loss = tf.reduce_mean(loss) 238 | 239 | if mode == tf.estimator.ModeKeys.EVAL: 240 | if FLAGS.use_tpu: 241 | with tf.colocate_with(total_loss): 242 | total_loss = tf.contrib.tpu.cross_replica_sum(total_loss) \ 243 | / FLAGS.num_hosts / FLAGS.num_core_per_host 244 | metric_loss = tf.tile(tf.reshape(total_loss, [1, 1]), [batch_size, 1]) 245 | eval_spec = tf.contrib.tpu.TPUEstimatorSpec( 246 | mode=mode, 247 | loss=total_loss, 248 | eval_metrics=(metric_fn, [metric_loss])) 249 | 250 | eval_spec.cache = new_mems 251 | 252 | return eval_spec 253 | 254 | # Configuring the optimization step. 255 | global_step = tf.train.get_global_step() 256 | 257 | # increase the learning rate linearly 258 | if FLAGS.warmup_steps > 0: 259 | warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \ 260 | * FLAGS.learning_rate 261 | else: 262 | warmup_lr = 0.0 263 | 264 | # number of parameters 265 | num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) 266 | tf.logging.info("#params: {}".format(num_params)) 267 | 268 | # format_str = '{{:<{0}s}}\t{{}}'.format( 269 | # max([len(v.name) for v in tf.trainable_variables()])) 270 | # for v in tf.trainable_variables(): 271 | # tf.logging.info(format_str.format(v.name, v.get_shape())) 272 | 273 | 274 | # decay the learning rate using the cosine schedule 275 | decay_lr = tf.train.cosine_decay( 276 | FLAGS.learning_rate, 277 | global_step=global_step-FLAGS.warmup_steps, 278 | decay_steps=FLAGS.train_steps-FLAGS.warmup_steps, 279 | alpha=FLAGS.min_lr_ratio) 280 | 281 | learning_rate = tf.where(global_step < FLAGS.warmup_steps, 282 | warmup_lr, decay_lr) 283 | 284 | if FLAGS.use_tpu: 285 | optimizer = tf.contrib.tpu.CrossShardOptimizer( 286 | tf.train.AdamOptimizer(learning_rate=learning_rate)) 287 | #GradientDescentOptimizer 288 | else: 289 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 290 | 291 | grads_and_vars = optimizer.compute_gradients(total_loss) 292 | gradients, variables = zip(*grads_and_vars) 293 | clipped, _ = tf.clip_by_global_norm(gradients, FLAGS.clip) 294 | train_op = optimizer.apply_gradients( 295 | zip(clipped, variables), global_step=tf.train.get_global_step()) 296 | 297 | # Constucting TPUEstimatorSpec with cache. 298 | train_spec = tf.contrib.tpu.TPUEstimatorSpec( 299 | mode=mode, loss=total_loss, train_op=train_op) 300 | 301 | if FLAGS.mem_len < FLAGS.tgt_len: 302 | new_mems = [new_mems[: FLAGS.mem_len] for mem_t in new_mems] 303 | train_spec.cache = new_mems 304 | 305 | return train_spec 306 | 307 | return model_fn 308 | 309 | 310 | def get_cache_fn(mem_len): 311 | 312 | def cache_fn(batch_size): 313 | mems = [] 314 | for l in xrange(FLAGS.n_layer): 315 | if mem_len > 0: 316 | mems.append( 317 | tf.zeros([mem_len, batch_size, FLAGS.d_model], dtype=tf.float32)) 318 | else: 319 | mems.append(tf.zeros([mem_len], dtype=tf.float32)) 320 | 321 | return mems 322 | 323 | return cache_fn 324 | 325 | 326 | def main(unused_argv): 327 | del unused_argv # Unused 328 | 329 | tf.logging.set_verbosity(tf.logging.INFO) 330 | 331 | # Get corpus info 332 | corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path) 333 | n_token = corpus_info["vocab_size"] 334 | cutoffs = corpus_info["cutoffs"][1:-1] 335 | 336 | if FLAGS.save_steps == 0: 337 | FLAGS.save_steps = None 338 | 339 | if not FLAGS.do_eval_only: 340 | # Get train input function 341 | train_input_fn, train_record_info = data_utils.get_input_fn( 342 | record_info_dir=FLAGS.record_info_dir, 343 | split="train", 344 | per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts, 345 | tgt_len=FLAGS.tgt_len, 346 | num_core_per_host=FLAGS.num_core_per_host, 347 | num_hosts=FLAGS.num_hosts, 348 | use_tpu=FLAGS.use_tpu) 349 | train_bin_sizes = train_record_info["bin_sizes"] 350 | num_train_batch = train_record_info["num_batch"] 351 | 352 | # Get train cache function 353 | train_cache_fn = get_cache_fn(FLAGS.mem_len) 354 | else: 355 | train_bin_sizes = [] 356 | num_train_batch = None 357 | train_cache_fn = None 358 | 359 | if FLAGS.do_eval or FLAGS.do_eval_only: 360 | assert FLAGS.num_hosts == 1 361 | # Get eval input function 362 | eval_input_fn, eval_record_info = data_utils.get_input_fn( 363 | record_info_dir=FLAGS.record_info_dir, 364 | split=FLAGS.eval_split, 365 | per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts, 366 | tgt_len=FLAGS.tgt_len, 367 | num_core_per_host=FLAGS.num_core_per_host, 368 | num_hosts=FLAGS.num_hosts, 369 | use_tpu=FLAGS.use_tpu) 370 | eval_bin_sizes = eval_record_info["bin_sizes"] 371 | num_eval_batch = eval_record_info["num_batch"] 372 | 373 | if FLAGS.max_eval_batch > 0: 374 | num_eval_batch = min(FLAGS.max_eval_batch, num_eval_batch) 375 | 376 | # Get eval cache function 377 | eval_cache_fn = get_cache_fn(FLAGS.mem_len) 378 | model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, eval_bin_sizes) 379 | else: 380 | eval_cache_fn = None 381 | model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, []) 382 | 383 | ##### Create estimator 384 | # TPU Configuration 385 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 386 | FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 387 | 388 | per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 389 | run_config = tf.contrib.tpu.RunConfig( 390 | cluster=tpu_cluster_resolver, 391 | model_dir=FLAGS.model_dir, 392 | session_config=tf.ConfigProto( 393 | allow_soft_placement=True, log_device_placement=True), 394 | tpu_config=tf.contrib.tpu.TPUConfig( 395 | iterations_per_loop=FLAGS.iterations, 396 | num_shards=FLAGS.num_core_per_host * FLAGS.num_hosts, 397 | per_host_input_for_training=per_host_input), 398 | keep_checkpoint_max=100000, # effectively save all checkpoints 399 | save_checkpoints_secs=None, 400 | save_checkpoints_steps=FLAGS.save_steps 401 | ) 402 | 403 | # warm start 404 | warm_start_from = None 405 | if FLAGS.warm_start_path is not None: 406 | warm_start_from = tf.estimator.WarmStartSettings( 407 | ckpt_to_initialize_from=FLAGS.warm_start_path) 408 | 409 | # TPU Estimator 410 | estimator = tpu_estimator.TPUEstimator( 411 | model_fn=model_fn, 412 | train_cache_fn=train_cache_fn, 413 | eval_cache_fn=eval_cache_fn, 414 | use_tpu=FLAGS.use_tpu, 415 | config=run_config, 416 | params={"data_dir":FLAGS.data_dir, "track_mean":FLAGS.track_mean}, 417 | train_batch_size=FLAGS.train_batch_size, 418 | eval_batch_size=FLAGS.eval_batch_size, 419 | warm_start_from=warm_start_from) 420 | 421 | if FLAGS.do_eval_only: 422 | if FLAGS.eval_ckpt_path is not None: 423 | ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch, 424 | checkpoint_path=FLAGS.eval_ckpt_path) 425 | tf.logging.info("=" * 200) 426 | log_str = "Eval results | " 427 | for key, val in ret.items(): 428 | log_str += "{} {} | ".format(key, val) 429 | tf.logging.info(log_str) 430 | tf.logging.info("=" * 200) 431 | else: 432 | ckpt_state = tf.train.get_checkpoint_state(FLAGS.model_dir) 433 | eval_results = [] 434 | for eval_checkpoint in ckpt_state.all_model_checkpoint_paths: 435 | if not exists(eval_checkpoint + ".index"): continue 436 | global_step = int(eval_checkpoint.split("-")[-1]) 437 | if global_step < FLAGS.start_eval_steps or global_step > FLAGS.train_steps: 438 | continue 439 | ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch, 440 | checkpoint_path=eval_checkpoint) 441 | eval_results.append(ret) 442 | 443 | eval_results.sort(key = lambda x: x["perplexity"]) 444 | 445 | tf.logging.info("=" * 200) 446 | log_str = "Best results | " 447 | for key, val in eval_results[0].items(): 448 | log_str += "{} {} | ".format(key, val) 449 | tf.logging.info(log_str) 450 | tf.logging.info("=" * 200) 451 | else: 452 | if not FLAGS.do_eval: 453 | estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps) 454 | else: 455 | for step in range(0, FLAGS.train_steps, num_train_batch): 456 | train_steps = min(FLAGS.train_steps - step, num_train_batch) 457 | estimator.train(input_fn=train_input_fn, steps=train_steps) 458 | estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch) 459 | 460 | 461 | if __name__ == "__main__": 462 | tf.app.run() 463 | -------------------------------------------------------------------------------- /tf/train_gpu.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import math 7 | import time 8 | 9 | from absl import flags 10 | import absl.logging as _logging # pylint: disable=unused-import 11 | 12 | import tensorflow as tf 13 | import model 14 | import data_utils 15 | 16 | from gpu_utils import assign_to_gpu, average_grads_and_vars 17 | 18 | import numpy as np 19 | 20 | # GPU config 21 | flags.DEFINE_integer("num_hosts", default=1, 22 | help="Number of TPU hosts") 23 | flags.DEFINE_integer("num_core_per_host", default=8, 24 | help="Number of cores per host") 25 | 26 | # Experiment (data/checkpoint/directory) config 27 | flags.DEFINE_string("data_dir", default="", 28 | help="Path to tf-records directory.") 29 | flags.DEFINE_string("record_info_dir", default="", 30 | help="Path to local directory containing filenames.txt.") 31 | flags.DEFINE_string("corpus_info_path", default="", 32 | help="Path to corpus-info.json file.") 33 | flags.DEFINE_string("model_dir", default=None, 34 | help="Estimator model_dir.") 35 | flags.DEFINE_bool("do_train", default=True, 36 | help="Whether to run training.") 37 | flags.DEFINE_bool("do_eval", default=False, 38 | help="Whether to run eval on the dev set.") 39 | flags.DEFINE_string("eval_ckpt_path", None, 40 | help="Checkpoint path for do_test evaluation." 41 | "If set, model_dir will be ignored." 42 | "If unset, will use the latest ckpt in model_dir.") 43 | flags.DEFINE_string("warm_start_path", None, 44 | help="Checkpoint path for warm start." 45 | "If set, will clear Adam states." 46 | "Note that the new model_dir should be different" 47 | " from warm_start_path.") 48 | 49 | # Optimization config 50 | flags.DEFINE_float("learning_rate", default=2.5e-4, 51 | help="Maximum learning rate.") 52 | flags.DEFINE_float("clip", default=0.25, 53 | help="Gradient clipping value.") 54 | # for cosine decay 55 | flags.DEFINE_float("min_lr_ratio", default=0.004, 56 | help="Minimum ratio learning rate.") 57 | flags.DEFINE_integer("warmup_steps", default=0, 58 | help="Number of steps for linear lr warmup.") 59 | 60 | # Training config 61 | flags.DEFINE_integer("train_batch_size", default=60, 62 | help="Size of train batch.") 63 | flags.DEFINE_integer("eval_batch_size", default=60, 64 | help="Size of valid batch.") 65 | flags.DEFINE_integer("train_steps", default=100000, 66 | help="Total number of training steps.") 67 | flags.DEFINE_integer("iterations", default=500, 68 | help="Number of iterations per repeat loop.") 69 | flags.DEFINE_integer("save_steps", default=10000, 70 | help="number of steps for model checkpointing.") 71 | 72 | # Evaluation config 73 | flags.DEFINE_bool("do_test", default=False, 74 | help="Run on the test set.") 75 | flags.DEFINE_integer("max_eval_batch", default=-1, 76 | help="Set -1 to turn off. Only used in test mode.") 77 | flags.DEFINE_bool("do_eval_only", default=False, 78 | help="Run evaluation only.") 79 | flags.DEFINE_integer("start_eval_steps", default=10000, 80 | help="Which checkpoint to start with in `do_eval_only` mode.") 81 | flags.DEFINE_string("eval_split", "valid", 82 | help="Which data split to evaluate.") 83 | 84 | # Model config 85 | flags.DEFINE_integer("tgt_len", default=70, 86 | help="Number of steps to predict") 87 | flags.DEFINE_integer("mem_len", default=70, 88 | help="Number of steps to cache") 89 | flags.DEFINE_bool("same_length", default=False, 90 | help="Same length attention") 91 | flags.DEFINE_integer("clamp_len", default=-1, 92 | help="Clamp length") 93 | 94 | flags.DEFINE_integer("n_layer", default=6, 95 | help="Number of layers.") 96 | flags.DEFINE_integer("d_model", default=500, 97 | help="Dimension of the model.") 98 | flags.DEFINE_integer("d_embed", default=500, 99 | help="Dimension of the embeddings.") 100 | flags.DEFINE_integer("n_head", default=10, 101 | help="Number of attention heads.") 102 | flags.DEFINE_integer("d_head", default=50, 103 | help="Dimension of each attention head.") 104 | flags.DEFINE_integer("d_inner", default=1000, 105 | help="Dimension of inner hidden size in positionwise feed-forward.") 106 | flags.DEFINE_float("dropout", default=0.1, 107 | help="Dropout rate.") 108 | flags.DEFINE_float("dropatt", default=0.1, 109 | help="Attention dropout rate.") 110 | flags.DEFINE_bool("untie_r", default=False, 111 | help="untie r_w_bias and r_r_bias") 112 | 113 | # Adaptive Softmax / Embedding 114 | flags.DEFINE_bool("tie_weight", default=True, 115 | help="Tie embedding and softmax weight.") 116 | flags.DEFINE_integer("div_val", default=1, 117 | help="Divide the embedding size by this val for each bin") 118 | flags.DEFINE_bool("proj_share_all_but_first", default=False, 119 | help="True to share all but first projs, False not to share.") 120 | flags.DEFINE_bool("proj_same_dim", default=True, 121 | help="Project the bin with the same dimension.") 122 | 123 | # Parameter initialization 124 | flags.DEFINE_enum("init", default="normal", 125 | enum_values=["normal", "uniform"], 126 | help="Initialization method.") 127 | flags.DEFINE_float("init_std", default=0.02, 128 | help="Initialization std when init is normal.") 129 | flags.DEFINE_float("proj_init_std", default=0.01, 130 | help="Initialization std for embedding projection.") 131 | flags.DEFINE_float("init_range", default=0.1, 132 | help="Initialization std when init is uniform.") 133 | 134 | FLAGS = flags.FLAGS 135 | 136 | def get_model_fn(n_token, cutoffs): 137 | def model_fn(inp, tgt, mems, is_training): 138 | inp = tf.transpose(inp, [1, 0]) 139 | tgt = tf.transpose(tgt, [1, 0]) 140 | 141 | if FLAGS.init == "uniform": 142 | initializer = tf.initializers.random_uniform( 143 | minval=-FLAGS.init_range, 144 | maxval=FLAGS.init_range, 145 | seed=None) 146 | elif FLAGS.init == "normal": 147 | initializer = tf.initializers.random_normal( 148 | stddev=FLAGS.init_std, 149 | seed=None) 150 | proj_initializer = tf.initializers.random_normal( 151 | stddev=FLAGS.proj_init_std, 152 | seed=None) 153 | 154 | tie_projs = [False for _ in range(len(cutoffs) + 1)] 155 | if FLAGS.proj_share_all_but_first: 156 | for i in range(1, len(tie_projs)): 157 | tie_projs[i] = True 158 | 159 | loss, new_mems = model.transformer( 160 | dec_inp=inp, 161 | target=tgt, 162 | mems=mems, 163 | n_token=n_token, 164 | n_layer=FLAGS.n_layer, 165 | d_model=FLAGS.d_model, 166 | d_embed=FLAGS.d_embed, 167 | n_head=FLAGS.n_head, 168 | d_head=FLAGS.d_head, 169 | d_inner=FLAGS.d_inner, 170 | dropout=FLAGS.dropout, 171 | dropatt=FLAGS.dropatt, 172 | initializer=initializer, 173 | proj_initializer=proj_initializer, 174 | is_training=is_training, 175 | mem_len=FLAGS.mem_len, 176 | cutoffs=cutoffs, 177 | div_val=FLAGS.div_val, 178 | tie_projs=tie_projs, 179 | input_perms=None, 180 | target_perms=None, 181 | head_target=None, 182 | same_length=FLAGS.same_length, 183 | clamp_len=FLAGS.clamp_len, 184 | use_tpu=False, 185 | untie_r=FLAGS.untie_r, 186 | proj_same_dim=FLAGS.proj_same_dim) 187 | 188 | # number of parameters 189 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) 190 | tf.logging.info('#params: {}'.format(num_params)) 191 | 192 | # format_str = '{{:<{0}s}}\t{{}}'.format( 193 | # max([len(v.name) for v in tf.trainable_variables()])) 194 | # for v in tf.trainable_variables(): 195 | # tf.logging.info(format_str.format(v.name, v.get_shape())) 196 | 197 | if is_training: 198 | all_vars = tf.trainable_variables() 199 | grads = tf.gradients(loss, all_vars) 200 | grads_and_vars = list(zip(grads, all_vars)) 201 | 202 | return loss, new_mems, grads_and_vars 203 | else: 204 | return loss, new_mems 205 | 206 | return model_fn 207 | 208 | 209 | def single_core_graph(n_token, cutoffs, is_training, inp, tgt, mems): 210 | model_fn = get_model_fn( 211 | n_token=n_token, 212 | cutoffs=cutoffs) 213 | 214 | model_ret = model_fn( 215 | inp=inp, 216 | tgt=tgt, 217 | mems=mems, 218 | is_training=is_training) 219 | 220 | return model_ret 221 | 222 | 223 | def train(n_token, cutoffs, ps_device): 224 | ##### Get input function and model function 225 | train_input_fn, train_record_info = data_utils.get_input_fn( 226 | record_info_dir=FLAGS.record_info_dir, 227 | split="train", 228 | per_host_bsz=FLAGS.train_batch_size, 229 | tgt_len=FLAGS.tgt_len, 230 | num_core_per_host=FLAGS.num_core_per_host, 231 | num_hosts=1, 232 | use_tpu=False) 233 | 234 | tf.logging.info("num of batches {}".format(train_record_info["num_batch"])) 235 | 236 | ##### Create computational graph 237 | train_set = train_input_fn({ 238 | "batch_size": FLAGS.train_batch_size, 239 | "data_dir": FLAGS.data_dir}) 240 | 241 | input_feed, label_feed = train_set.make_one_shot_iterator().get_next() 242 | 243 | inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0) 244 | labels = tf.split(label_feed, FLAGS.num_core_per_host, 0) 245 | 246 | per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host 247 | 248 | tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], [] 249 | 250 | for i in range(FLAGS.num_core_per_host): 251 | reuse = True if i > 0 else None 252 | with tf.device(assign_to_gpu(i, ps_device)), \ 253 | tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 254 | 255 | mems_i = [tf.placeholder(tf.float32, 256 | [FLAGS.mem_len, per_core_bsz, FLAGS.d_model]) 257 | for _ in range(FLAGS.n_layer)] 258 | 259 | loss_i, new_mems_i, grads_and_vars_i = single_core_graph( 260 | n_token=n_token, 261 | cutoffs=cutoffs, 262 | is_training=True, 263 | inp=inputs[i], 264 | tgt=labels[i], 265 | mems=mems_i) 266 | 267 | tower_mems.append(mems_i) 268 | tower_losses.append(loss_i) 269 | tower_new_mems.append(new_mems_i) 270 | tower_grads_and_vars.append(grads_and_vars_i) 271 | 272 | ## average losses and gradients across towers 273 | if len(tower_losses) > 1: 274 | loss = tf.add_n(tower_losses) / len(tower_losses) 275 | grads_and_vars = average_grads_and_vars(tower_grads_and_vars) 276 | else: 277 | loss = tower_losses[0] 278 | grads_and_vars = tower_grads_and_vars[0] 279 | grads, all_vars = zip(*grads_and_vars) 280 | 281 | ## clip gradient 282 | clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip) 283 | grads_and_vars = list(zip(clipped, all_vars)) 284 | 285 | ## configure the optimizer 286 | global_step = tf.train.get_or_create_global_step() 287 | 288 | # warmup stage: increase the learning rate linearly 289 | if FLAGS.warmup_steps > 0: 290 | warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \ 291 | * FLAGS.learning_rate 292 | else: 293 | warmup_lr = 0.0 294 | 295 | # decay stage: decay the learning rate using the cosine schedule 296 | decay_lr = tf.train.cosine_decay( 297 | FLAGS.learning_rate, 298 | global_step=global_step-FLAGS.warmup_steps, 299 | decay_steps=FLAGS.train_steps-FLAGS.warmup_steps, 300 | alpha=FLAGS.min_lr_ratio) 301 | 302 | # choose warmup or decay 303 | learning_rate = tf.where(global_step < FLAGS.warmup_steps, 304 | warmup_lr, decay_lr) 305 | 306 | # get the train op 307 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 308 | train_op = optimizer.apply_gradients(grads_and_vars, global_step) 309 | 310 | ##### Training loop 311 | tower_mems_np = [ 312 | [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32) 313 | for layer in range(FLAGS.n_layer)] 314 | for core in range(FLAGS.num_core_per_host) 315 | ] 316 | 317 | saver = tf.train.Saver() 318 | 319 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 320 | sess.run(tf.global_variables_initializer()) 321 | 322 | if FLAGS.warm_start_path is not None: 323 | tf.logging.info("warm start from {}".format(FLAGS.warm_start_path)) 324 | saver.restore(sess, FLAGS.warm_start_path) 325 | 326 | fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op] 327 | 328 | total_loss, prev_step = 0., -1 329 | while True: 330 | feed_dict = {} 331 | for i in range(FLAGS.num_core_per_host): 332 | for m, m_np in zip(tower_mems[i], tower_mems_np[i]): 333 | feed_dict[m] = m_np 334 | 335 | fetched = sess.run(fetches, feed_dict=feed_dict) 336 | 337 | loss_np, tower_mems_np, curr_step = fetched[:3] 338 | total_loss += loss_np 339 | 340 | if curr_step > 0 and curr_step % FLAGS.iterations == 0: 341 | curr_loss = total_loss / (curr_step - prev_step) 342 | tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} " 343 | "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format( 344 | curr_step, fetched[-3], fetched[-2], 345 | curr_loss, math.exp(curr_loss), curr_loss / math.log(2))) 346 | total_loss, prev_step = 0., curr_step 347 | 348 | if curr_step > 0 and curr_step % FLAGS.save_steps == 0: 349 | save_path = os.path.join(FLAGS.model_dir, "model.ckpt") 350 | saver.save(sess, save_path) 351 | tf.logging.info("Model saved in path: {}".format(save_path)) 352 | 353 | if curr_step == FLAGS.train_steps: 354 | break 355 | 356 | 357 | def evaluate(n_token, cutoffs, ps_device): 358 | ##### Get input function and model function 359 | eval_input_fn, eval_record_info = data_utils.get_input_fn( 360 | record_info_dir=FLAGS.record_info_dir, 361 | split=FLAGS.eval_split, 362 | per_host_bsz=FLAGS.eval_batch_size, 363 | tgt_len=FLAGS.tgt_len, 364 | num_core_per_host=FLAGS.num_core_per_host, 365 | num_hosts=1, 366 | use_tpu=False) 367 | 368 | num_batch = eval_record_info["num_batch"] 369 | if FLAGS.max_eval_batch > 0: 370 | num_batch = FLAGS.max_eval_batch 371 | tf.logging.info("num of batches {}".format(num_batch)) 372 | 373 | ##### Create computational graph 374 | eval_set = eval_input_fn({ 375 | "batch_size": FLAGS.eval_batch_size, 376 | "data_dir": FLAGS.data_dir}) 377 | 378 | input_feed, label_feed = eval_set.make_one_shot_iterator().get_next() 379 | 380 | inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0) 381 | labels = tf.split(label_feed, FLAGS.num_core_per_host, 0) 382 | 383 | per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host 384 | tower_mems, tower_losses, tower_new_mems = [], [], [] 385 | 386 | for i in range(FLAGS.num_core_per_host): 387 | with tf.device(assign_to_gpu(i, ps_device)), \ 388 | tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): 389 | 390 | mems_i = [tf.placeholder(tf.float32, 391 | [FLAGS.mem_len, per_core_bsz, FLAGS.d_model]) 392 | for _ in range(FLAGS.n_layer)] 393 | 394 | loss_i, new_mems_i = single_core_graph( 395 | n_token=n_token, 396 | cutoffs=cutoffs, 397 | is_training=False, 398 | inp=inputs[i], 399 | tgt=labels[i], 400 | mems=mems_i) 401 | 402 | tower_mems.append(mems_i) 403 | tower_losses.append(loss_i) 404 | tower_new_mems.append(new_mems_i) 405 | 406 | ## sum losses across towers 407 | if len(tower_losses) > 1: 408 | loss = tf.add_n(tower_losses) / len(tower_losses) 409 | else: 410 | loss = tower_losses[0] 411 | 412 | ##### Evaluation loop 413 | tower_mems_np = [ 414 | [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32) 415 | for layer in range(FLAGS.n_layer)] 416 | for core in range(FLAGS.num_core_per_host) 417 | ] 418 | 419 | saver = tf.train.Saver() 420 | 421 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 422 | sess.run(tf.global_variables_initializer()) 423 | 424 | if FLAGS.eval_ckpt_path is None: 425 | eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir) 426 | else: 427 | eval_ckpt_path = FLAGS.eval_ckpt_path 428 | tf.logging.info("Evaluate {}".format(eval_ckpt_path)) 429 | saver.restore(sess, eval_ckpt_path) 430 | 431 | fetches = [loss, tower_new_mems, tf.size(label_feed)] 432 | 433 | format_str = " >> processing batch {{:{0}d}}/{{:{0}d}} ..".format( 434 | len(str(num_batch))) 435 | 436 | total_loss, total_cnt = 0, 0 437 | for step in range(num_batch): 438 | if step % (num_batch // 10) == 0: 439 | tf.logging.info(format_str.format(step, num_batch)) 440 | 441 | feed_dict = {} 442 | for i in range(FLAGS.num_core_per_host): 443 | for m, m_np in zip(tower_mems[i], tower_mems_np[i]): 444 | feed_dict[m] = m_np 445 | 446 | fetched = sess.run(fetches, feed_dict=feed_dict) 447 | 448 | loss_np, tower_mems_np, cnt_np = fetched[:3] 449 | total_loss += loss_np * cnt_np 450 | total_cnt += cnt_np 451 | 452 | avg_loss = total_loss / total_cnt 453 | tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format( 454 | avg_loss, math.exp(avg_loss), avg_loss / math.log(2))) 455 | 456 | 457 | def main(unused_argv): 458 | del unused_argv # Unused 459 | 460 | tf.logging.set_verbosity(tf.logging.INFO) 461 | 462 | # Get corpus info 463 | corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path) 464 | n_token = corpus_info["vocab_size"] 465 | cutoffs = corpus_info["cutoffs"][1:-1] 466 | tf.logging.info("n_token {}".format(n_token)) 467 | 468 | if FLAGS.do_train: 469 | train(n_token, cutoffs, "/gpu:0") 470 | if FLAGS.do_eval: 471 | evaluate(n_token, cutoffs, "/gpu:0") 472 | 473 | 474 | if __name__ == "__main__": 475 | tf.app.run() 476 | -------------------------------------------------------------------------------- /tf/vocabulary.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import Counter, OrderedDict 6 | 7 | import numpy as np 8 | 9 | import tensorflow as tf 10 | 11 | from tensorflow.gfile import Open as open 12 | from tensorflow.gfile import Exists as exists 13 | 14 | class Vocab(object): 15 | def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, 16 | delimiter=None, vocab_file=None): 17 | self.counter = Counter() 18 | self.special = special 19 | self.min_freq = min_freq 20 | self.max_size = max_size 21 | self.lower_case = lower_case 22 | self.delimiter = delimiter 23 | self.vocab_file = vocab_file 24 | 25 | def tokenize(self, line, add_eos=False, add_double_eos=False): 26 | line = line.strip() 27 | # convert to lower case 28 | if self.lower_case: 29 | line = line.lower() 30 | 31 | # empty delimiter '' will evaluate False 32 | if self.delimiter == '': 33 | symbols = line 34 | else: 35 | symbols = line.split(self.delimiter) 36 | 37 | if add_double_eos: # lm1b 38 | return [''] + symbols + [''] 39 | elif add_eos: 40 | return symbols + [''] 41 | else: 42 | return symbols 43 | 44 | def count_file(self, path, verbose=False, add_eos=False): 45 | if verbose: print('counting file {} ...'.format(path)) 46 | assert exists(path) 47 | 48 | sents = [] 49 | with open(path, 'r') as f: 50 | for idx, line in enumerate(f): 51 | if verbose and idx > 0 and idx % 500000 == 0: 52 | print(' line {}'.format(idx)) 53 | symbols = self.tokenize(line, add_eos=add_eos) 54 | self.counter.update(symbols) 55 | sents.append(symbols) 56 | 57 | return sents 58 | 59 | def count_sents(self, sents, verbose=False): 60 | """ 61 | sents : a list of sentences, each a list of tokenized symbols 62 | """ 63 | if verbose: print('counting {} sents ...'.format(len(sents))) 64 | for idx, symbols in enumerate(sents): 65 | if verbose and idx > 0 and idx % 500000 == 0: 66 | print(' line {}'.format(idx)) 67 | self.counter.update(symbols) 68 | 69 | def _build_from_file(self, vocab_file): 70 | self.idx2sym = [] 71 | self.sym2idx = OrderedDict() 72 | 73 | with open(vocab_file, 'r') as f: 74 | for line in f: 75 | symb = line.strip().split()[0] 76 | self.add_symbol(symb) 77 | self.unk_idx = self.sym2idx[''] 78 | 79 | def build_vocab(self): 80 | if self.vocab_file: 81 | print('building vocab from {}'.format(self.vocab_file)) 82 | self._build_from_file(self.vocab_file) 83 | print('final vocab size {}'.format(len(self))) 84 | else: 85 | print('building vocab with min_freq={}, max_size={}'.format( 86 | self.min_freq, self.max_size)) 87 | self.idx2sym = [] 88 | self.sym2idx = OrderedDict() 89 | 90 | for sym in self.special: 91 | self.add_special(sym) 92 | 93 | for sym, cnt in self.counter.most_common(self.max_size): 94 | if cnt < self.min_freq: break 95 | self.add_symbol(sym) 96 | 97 | print('final vocab size {} from {} unique tokens'.format( 98 | len(self), len(self.counter))) 99 | 100 | def encode_file(self, path, ordered=False, verbose=False, add_eos=True, 101 | add_double_eos=False): 102 | if verbose: print('encoding file {} ...'.format(path)) 103 | assert exists(path) 104 | encoded = [] 105 | with open(path, 'r') as f: 106 | for idx, line in enumerate(f): 107 | if verbose and idx > 0 and idx % 500000 == 0: 108 | print(' line {}'.format(idx)) 109 | symbols = self.tokenize(line, add_eos=add_eos, 110 | add_double_eos=add_double_eos) 111 | encoded.append(self.convert_to_nparray(symbols)) 112 | 113 | if ordered: 114 | encoded = np.concatenate(encoded) 115 | 116 | return encoded 117 | 118 | def encode_sents(self, sents, ordered=False, verbose=False): 119 | if verbose: print('encoding {} sents ...'.format(len(sents))) 120 | encoded = [] 121 | for idx, symbols in enumerate(sents): 122 | if verbose and idx > 0 and idx % 500000 == 0: 123 | print(' line {}'.format(idx)) 124 | encoded.append(self.convert_to_nparray(symbols)) 125 | 126 | if ordered: 127 | encoded = np.concatenate(encoded) 128 | 129 | return encoded 130 | 131 | def add_special(self, sym): 132 | if sym not in self.sym2idx: 133 | self.idx2sym.append(sym) 134 | self.sym2idx[sym] = len(self.idx2sym) - 1 135 | setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) 136 | 137 | def add_symbol(self, sym): 138 | if sym not in self.sym2idx: 139 | self.idx2sym.append(sym) 140 | self.sym2idx[sym] = len(self.idx2sym) - 1 141 | 142 | def get_sym(self, idx): 143 | assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) 144 | return self.idx2sym[idx] 145 | 146 | def get_idx(self, sym): 147 | if sym in self.sym2idx: 148 | return self.sym2idx[sym] 149 | else: 150 | assert hasattr(self, 'unk_idx') 151 | return self.sym2idx.get(sym, self.unk_idx) 152 | 153 | def get_symbols(self, indices): 154 | return [self.get_sym(idx) for idx in indices] 155 | 156 | def get_indices(self, symbols): 157 | return [self.get_idx(sym) for sym in symbols] 158 | 159 | def convert_to_nparray(self, symbols): 160 | nparray = np.array(self.get_indices(symbols), dtype=np.int64) 161 | return nparray 162 | 163 | def convert_to_sent(self, indices, exclude=None): 164 | if exclude is None: 165 | return ' '.join([self.get_sym(idx) for idx in indices]) 166 | else: 167 | return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) 168 | 169 | def __len__(self): 170 | return len(self.idx2sym) 171 | --------------------------------------------------------------------------------