├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── classifier_utils.py ├── data_utils.py ├── function_builder.py ├── gpu_utils.py ├── misc ├── race_example.md └── slides.pdf ├── model_utils.py ├── modeling.py ├── notebooks └── colab_imdb_gpu.ipynb ├── prepro_utils.py ├── run_classifier.py ├── run_race.py ├── run_squad.py ├── scripts ├── gpu_squad_base.sh ├── prepro_squad.sh ├── tpu_race_large_bsz32.sh ├── tpu_race_large_bsz8.sh └── tpu_squad_large.sh ├── squad_utils.py ├── tpu_estimator.py ├── train.py ├── train_gpu.py └── xlnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 XLNet Authors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | **XLNet** is a new unsupervised language representation learning method based on a novel generalized permutation language modeling objective. Additionally, XLNet employs [Transformer-XL](https://arxiv.org/abs/1901.02860) as the backbone model, exhibiting excellent performance for language tasks involving long context. Overall, XLNet achieves state-of-the-art (SOTA) results on various downstream language tasks including question answering, natural language inference, sentiment analysis, and document ranking. 4 | 5 | For a detailed description of technical details and experimental results, please refer to our paper: 6 | 7 | ​ [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) 8 | 9 | ​ Zhilin Yang\*, Zihang Dai\*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 10 | 11 | ​ (*: equal contribution) 12 | 13 | ​ Preprint 2019 14 | 15 | 16 | 17 | 18 | ## Release Notes 19 | 20 | * July 16, 2019: XLNet-Base. 21 | * June 19, 2019: initial release with XLNet-Large and code. 22 | 23 | ## Results 24 | 25 | As of June 19, 2019, XLNet outperforms BERT on 20 tasks and achieves state-of-the-art results on 18 tasks. Below are some comparison between XLNet-Large and BERT-Large, which have similar model sizes: 26 | 27 | ### Results on Reading Comprehension 28 | 29 | Model | [RACE accuracy](http://www.qizhexie.com/data/RACE_leaderboard.html) | SQuAD1.1 EM | SQuAD2.0 EM 30 | --- | --- | --- | --- 31 | BERT-Large | 72.0 | 84.1 | 78.98 32 | XLNet-Base | | | 80.18 33 | XLNet-Large | **81.75** | **88.95** | **86.12** 34 | 35 | We use SQuAD dev results in the table to exclude other factors such as using additional training data or other data augmentation techniques. See [SQuAD leaderboard](https://rajpurkar.github.io/SQuAD-explorer/) for test numbers. 36 | 37 | ### Results on Text Classification 38 | 39 | Model | IMDB | Yelp-2 | Yelp-5 | DBpedia | Amazon-2 | Amazon-5 40 | --- | --- | --- | --- | --- | --- | --- 41 | BERT-Large | 4.51 | 1.89 | 29.32 | 0.64 | 2.63 | 34.17 42 | XLNet-Large | **3.79** | **1.55** | **27.80** | **0.62** | **2.40** | **32.26** 43 | 44 | The above numbers are error rates. 45 | 46 | ### Results on GLUE 47 | 48 | Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B 49 | --- | --- | --- | --- | --- | --- | --- | --- | --- 50 | BERT-Large | 86.6 | 92.3 | 91.3 | 70.4 | 93.2 | 88.0 | 60.6 | 90.0 51 | XLNet-Base | 86.8 | 91.7 | 91.4 | 74.0 | 94.7 | 88.2 | 60.2 | 89.5 52 | XLNet-Large | **89.8** | **93.9** | **91.8** | **83.8** | **95.6** | **89.2** | **63.6** | **91.8** 53 | 54 | We use single-task dev results in the table to exclude other factors such as multi-task learning or using ensembles. 55 | 56 | ## Pre-trained models 57 | 58 | ### Released Models 59 | 60 | As of July 16, 2019, the following models have been made available: 61 | * **[`XLNet-Large, Cased`](https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip)**: 24-layer, 1024-hidden, 16-heads 62 | * **[`XLNet-Base, Cased`](https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip)**: 12-layer, 768-hidden, 12-heads. This model is trained on full data (different from the one in the paper). 63 | 64 | We only release cased models for now because on the tasks we consider, we found: (1) for the base setting, cased and uncased models have similar performance; (2) for the large setting, cased models are a bit better in some tasks. 65 | 66 | Each .zip file contains three items: 67 | * A TensorFlow checkpoint (`xlnet_model.ckpt`) containing the pre-trained weights (which is actually 3 files). 68 | * A [Sentence Piece](https://github.com/google/sentencepiece) model (`spiece.model`) used for (de)tokenization. 69 | * A config file (`xlnet_config.json`) which specifies the hyperparameters of the model. 70 | 71 | 72 | ### Future Release Plan 73 | 74 | We also plan to continuously release more pretrained models under different settings, including: 75 | * A pretrained model that is **finetuned on Wikipedia**. This can be used for tasks with Wikipedia text such as SQuAD and HotpotQA. 76 | * Pretrained models with other hyperparameter configurations, targeting specific downstream tasks. 77 | * Pretrained models that benefit from new techniques. 78 | 79 | ### Subscribing to XLNet on Google Groups 80 | 81 | To receive notifications about updates, announcements and new releases, we recommend subscribing to the XLNet on [Google Groups](https://groups.google.com/forum/#!forum/xlnet). 82 | 83 | 84 | 85 | ## Fine-tuning with XLNet 86 | 87 | As of June 19, 2019, this code base has been tested with TensorFlow 1.13.1 under Python2. 88 | 89 | ### Memory Issue during Finetuning 90 | 91 | - Most of the SOTA results in our paper were produced on TPUs, which generally have more RAM than common GPUs. As a result, it is currently very difficult (costly) to re-produce most of the `XLNet-Large` SOTA results in the paper using GPUs with 12GB - 16GB of RAM, because a 16GB GPU is only able to hold a single sequence with length 512 for `XLNet-Large`. Therefore, a large number (ranging from 32 to 128, equal to `batch_size`) of GPUs are required to reproduce many results in the paper. 92 | - We are experimenting with gradient accumulation to potentially relieve the memory burden, which could be included in a near-future update. 93 | - **Alternative methods** of finetuning XLNet on **constrained hardware** have been presented in [renatoviolin's repo](https://github.com/renatoviolin/xlnet), which obtained 86.24 F1 on SQuAD2.0 with a 8GB memory GPU. 94 | 95 | Given the memory issue mentioned above, using the default finetuning scripts (`run_classifier.py` and `run_squad.py`), we benchmarked the maximum batch size on a single **16GB** GPU with TensorFlow **1.13.1**: 96 | 97 | | System | Seq Length | Max Batch Size | 98 | | ------------- | ---------- | -------------- | 99 | | `XLNet-Base` | 64 | 120 | 100 | | ... | 128 | 56 | 101 | | ... | 256 | 24 | 102 | | ... | 512 | 8 | 103 | | `XLNet-Large` | 64 | 16 | 104 | | ... | 128 | 8 | 105 | | ... | 256 | 2 | 106 | | ... | 512 | 1 | 107 | 108 | In most cases, it is possible to reduce the batch size `train_batch_size` or the maximum sequence length `max_seq_length` to fit in given hardware. The decrease in performance depends on the task and the available resources. 109 | 110 | 111 | ### Text Classification/Regression 112 | 113 | The code used to perform classification/regression finetuning is in `run_classifier.py`. It also contains examples for standard one-document classification, one-document regression, and document pair classification. Here, we provide two concrete examples of how `run_classifier.py` can be used. 114 | 115 | From here on, we assume XLNet-Large and XLNet-base has been downloaded to `$LARGE_DIR` and `$BASE_DIR` respectively. 116 | 117 | 118 | #### (1) STS-B: sentence pair relevance regression (with GPUs) 119 | 120 | - Download the [GLUE data](https://gluebenchmark.com/tasks) by running [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) and unpack it to some directory `$GLUE_DIR`. 121 | 122 | - Perform **multi-GPU** (4 V100 GPUs) finetuning with XLNet-Large by running 123 | 124 | ```shell 125 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_classifier.py \ 126 | --do_train=True \ 127 | --do_eval=False \ 128 | --task_name=sts-b \ 129 | --data_dir=${GLUE_DIR}/STS-B \ 130 | --output_dir=proc_data/sts-b \ 131 | --model_dir=exp/sts-b \ 132 | --uncased=False \ 133 | --spiece_model_file=${LARGE_DIR}/spiece.model \ 134 | --model_config_path=${LARGE_DIR}/xlnet_config.json \ 135 | --init_checkpoint=${LARGE_DIR}/xlnet_model.ckpt \ 136 | --max_seq_length=128 \ 137 | --train_batch_size=8 \ 138 | --num_hosts=1 \ 139 | --num_core_per_host=4 \ 140 | --learning_rate=5e-5 \ 141 | --train_steps=1200 \ 142 | --warmup_steps=120 \ 143 | --save_steps=600 \ 144 | --is_regression=True 145 | ``` 146 | 147 | - Evaluate the finetuning results with a single GPU by 148 | 149 | ```shell 150 | CUDA_VISIBLE_DEVICES=0 python run_classifier.py \ 151 | --do_train=False \ 152 | --do_eval=True \ 153 | --task_name=sts-b \ 154 | --data_dir=${GLUE_DIR}/STS-B \ 155 | --output_dir=proc_data/sts-b \ 156 | --model_dir=exp/sts-b \ 157 | --uncased=False \ 158 | --spiece_model_file=${LARGE_DIR}/spiece.model \ 159 | --model_config_path=${LARGE_DIR}/xlnet_config.json \ 160 | --max_seq_length=128 \ 161 | --eval_batch_size=8 \ 162 | --num_hosts=1 \ 163 | --num_core_per_host=1 \ 164 | --eval_all_ckpt=True \ 165 | --is_regression=True 166 | 167 | # Expected performance: "eval_pearsonr 0.916+ " 168 | ``` 169 | 170 | **Notes**: 171 | 172 | - In the context of GPU training, `num_core_per_host` denotes the number of GPUs to use. 173 | - In the multi-GPU setting, `train_batch_size` refers to the per-GPU batch size. 174 | - `eval_all_ckpt` allows one to evaluate all saved checkpoints (save frequency is controlled by `save_steps`) after training finishes and choose the best model based on dev performance. 175 | - `data_dir` and `output_dir` refer to the directories of the "raw data" and "preprocessed tfrecords" respectively, while `model_dir` is the working directory for saving checkpoints and tensorflow events. **`model_dir` should be set as a separate folder to `init_checkpoint`.** 176 | - To try out XLNet-base, one can simply set `--train_batch_size=32` and `--num_core_per_host=1`, along with according changes in `init_checkpoint` and `model_config_path`. 177 | - For GPUs with smaller RAM, please proportionally decrease the `train_batch_size` and increase `num_core_per_host` to use the same training setting. 178 | - **Important**: we separate the training and evaluation into "two phases", as using multi GPUs to perform evaluation is tricky (one has to correctly separate the data across GPUs). To ensure correctness, we only support single-GPU evaluation for now. 179 | 180 | 181 | #### (2) IMDB: movie review sentiment classification (with TPU V3-8) 182 | 183 | - Download and unpack the IMDB dataset by running 184 | 185 | ```shell 186 | wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz 187 | tar zxvf aclImdb_v1.tar.gz 188 | ``` 189 | 190 | - Launch a Google cloud TPU V3-8 instance (see the [Google Cloud TPU tutorial](https://cloud.google.com/tpu/docs/tutorials/mnist) for how to set up Cloud TPUs). 191 | 192 | - Set up your Google storage bucket path `$GS_ROOT` and move the IMDB dataset and pretrained checkpoint into your Google storage. 193 | 194 | - Perform TPU finetuning with XLNet-Large by running 195 | 196 | ```shell 197 | python run_classifier.py \ 198 | --use_tpu=True \ 199 | --tpu=${TPU_NAME} \ 200 | --do_train=True \ 201 | --do_eval=True \ 202 | --eval_all_ckpt=True \ 203 | --task_name=imdb \ 204 | --data_dir=${IMDB_DIR} \ 205 | --output_dir=${GS_ROOT}/proc_data/imdb \ 206 | --model_dir=${GS_ROOT}/exp/imdb \ 207 | --uncased=False \ 208 | --spiece_model_file=${LARGE_DIR}/spiece.model \ 209 | --model_config_path=${GS_ROOT}/${LARGE_DIR}/model_config.json \ 210 | --init_checkpoint=${GS_ROOT}/${LARGE_DIR}/xlnet_model.ckpt \ 211 | --max_seq_length=512 \ 212 | --train_batch_size=32 \ 213 | --eval_batch_size=8 \ 214 | --num_hosts=1 \ 215 | --num_core_per_host=8 \ 216 | --learning_rate=2e-5 \ 217 | --train_steps=4000 \ 218 | --warmup_steps=500 \ 219 | --save_steps=500 \ 220 | --iterations=500 221 | 222 | # Expected performance: "eval_accuracy 0.962+ " 223 | ``` 224 | 225 | **Notes**: 226 | 227 | - To obtain the SOTA on the IMDB dataset, using sequence length 512 is **necessary**. Therefore, we show how this can be done with a TPU V3-8. 228 | - Alternatively, one can use a sequence length smaller than 512, a smaller batch size, or switch to XLNet-base to train on GPUs. But performance drop is expected. 229 | - Notice that the `data_dir` and `spiece_model_file` both use a local path rather than a Google Storage path. The reason is that data preprocessing is actually performed locally. Hence, using local paths leads to a faster preprocessing speed. 230 | 231 | ### SQuAD2.0 232 | 233 | The code for the SQuAD dataset is included in `run_squad.py`. 234 | 235 | To run the code: 236 | 237 | (1) Download the SQuAD2.0 dataset into `$SQUAD_DIR` by: 238 | 239 | ```shell 240 | mkdir -p ${SQUAD_DIR} && cd ${SQUAD_DIR} 241 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json 242 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json 243 | ``` 244 | 245 | (2) Perform data preprocessing using the script `scripts/prepro_squad.sh`. 246 | 247 | - This will take quite some time in order to accurately map character positions (raw data) to sentence piece positions (used for training). 248 | 249 | - For faster parallel preprocessing, please refer to the flags `--num_proc` and `--proc_id` in `run_squad.py`. 250 | 251 | (3) Perform training and evaluation. 252 | 253 | For the best performance, XLNet-Large uses sequence length 512 and batch size 48 for training. 254 | 255 | - As a result, reproducing the best result with GPUs is quite difficult. 256 | 257 | - For training with one TPU v3-8, one can simply run the script `scripts/tpu_squad_large.sh` after both the TPU and Google storage have been setup. 258 | - `run_squad.py` will automatically perform threshold searching on the dev set of squad and output the score. With `scripts/tpu_squad_large.sh`, the expected F1 score should be around 88.6 (median of our multiple runs). 259 | 260 | Alternatively, one can use XLNet-Base with GPUs (e.g. three V100). One set of reasonable hyper-parameters can be found in the script `scripts/gpu_squad_base.sh`. 261 | 262 | 263 | ### RACE reading comprehension 264 | 265 | The code for the reading comprehension task [RACE](https://www.cs.cmu.edu/~glai1/data/race/) is included in `run_race.py`. 266 | 267 | - Notably, the average length of the passages in RACE is over 300 tokens (not peices), which is significantly longer than other popular reading comprehension datasets such as SQuAD. 268 | - Also, many questions can be very difficult and requires complex reasoning for machines to solve (see [one example here](misc/race_example.md)). 269 | 270 | 271 | To run the code: 272 | 273 | (1) Download the RACE dataset from the [official website](https://www.cs.cmu.edu/~glai1/data/race/) and unpack the raw data to `$RACE_DIR`. 274 | 275 | (2) Perform training and evaluation: 276 | 277 | - The SOTA performance (accuracy 81.75) of RACE is produced using XLNet-Large with sequence length 512 and batch size 32, which requires a large TPU v3-32 in the pod setting. Please refer to the script `script/tpu_race_large_bsz32.sh` for this setting. 278 | - Using XLNet-Large with sequence length 512 and batch size 8 on a TPU v3-8 can give you an accuracy of around 80.3 (see `script/tpu_race_large_bsz8.sh`). 279 | 280 | ### Using Google Colab 281 | 282 | [An example](notebooks/colab_imdb_gpu.ipynb) of using Google Colab with GPUs has been provided. Note that since the hardware is constrained in the example, the results are worse than the best we can get. It mainly serves as an example and should be modified accordingly to maximize performance. 283 | 284 | 285 | ## Custom Usage of XLNet 286 | 287 | ### XLNet Abstraction 288 | 289 | For finetuning, it is likely that you will be able to modify existing files such as `run_classifier.py`, `run_squad.py` and `run_race.py` for your task at hand. However, we also provide an abstraction of XLNet to enable more flexible usage. Below is an example: 290 | 291 | ```python 292 | import xlnet 293 | 294 | # some code omitted here... 295 | # initialize FLAGS 296 | # initialize instances of tf.Tensor, including input_ids, seg_ids, and input_mask 297 | 298 | # XLNetConfig contains hyperparameters that are specific to a model checkpoint. 299 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 300 | 301 | # RunConfig contains hyperparameters that could be different between pretraining and finetuning. 302 | run_config = xlnet.create_run_config(is_training=True, is_finetune=True, FLAGS=FLAGS) 303 | 304 | # Construct an XLNet model 305 | xlnet_model = xlnet.XLNetModel( 306 | xlnet_config=xlnet_config, 307 | run_config=run_config, 308 | input_ids=input_ids, 309 | seg_ids=seg_ids, 310 | input_mask=input_mask) 311 | 312 | # Get a summary of the sequence using the last hidden state 313 | summary = xlnet_model.get_pooled_out(summary_type="last") 314 | 315 | # Get a sequence output 316 | seq_out = xlnet_model.get_sequence_output() 317 | 318 | # build your applications based on `summary` or `seq_out` 319 | ``` 320 | 321 | ### Tokenization 322 | 323 | Below is an example of doing tokenization in XLNet: 324 | ```python 325 | import sentencepiece as spm 326 | from prepro_utils import preprocess_text, encode_ids 327 | 328 | # some code omitted here... 329 | # initialize FLAGS 330 | 331 | text = "An input text string." 332 | 333 | sp_model = spm.SentencePieceProcessor() 334 | sp_model.Load(FLAGS.spiece_model_file) 335 | text = preprocess_text(text, lower=FLAGS.uncased) 336 | ids = encode_ids(sp_model, text) 337 | ``` 338 | where `FLAGS.spiece_model_file` is the SentencePiece model file in the same zip as the pretrained model, `FLAGS.uncased` is a bool indicating whether to do uncasing. 339 | 340 | 341 | ## Pretraining with XLNet 342 | 343 | Refer to `train.py` for pretraining on TPUs and `train_gpu.py` for pretraining on GPUs. First we need to preprocess the text data into tfrecords. 344 | 345 | ```shell 346 | python data_utils.py \ 347 | --bsz_per_host=32 \ 348 | --num_core_per_host=16 \ 349 | --seq_len=512 \ 350 | --reuse_len=256 \ 351 | --input_glob=*.txt \ 352 | --save_dir=${SAVE_DIR} \ 353 | --num_passes=20 \ 354 | --bi_data=True \ 355 | --sp_path=spiece.model \ 356 | --mask_alpha=6 \ 357 | --mask_beta=1 \ 358 | --num_predict=85 359 | ``` 360 | 361 | where `input_glob` defines all input text files, `save_dir` is the output directory for tfrecords, and `sp_path` is a [Sentence Piece](https://github.com/google/sentencepiece) model. Here is our script to train the Sentence Piece model 362 | 363 | ```bash 364 | spm_train \ 365 | --input=$INPUT \ 366 | --model_prefix=sp10m.cased.v3 \ 367 | --vocab_size=32000 \ 368 | --character_coverage=0.99995 \ 369 | --model_type=unigram \ 370 | --control_symbols=,,,, \ 371 | --user_defined_symbols=,.,(,),",-,–,£,€ \ 372 | --shuffle_input_sentence \ 373 | --input_sentence_size=10000000 374 | ``` 375 | 376 | Special symbols are used, including `control_symbols` and `user_defined_symbols`. We use `` and `` to denote End of Paragraph and End of Document respectively. 377 | 378 | The input text files to `data_utils.py` must use the following format: 379 | * Each line is a sentence. 380 | * An empty line means End of Document. 381 | * (Optional) If one also wants to model paragraph structures, `` can be inserted at the end of certain lines (without any space) to indicate that the corresponding sentence ends a paragraph. 382 | 383 | For example, the text input file could be: 384 | ``` 385 | This is the first sentence. 386 | This is the second sentence and also the end of the paragraph. 387 | Another paragraph. 388 | 389 | Another document starts here. 390 | ``` 391 | 392 | After preprocessing, we are ready to pretrain an XLNet. Below are the hyperparameters used for pretraining XLNet-Large: 393 | 394 | ```shell 395 | python train.py 396 | --record_info_dir=$DATA/tfrecords \ 397 | --train_batch_size=2048 \ 398 | --seq_len=512 \ 399 | --reuse_len=256 \ 400 | --mem_len=384 \ 401 | --perm_size=256 \ 402 | --n_layer=24 \ 403 | --d_model=1024 \ 404 | --d_embed=1024 \ 405 | --n_head=16 \ 406 | --d_head=64 \ 407 | --d_inner=4096 \ 408 | --untie_r=True \ 409 | --mask_alpha=6 \ 410 | --mask_beta=1 \ 411 | --num_predict=85 412 | ``` 413 | 414 | where we only list the most important flags and the other flags could be adjusted based on specific use cases. 415 | 416 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zihangdai/xlnet/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/__init__.py -------------------------------------------------------------------------------- /classifier_utils.py: -------------------------------------------------------------------------------- 1 | from absl import flags 2 | 3 | import re 4 | import numpy as np 5 | 6 | import tensorflow as tf 7 | from data_utils import SEP_ID, CLS_ID 8 | 9 | FLAGS = flags.FLAGS 10 | 11 | SEG_ID_A = 0 12 | SEG_ID_B = 1 13 | SEG_ID_CLS = 2 14 | SEG_ID_SEP = 3 15 | SEG_ID_PAD = 4 16 | 17 | class PaddingInputExample(object): 18 | """Fake example so the num input examples is a multiple of the batch size. 19 | When running eval/predict on the TPU, we need to pad the number of examples 20 | to be a multiple of the batch size, because the TPU requires a fixed batch 21 | size. The alternative is to drop the last batch, which is bad because it means 22 | the entire output data won't be generated. 23 | We use this class instead of `None` because treating `None` as padding 24 | battches could cause silent errors. 25 | """ 26 | 27 | 28 | class InputFeatures(object): 29 | """A single set of features of data.""" 30 | 31 | def __init__(self, 32 | input_ids, 33 | input_mask, 34 | segment_ids, 35 | label_id, 36 | is_real_example=True): 37 | self.input_ids = input_ids 38 | self.input_mask = input_mask 39 | self.segment_ids = segment_ids 40 | self.label_id = label_id 41 | self.is_real_example = is_real_example 42 | 43 | 44 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 45 | """Truncates a sequence pair in place to the maximum length.""" 46 | 47 | # This is a simple heuristic which will always truncate the longer sequence 48 | # one token at a time. This makes more sense than truncating an equal percent 49 | # of tokens from each, since if one sequence is very short then each token 50 | # that's truncated likely contains more information than a longer sequence. 51 | while True: 52 | total_length = len(tokens_a) + len(tokens_b) 53 | if total_length <= max_length: 54 | break 55 | if len(tokens_a) > len(tokens_b): 56 | tokens_a.pop() 57 | else: 58 | tokens_b.pop() 59 | 60 | 61 | def convert_single_example(ex_index, example, label_list, max_seq_length, 62 | tokenize_fn): 63 | """Converts a single `InputExample` into a single `InputFeatures`.""" 64 | 65 | if isinstance(example, PaddingInputExample): 66 | return InputFeatures( 67 | input_ids=[0] * max_seq_length, 68 | input_mask=[1] * max_seq_length, 69 | segment_ids=[0] * max_seq_length, 70 | label_id=0, 71 | is_real_example=False) 72 | 73 | if label_list is not None: 74 | label_map = {} 75 | for (i, label) in enumerate(label_list): 76 | label_map[label] = i 77 | 78 | tokens_a = tokenize_fn(example.text_a) 79 | tokens_b = None 80 | if example.text_b: 81 | tokens_b = tokenize_fn(example.text_b) 82 | 83 | if tokens_b: 84 | # Modifies `tokens_a` and `tokens_b` in place so that the total 85 | # length is less than the specified length. 86 | # Account for two [SEP] & one [CLS] with "- 3" 87 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 88 | else: 89 | # Account for one [SEP] & one [CLS] with "- 2" 90 | if len(tokens_a) > max_seq_length - 2: 91 | tokens_a = tokens_a[:max_seq_length - 2] 92 | 93 | tokens = [] 94 | segment_ids = [] 95 | for token in tokens_a: 96 | tokens.append(token) 97 | segment_ids.append(SEG_ID_A) 98 | tokens.append(SEP_ID) 99 | segment_ids.append(SEG_ID_A) 100 | 101 | if tokens_b: 102 | for token in tokens_b: 103 | tokens.append(token) 104 | segment_ids.append(SEG_ID_B) 105 | tokens.append(SEP_ID) 106 | segment_ids.append(SEG_ID_B) 107 | 108 | tokens.append(CLS_ID) 109 | segment_ids.append(SEG_ID_CLS) 110 | 111 | input_ids = tokens 112 | 113 | # The mask has 0 for real tokens and 1 for padding tokens. Only real 114 | # tokens are attended to. 115 | input_mask = [0] * len(input_ids) 116 | 117 | # Zero-pad up to the sequence length. 118 | if len(input_ids) < max_seq_length: 119 | delta_len = max_seq_length - len(input_ids) 120 | input_ids = [0] * delta_len + input_ids 121 | input_mask = [1] * delta_len + input_mask 122 | segment_ids = [SEG_ID_PAD] * delta_len + segment_ids 123 | 124 | assert len(input_ids) == max_seq_length 125 | assert len(input_mask) == max_seq_length 126 | assert len(segment_ids) == max_seq_length 127 | 128 | if label_list is not None: 129 | label_id = label_map[example.label] 130 | else: 131 | label_id = example.label 132 | if ex_index < 5: 133 | tf.logging.info("*** Example ***") 134 | tf.logging.info("guid: %s" % (example.guid)) 135 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 136 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 137 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 138 | tf.logging.info("label: {} (id = {})".format(example.label, label_id)) 139 | 140 | feature = InputFeatures( 141 | input_ids=input_ids, 142 | input_mask=input_mask, 143 | segment_ids=segment_ids, 144 | label_id=label_id) 145 | return feature 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /function_builder.py: -------------------------------------------------------------------------------- 1 | """doc.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import functools 7 | import os 8 | import tensorflow as tf 9 | import modeling 10 | import xlnet 11 | 12 | 13 | def construct_scalar_host_call( 14 | monitor_dict, 15 | model_dir, 16 | prefix="", 17 | reduce_fn=None): 18 | """ 19 | Construct host calls to monitor training progress on TPUs. 20 | """ 21 | 22 | metric_names = list(monitor_dict.keys()) 23 | 24 | def host_call_fn(global_step, *args): 25 | """actual host call function.""" 26 | step = global_step[0] 27 | with tf.contrib.summary.create_file_writer( 28 | logdir=model_dir, filename_suffix=".host_call").as_default(): 29 | with tf.contrib.summary.always_record_summaries(): 30 | for i, name in enumerate(metric_names): 31 | if reduce_fn is None: 32 | scalar = args[i][0] 33 | else: 34 | scalar = reduce_fn(args[i]) 35 | with tf.contrib.summary.record_summaries_every_n_global_steps( 36 | 100, global_step=step): 37 | tf.contrib.summary.scalar(prefix + name, scalar, step=step) 38 | 39 | return tf.contrib.summary.all_summary_ops() 40 | 41 | global_step_tensor = tf.reshape(tf.train.get_or_create_global_step(), [1]) 42 | other_tensors = [tf.reshape(monitor_dict[key], [1]) for key in metric_names] 43 | 44 | return host_call_fn, [global_step_tensor] + other_tensors 45 | 46 | 47 | def two_stream_loss(FLAGS, features, labels, mems, is_training): 48 | """Pretraining loss with two-stream attention Transformer-XL.""" 49 | 50 | #### Unpack input 51 | mem_name = "mems" 52 | mems = mems.get(mem_name, None) 53 | 54 | inp_k = tf.transpose(features["input_k"], [1, 0]) 55 | inp_q = tf.transpose(features["input_q"], [1, 0]) 56 | 57 | seg_id = tf.transpose(features["seg_id"], [1, 0]) 58 | 59 | inp_mask = None 60 | perm_mask = tf.transpose(features["perm_mask"], [1, 2, 0]) 61 | 62 | if FLAGS.num_predict is not None: 63 | # [num_predict x tgt_len x bsz] 64 | target_mapping = tf.transpose(features["target_mapping"], [1, 2, 0]) 65 | else: 66 | target_mapping = None 67 | 68 | # target for LM loss 69 | tgt = tf.transpose(features["target"], [1, 0]) 70 | 71 | # target mask for LM loss 72 | tgt_mask = tf.transpose(features["target_mask"], [1, 0]) 73 | 74 | # construct xlnet config and save to model_dir 75 | xlnet_config = xlnet.XLNetConfig(FLAGS=FLAGS) 76 | xlnet_config.to_json(os.path.join(FLAGS.model_dir, "config.json")) 77 | 78 | # construct run config from FLAGS 79 | run_config = xlnet.create_run_config(is_training, False, FLAGS) 80 | 81 | xlnet_model = xlnet.XLNetModel( 82 | xlnet_config=xlnet_config, 83 | run_config=run_config, 84 | input_ids=inp_k, 85 | seg_ids=seg_id, 86 | input_mask=inp_mask, 87 | mems=mems, 88 | perm_mask=perm_mask, 89 | target_mapping=target_mapping, 90 | inp_q=inp_q) 91 | 92 | output = xlnet_model.get_sequence_output() 93 | new_mems = {mem_name: xlnet_model.get_new_memory()} 94 | lookup_table = xlnet_model.get_embedding_table() 95 | 96 | initializer = xlnet_model.get_initializer() 97 | 98 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 99 | # LM loss 100 | lm_loss = modeling.lm_loss( 101 | hidden=output, 102 | target=tgt, 103 | n_token=xlnet_config.n_token, 104 | d_model=xlnet_config.d_model, 105 | initializer=initializer, 106 | lookup_table=lookup_table, 107 | tie_weight=True, 108 | bi_data=run_config.bi_data, 109 | use_tpu=run_config.use_tpu) 110 | 111 | #### Quantity to monitor 112 | monitor_dict = {} 113 | 114 | if FLAGS.use_bfloat16: 115 | tgt_mask = tf.cast(tgt_mask, tf.float32) 116 | lm_loss = tf.cast(lm_loss, tf.float32) 117 | 118 | total_loss = tf.reduce_sum(lm_loss * tgt_mask) / tf.reduce_sum(tgt_mask) 119 | monitor_dict["total_loss"] = total_loss 120 | 121 | return total_loss, new_mems, monitor_dict 122 | 123 | 124 | def get_loss(FLAGS, features, labels, mems, is_training): 125 | """Pretraining loss with two-stream attention Transformer-XL.""" 126 | if FLAGS.use_bfloat16: 127 | with tf.tpu.bfloat16_scope(): 128 | return two_stream_loss(FLAGS, features, labels, mems, is_training) 129 | else: 130 | return two_stream_loss(FLAGS, features, labels, mems, is_training) 131 | 132 | 133 | def get_classification_loss( 134 | FLAGS, features, n_class, is_training): 135 | """Loss for downstream classification tasks.""" 136 | 137 | bsz_per_core = tf.shape(features["input_ids"])[0] 138 | 139 | inp = tf.transpose(features["input_ids"], [1, 0]) 140 | seg_id = tf.transpose(features["segment_ids"], [1, 0]) 141 | inp_mask = tf.transpose(features["input_mask"], [1, 0]) 142 | label = tf.reshape(features["label_ids"], [bsz_per_core]) 143 | 144 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 145 | run_config = xlnet.create_run_config(is_training, True, FLAGS) 146 | 147 | xlnet_model = xlnet.XLNetModel( 148 | xlnet_config=xlnet_config, 149 | run_config=run_config, 150 | input_ids=inp, 151 | seg_ids=seg_id, 152 | input_mask=inp_mask) 153 | 154 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) 155 | 156 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 157 | 158 | if FLAGS.cls_scope is not None and FLAGS.cls_scope: 159 | cls_scope = "classification_{}".format(FLAGS.cls_scope) 160 | else: 161 | cls_scope = "classification_{}".format(FLAGS.task_name.lower()) 162 | 163 | per_example_loss, logits = modeling.classification_loss( 164 | hidden=summary, 165 | labels=label, 166 | n_class=n_class, 167 | initializer=xlnet_model.get_initializer(), 168 | scope=cls_scope, 169 | return_logits=True) 170 | 171 | total_loss = tf.reduce_mean(per_example_loss) 172 | 173 | return total_loss, per_example_loss, logits 174 | 175 | 176 | def get_regression_loss( 177 | FLAGS, features, is_training): 178 | """Loss for downstream regression tasks.""" 179 | 180 | bsz_per_core = tf.shape(features["input_ids"])[0] 181 | 182 | inp = tf.transpose(features["input_ids"], [1, 0]) 183 | seg_id = tf.transpose(features["segment_ids"], [1, 0]) 184 | inp_mask = tf.transpose(features["input_mask"], [1, 0]) 185 | label = tf.reshape(features["label_ids"], [bsz_per_core]) 186 | 187 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 188 | run_config = xlnet.create_run_config(is_training, True, FLAGS) 189 | 190 | xlnet_model = xlnet.XLNetModel( 191 | xlnet_config=xlnet_config, 192 | run_config=run_config, 193 | input_ids=inp, 194 | seg_ids=seg_id, 195 | input_mask=inp_mask) 196 | 197 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) 198 | 199 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 200 | per_example_loss, logits = modeling.regression_loss( 201 | hidden=summary, 202 | labels=label, 203 | initializer=xlnet_model.get_initializer(), 204 | scope="regression_{}".format(FLAGS.task_name.lower()), 205 | return_logits=True) 206 | 207 | total_loss = tf.reduce_mean(per_example_loss) 208 | 209 | return total_loss, per_example_loss, logits 210 | 211 | 212 | def get_qa_outputs(FLAGS, features, is_training): 213 | """Loss for downstream span-extraction QA tasks such as SQuAD.""" 214 | 215 | inp = tf.transpose(features["input_ids"], [1, 0]) 216 | seg_id = tf.transpose(features["segment_ids"], [1, 0]) 217 | inp_mask = tf.transpose(features["input_mask"], [1, 0]) 218 | cls_index = tf.reshape(features["cls_index"], [-1]) 219 | 220 | seq_len = tf.shape(inp)[0] 221 | 222 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 223 | run_config = xlnet.create_run_config(is_training, True, FLAGS) 224 | 225 | xlnet_model = xlnet.XLNetModel( 226 | xlnet_config=xlnet_config, 227 | run_config=run_config, 228 | input_ids=inp, 229 | seg_ids=seg_id, 230 | input_mask=inp_mask) 231 | output = xlnet_model.get_sequence_output() 232 | initializer = xlnet_model.get_initializer() 233 | 234 | return_dict = {} 235 | 236 | # invalid position mask such as query and special symbols (PAD, SEP, CLS) 237 | p_mask = features["p_mask"] 238 | 239 | # logit of the start position 240 | with tf.variable_scope("start_logits"): 241 | start_logits = tf.layers.dense( 242 | output, 243 | 1, 244 | kernel_initializer=initializer) 245 | start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0]) 246 | start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask 247 | start_log_probs = tf.nn.log_softmax(start_logits_masked, -1) 248 | 249 | # logit of the end position 250 | with tf.variable_scope("end_logits"): 251 | if is_training: 252 | # during training, compute the end logits based on the 253 | # ground truth of the start position 254 | 255 | start_positions = tf.reshape(features["start_positions"], [-1]) 256 | start_index = tf.one_hot(start_positions, depth=seq_len, axis=-1, 257 | dtype=tf.float32) 258 | start_features = tf.einsum("lbh,bl->bh", output, start_index) 259 | start_features = tf.tile(start_features[None], [seq_len, 1, 1]) 260 | end_logits = tf.layers.dense( 261 | tf.concat([output, start_features], axis=-1), xlnet_config.d_model, 262 | kernel_initializer=initializer, activation=tf.tanh, name="dense_0") 263 | end_logits = tf.contrib.layers.layer_norm( 264 | end_logits, begin_norm_axis=-1) 265 | 266 | end_logits = tf.layers.dense( 267 | end_logits, 1, 268 | kernel_initializer=initializer, 269 | name="dense_1") 270 | end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0]) 271 | end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask 272 | end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) 273 | else: 274 | # during inference, compute the end logits based on beam search 275 | 276 | start_top_log_probs, start_top_index = tf.nn.top_k( 277 | start_log_probs, k=FLAGS.start_n_top) 278 | start_index = tf.one_hot(start_top_index, 279 | depth=seq_len, axis=-1, dtype=tf.float32) 280 | start_features = tf.einsum("lbh,bkl->bkh", output, start_index) 281 | end_input = tf.tile(output[:, :, None], 282 | [1, 1, FLAGS.start_n_top, 1]) 283 | start_features = tf.tile(start_features[None], 284 | [seq_len, 1, 1, 1]) 285 | end_input = tf.concat([end_input, start_features], axis=-1) 286 | end_logits = tf.layers.dense( 287 | end_input, 288 | xlnet_config.d_model, 289 | kernel_initializer=initializer, 290 | activation=tf.tanh, 291 | name="dense_0") 292 | end_logits = tf.contrib.layers.layer_norm(end_logits, 293 | begin_norm_axis=-1) 294 | end_logits = tf.layers.dense( 295 | end_logits, 296 | 1, 297 | kernel_initializer=initializer, 298 | name="dense_1") 299 | end_logits = tf.reshape(end_logits, [seq_len, -1, FLAGS.start_n_top]) 300 | end_logits = tf.transpose(end_logits, [1, 2, 0]) 301 | end_logits_masked = end_logits * ( 302 | 1 - p_mask[:, None]) - 1e30 * p_mask[:, None] 303 | end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) 304 | end_top_log_probs, end_top_index = tf.nn.top_k( 305 | end_log_probs, k=FLAGS.end_n_top) 306 | end_top_log_probs = tf.reshape( 307 | end_top_log_probs, 308 | [-1, FLAGS.start_n_top * FLAGS.end_n_top]) 309 | end_top_index = tf.reshape( 310 | end_top_index, 311 | [-1, FLAGS.start_n_top * FLAGS.end_n_top]) 312 | 313 | if is_training: 314 | return_dict["start_log_probs"] = start_log_probs 315 | return_dict["end_log_probs"] = end_log_probs 316 | else: 317 | return_dict["start_top_log_probs"] = start_top_log_probs 318 | return_dict["start_top_index"] = start_top_index 319 | return_dict["end_top_log_probs"] = end_top_log_probs 320 | return_dict["end_top_index"] = end_top_index 321 | 322 | # an additional layer to predict answerability 323 | with tf.variable_scope("answer_class"): 324 | # get the representation of CLS 325 | cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32) 326 | cls_feature = tf.einsum("lbh,bl->bh", output, cls_index) 327 | 328 | # get the representation of START 329 | start_p = tf.nn.softmax(start_logits_masked, axis=-1, 330 | name="softmax_start") 331 | start_feature = tf.einsum("lbh,bl->bh", output, start_p) 332 | 333 | # note(zhiliny): no dependency on end_feature so that we can obtain 334 | # one single `cls_logits` for each sample 335 | ans_feature = tf.concat([start_feature, cls_feature], -1) 336 | ans_feature = tf.layers.dense( 337 | ans_feature, 338 | xlnet_config.d_model, 339 | activation=tf.tanh, 340 | kernel_initializer=initializer, name="dense_0") 341 | ans_feature = tf.layers.dropout(ans_feature, FLAGS.dropout, 342 | training=is_training) 343 | cls_logits = tf.layers.dense( 344 | ans_feature, 345 | 1, 346 | kernel_initializer=initializer, 347 | name="dense_1", 348 | use_bias=False) 349 | cls_logits = tf.squeeze(cls_logits, -1) 350 | 351 | return_dict["cls_logits"] = cls_logits 352 | 353 | return return_dict 354 | 355 | 356 | def get_race_loss(FLAGS, features, is_training): 357 | """Loss for downstream multi-choice QA tasks such as RACE.""" 358 | 359 | bsz_per_core = tf.shape(features["input_ids"])[0] 360 | 361 | def _transform_features(feature): 362 | out = tf.reshape(feature, [bsz_per_core, 4, -1]) 363 | out = tf.transpose(out, [2, 0, 1]) 364 | out = tf.reshape(out, [-1, bsz_per_core * 4]) 365 | return out 366 | 367 | inp = _transform_features(features["input_ids"]) 368 | seg_id = _transform_features(features["segment_ids"]) 369 | inp_mask = _transform_features(features["input_mask"]) 370 | label = tf.reshape(features["label_ids"], [bsz_per_core]) 371 | 372 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path) 373 | run_config = xlnet.create_run_config(is_training, True, FLAGS) 374 | 375 | xlnet_model = xlnet.XLNetModel( 376 | xlnet_config=xlnet_config, 377 | run_config=run_config, 378 | input_ids=inp, 379 | seg_ids=seg_id, 380 | input_mask=inp_mask) 381 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj) 382 | 383 | with tf.variable_scope("logits"): 384 | logits = tf.layers.dense(summary, 1, 385 | kernel_initializer=xlnet_model.get_initializer()) 386 | logits = tf.reshape(logits, [bsz_per_core, 4]) 387 | 388 | one_hot_target = tf.one_hot(label, 4) 389 | per_example_loss = -tf.reduce_sum( 390 | tf.nn.log_softmax(logits) * one_hot_target, -1) 391 | total_loss = tf.reduce_mean(per_example_loss) 392 | 393 | return total_loss, per_example_loss, logits 394 | -------------------------------------------------------------------------------- /gpu_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import tensorflow as tf 7 | 8 | def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"): 9 | def _assign(op): 10 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def 11 | if node_def.op == "Variable": 12 | return ps_dev 13 | else: 14 | return "/gpu:%d" % gpu 15 | return _assign 16 | 17 | 18 | def average_grads_and_vars(tower_grads_and_vars): 19 | def average_dense(grad_and_vars): 20 | if len(grad_and_vars) == 1: 21 | return grad_and_vars[0][0] 22 | 23 | grad = grad_and_vars[0][0] 24 | for g, _ in grad_and_vars[1:]: 25 | grad += g 26 | return grad / len(grad_and_vars) 27 | 28 | def average_sparse(grad_and_vars): 29 | if len(grad_and_vars) == 1: 30 | return grad_and_vars[0][0] 31 | 32 | indices = [] 33 | values = [] 34 | for g, _ in grad_and_vars: 35 | indices += [g.indices] 36 | values += [g.values] 37 | indices = tf.concat(indices, 0) 38 | values = tf.concat(values, 0) / len(grad_and_vars) 39 | return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape) 40 | 41 | average_grads_and_vars = [] 42 | for grad_and_vars in zip(*tower_grads_and_vars): 43 | if grad_and_vars[0][0] is None: 44 | grad = None 45 | elif isinstance(grad_and_vars[0][0], tf.IndexedSlices): 46 | grad = average_sparse(grad_and_vars) 47 | else: 48 | grad = average_dense(grad_and_vars) 49 | # Keep in mind that the Variables are redundant because they are shared 50 | # across towers. So .. we will just return the first tower's pointer to 51 | # the Variable. 52 | v = grad_and_vars[0][1] 53 | grad_and_var = (grad, v) 54 | average_grads_and_vars.append(grad_and_var) 55 | return average_grads_and_vars 56 | 57 | 58 | def load_from_checkpoint(saver, logdir): 59 | sess = tf.get_default_session() 60 | ckpt = tf.train.get_checkpoint_state(logdir) 61 | if ckpt and ckpt.model_checkpoint_path: 62 | if os.path.isabs(ckpt.model_checkpoint_path): 63 | # Restores from checkpoint with absolute path. 64 | saver.restore(sess, ckpt.model_checkpoint_path) 65 | else: 66 | # Restores from checkpoint with relative path. 67 | saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path)) 68 | return True 69 | return False 70 | -------------------------------------------------------------------------------- /misc/race_example.md: -------------------------------------------------------------------------------- 1 | ## A RACE Exampple 2 | 3 | **Paragraph** 4 | 5 | It was a cold night. The taxi driver didn't take even one passenger all day. 6 | When he went by the railway station, he saw a young man coming out with two bags in his hands. 7 | So he drove to him and asked, " where are you going, sir?" 8 | "To the Red Hotel," the young man answered. 9 | When the taxi driver heard this, he didn't feel happy any more. 10 | The young man would give him only three dollars because the hotel was near the railway station. 11 | But suddenly, he had an idea. He took the young man through many streets of the big city. 12 | After a long time, they arrived at the hotel. "Here we are! You should pay me fifteen dollars, please." the taxi driver said to the young man." 13 | What? Fifteen dollars! Do you think I'm a fool? Only last week, I took a taxi from the railway station to this hotel and I only gave the driver thirteen dollars. 14 | I know how much I have to pay for the trip." 15 | 16 | **Question 1** 17 | 18 | Maybe the taxi driver got _ dollars at last. 19 | 20 | **Answer Candidates** 21 | 22 | - 3 23 | - 2 24 | - 13 25 | - 15 26 | 27 | **Question 2** 28 | 29 | Which of the following is TRUE? 30 | 31 | **Answer Candidates** 32 | 33 | - The two taxi drivers were both honest. 34 | - The two taxi drivers cheated the young man. 35 | - It is very far from the railway station to the Red Hotel. 36 | - The young man knew how far it was from the railway station to the hotel. 37 | -------------------------------------------------------------------------------- /misc/slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zihangdai/xlnet/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/misc/slides.pdf -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import os 7 | import re 8 | import numpy as np 9 | import six 10 | from os.path import join 11 | from six.moves import zip 12 | 13 | from absl import flags 14 | 15 | import tensorflow as tf 16 | 17 | 18 | def configure_tpu(FLAGS): 19 | if FLAGS.use_tpu: 20 | tpu_cluster = tf.contrib.cluster_resolver.TPUClusterResolver( 21 | FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 22 | master = tpu_cluster.get_master() 23 | else: 24 | tpu_cluster = None 25 | master = FLAGS.master 26 | 27 | session_config = tf.ConfigProto(allow_soft_placement=True) 28 | # Uncomment the following line if you hope to monitor GPU RAM growth 29 | # session_config.gpu_options.allow_growth = True 30 | 31 | if FLAGS.use_tpu: 32 | strategy = None 33 | tf.logging.info('Use TPU without distribute strategy.') 34 | elif FLAGS.num_core_per_host == 1: 35 | strategy = None 36 | tf.logging.info('Single device mode.') 37 | else: 38 | strategy = tf.contrib.distribute.MirroredStrategy( 39 | num_gpus=FLAGS.num_core_per_host) 40 | tf.logging.info('Use MirroredStrategy with %d devices.', 41 | strategy.num_replicas_in_sync) 42 | 43 | per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 44 | run_config = tf.contrib.tpu.RunConfig( 45 | master=master, 46 | model_dir=FLAGS.model_dir, 47 | session_config=session_config, 48 | tpu_config=tf.contrib.tpu.TPUConfig( 49 | iterations_per_loop=FLAGS.iterations, 50 | num_shards=FLAGS.num_hosts * FLAGS.num_core_per_host, 51 | per_host_input_for_training=per_host_input), 52 | keep_checkpoint_max=FLAGS.max_save, 53 | save_checkpoints_secs=None, 54 | save_checkpoints_steps=FLAGS.save_steps, 55 | train_distribute=strategy 56 | ) 57 | return run_config 58 | 59 | 60 | def init_from_checkpoint(FLAGS, global_vars=False): 61 | tvars = tf.global_variables() if global_vars else tf.trainable_variables() 62 | initialized_variable_names = {} 63 | scaffold_fn = None 64 | if FLAGS.init_checkpoint is not None: 65 | if FLAGS.init_checkpoint.endswith("latest"): 66 | ckpt_dir = os.path.dirname(FLAGS.init_checkpoint) 67 | init_checkpoint = tf.train.latest_checkpoint(ckpt_dir) 68 | else: 69 | init_checkpoint = FLAGS.init_checkpoint 70 | 71 | tf.logging.info("Initialize from the ckpt {}".format(init_checkpoint)) 72 | 73 | (assignment_map, initialized_variable_names 74 | ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint) 75 | if FLAGS.use_tpu: 76 | def tpu_scaffold(): 77 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 78 | return tf.train.Scaffold() 79 | 80 | scaffold_fn = tpu_scaffold 81 | else: 82 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 83 | 84 | # Log customized initialization 85 | tf.logging.info("**** Global Variables ****") 86 | for var in tvars: 87 | init_string = "" 88 | if var.name in initialized_variable_names: 89 | init_string = ", *INIT_FROM_CKPT*" 90 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 91 | init_string) 92 | return scaffold_fn 93 | 94 | 95 | def get_train_op(FLAGS, total_loss, grads_and_vars=None): 96 | global_step = tf.train.get_or_create_global_step() 97 | 98 | # increase the learning rate linearly 99 | if FLAGS.warmup_steps > 0: 100 | warmup_lr = (tf.cast(global_step, tf.float32) 101 | / tf.cast(FLAGS.warmup_steps, tf.float32) 102 | * FLAGS.learning_rate) 103 | else: 104 | warmup_lr = 0.0 105 | 106 | # decay the learning rate 107 | if FLAGS.decay_method == "poly": 108 | decay_lr = tf.train.polynomial_decay( 109 | FLAGS.learning_rate, 110 | global_step=global_step - FLAGS.warmup_steps, 111 | decay_steps=FLAGS.train_steps - FLAGS.warmup_steps, 112 | end_learning_rate=FLAGS.learning_rate * FLAGS.min_lr_ratio) 113 | elif FLAGS.decay_method == "cos": 114 | decay_lr = tf.train.cosine_decay( 115 | FLAGS.learning_rate, 116 | global_step=global_step - FLAGS.warmup_steps, 117 | decay_steps=FLAGS.train_steps - FLAGS.warmup_steps, 118 | alpha=FLAGS.min_lr_ratio) 119 | else: 120 | raise ValueError(FLAGS.decay_method) 121 | 122 | learning_rate = tf.where(global_step < FLAGS.warmup_steps, 123 | warmup_lr, decay_lr) 124 | 125 | if (FLAGS.weight_decay > 0 and not FLAGS.use_tpu and 126 | FLAGS.num_core_per_host > 1): 127 | raise ValueError("Do not support `weight_decay > 0` with multi-gpu " 128 | "training so far.") 129 | 130 | if FLAGS.weight_decay == 0: 131 | optimizer = tf.train.AdamOptimizer( 132 | learning_rate=learning_rate, 133 | epsilon=FLAGS.adam_epsilon) 134 | else: 135 | optimizer = AdamWeightDecayOptimizer( 136 | learning_rate=learning_rate, 137 | epsilon=FLAGS.adam_epsilon, 138 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], 139 | weight_decay_rate=FLAGS.weight_decay) 140 | 141 | if FLAGS.use_tpu: 142 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 143 | 144 | if grads_and_vars is None: 145 | grads_and_vars = optimizer.compute_gradients(total_loss) 146 | gradients, variables = zip(*grads_and_vars) 147 | clipped, gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip) 148 | 149 | if getattr(FLAGS, "lr_layer_decay_rate", 1.0) != 1.0: 150 | n_layer = 0 151 | for i in range(len(clipped)): 152 | m = re.search(r"model/transformer/layer_(\d+?)/", variables[i].name) 153 | if not m: continue 154 | n_layer = max(n_layer, int(m.group(1)) + 1) 155 | 156 | for i in range(len(clipped)): 157 | for l in range(n_layer): 158 | if "model/transformer/layer_{}/".format(l) in variables[i].name: 159 | abs_rate = FLAGS.lr_layer_decay_rate ** (n_layer - 1 - l) 160 | clipped[i] *= abs_rate 161 | tf.logging.info("Apply mult {:.4f} to layer-{} grad of {}".format( 162 | abs_rate, l, variables[i].name)) 163 | break 164 | 165 | train_op = optimizer.apply_gradients( 166 | zip(clipped, variables), global_step=global_step) 167 | 168 | # Manually increment `global_step` for AdamWeightDecayOptimizer 169 | if FLAGS.weight_decay > 0: 170 | new_global_step = global_step + 1 171 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 172 | 173 | return train_op, learning_rate, gnorm 174 | 175 | 176 | def clean_ckpt(_): 177 | input_ckpt = FLAGS.clean_input_ckpt 178 | output_model_dir = FLAGS.clean_output_model_dir 179 | 180 | tf.reset_default_graph() 181 | 182 | var_list = tf.contrib.framework.list_variables(input_ckpt) 183 | var_values, var_dtypes = {}, {} 184 | for (name, shape) in var_list: 185 | if not name.startswith("global_step") and "adam" not in name.lower(): 186 | var_values[name] = None 187 | tf.logging.info("Include {}".format(name)) 188 | else: 189 | tf.logging.info("Exclude {}".format(name)) 190 | 191 | tf.logging.info("Loading from {}".format(input_ckpt)) 192 | reader = tf.contrib.framework.load_checkpoint(input_ckpt) 193 | for name in var_values: 194 | tensor = reader.get_tensor(name) 195 | var_dtypes[name] = tensor.dtype 196 | var_values[name] = tensor 197 | 198 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): 199 | tf_vars = [ 200 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v]) 201 | for v in var_values 202 | ] 203 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 204 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 205 | global_step = tf.Variable( 206 | 0, name="global_step", trainable=False, dtype=tf.int64) 207 | saver = tf.train.Saver(tf.all_variables()) 208 | 209 | if not tf.gfile.Exists(output_model_dir): 210 | tf.gfile.MakeDirs(output_model_dir) 211 | 212 | # Build a model consisting only of variables, set them to the average values. 213 | with tf.Session() as sess: 214 | sess.run(tf.initialize_all_variables()) 215 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 216 | six.iteritems(var_values)): 217 | sess.run(assign_op, {p: value}) 218 | 219 | # Use the built saver to save the averaged checkpoint. 220 | saver.save(sess, join(output_model_dir, "model.ckpt"), 221 | global_step=global_step) 222 | 223 | 224 | def avg_checkpoints(model_dir, output_model_dir, last_k): 225 | tf.reset_default_graph() 226 | 227 | checkpoint_state = tf.train.get_checkpoint_state(model_dir) 228 | checkpoints = checkpoint_state.all_model_checkpoint_paths[- last_k:] 229 | var_list = tf.contrib.framework.list_variables(checkpoints[0]) 230 | var_values, var_dtypes = {}, {} 231 | for (name, shape) in var_list: 232 | if not name.startswith("global_step"): 233 | var_values[name] = np.zeros(shape) 234 | for checkpoint in checkpoints: 235 | reader = tf.contrib.framework.load_checkpoint(checkpoint) 236 | for name in var_values: 237 | tensor = reader.get_tensor(name) 238 | var_dtypes[name] = tensor.dtype 239 | var_values[name] += tensor 240 | tf.logging.info("Read from checkpoint %s", checkpoint) 241 | for name in var_values: # Average. 242 | var_values[name] /= len(checkpoints) 243 | 244 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): 245 | tf_vars = [ 246 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v]) 247 | for v in var_values 248 | ] 249 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] 250 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] 251 | global_step = tf.Variable( 252 | 0, name="global_step", trainable=False, dtype=tf.int64) 253 | saver = tf.train.Saver(tf.all_variables()) 254 | 255 | # Build a model consisting only of variables, set them to the average values. 256 | with tf.Session() as sess: 257 | sess.run(tf.initialize_all_variables()) 258 | for p, assign_op, (name, value) in zip(placeholders, assign_ops, 259 | six.iteritems(var_values)): 260 | sess.run(assign_op, {p: value}) 261 | # Use the built saver to save the averaged checkpoint. 262 | saver.save(sess, join(output_model_dir, "model.ckpt"), 263 | global_step=global_step) 264 | 265 | 266 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 267 | """Compute the union of the current variables and checkpoint variables.""" 268 | assignment_map = {} 269 | initialized_variable_names = {} 270 | 271 | name_to_variable = collections.OrderedDict() 272 | for var in tvars: 273 | name = var.name 274 | m = re.match("^(.*):\\d+$", name) 275 | if m is not None: 276 | name = m.group(1) 277 | name_to_variable[name] = var 278 | 279 | init_vars = tf.train.list_variables(init_checkpoint) 280 | 281 | assignment_map = collections.OrderedDict() 282 | for x in init_vars: 283 | (name, var) = (x[0], x[1]) 284 | # tf.logging.info('original name: %s', name) 285 | if name not in name_to_variable: 286 | continue 287 | # assignment_map[name] = name 288 | assignment_map[name] = name_to_variable[name] 289 | initialized_variable_names[name] = 1 290 | initialized_variable_names[name + ":0"] = 1 291 | 292 | return (assignment_map, initialized_variable_names) 293 | 294 | 295 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 296 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 297 | 298 | def __init__(self, 299 | learning_rate, 300 | weight_decay_rate=0.0, 301 | beta_1=0.9, 302 | beta_2=0.999, 303 | epsilon=1e-6, 304 | exclude_from_weight_decay=None, 305 | include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"], 306 | name="AdamWeightDecayOptimizer"): 307 | """Constructs a AdamWeightDecayOptimizer.""" 308 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 309 | 310 | self.learning_rate = learning_rate 311 | self.weight_decay_rate = weight_decay_rate 312 | self.beta_1 = beta_1 313 | self.beta_2 = beta_2 314 | self.epsilon = epsilon 315 | self.exclude_from_weight_decay = exclude_from_weight_decay 316 | self.include_in_weight_decay = include_in_weight_decay 317 | 318 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 319 | """See base class.""" 320 | assignments = [] 321 | for (grad, param) in grads_and_vars: 322 | if grad is None or param is None: 323 | continue 324 | 325 | param_name = self._get_variable_name(param.name) 326 | 327 | m = tf.get_variable( 328 | name=param_name + "/adam_m", 329 | shape=param.shape.as_list(), 330 | dtype=tf.float32, 331 | trainable=False, 332 | initializer=tf.zeros_initializer()) 333 | v = tf.get_variable( 334 | name=param_name + "/adam_v", 335 | shape=param.shape.as_list(), 336 | dtype=tf.float32, 337 | trainable=False, 338 | initializer=tf.zeros_initializer()) 339 | 340 | # Standard Adam update. 341 | next_m = ( 342 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 343 | next_v = ( 344 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 345 | tf.square(grad))) 346 | 347 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 348 | 349 | # Just adding the square of the weights to the loss function is *not* 350 | # the correct way of using L2 regularization/weight decay with Adam, 351 | # since that will interact with the m and v parameters in strange ways. 352 | # 353 | # Instead we want ot decay the weights in a manner that doesn't interact 354 | # with the m/v parameters. This is equivalent to adding the square 355 | # of the weights to the loss with plain (non-momentum) SGD. 356 | if self._do_use_weight_decay(param_name): 357 | update += self.weight_decay_rate * param 358 | 359 | update_with_lr = self.learning_rate * update 360 | 361 | next_param = param - update_with_lr 362 | 363 | assignments.extend( 364 | [param.assign(next_param), 365 | m.assign(next_m), 366 | v.assign(next_v)]) 367 | 368 | return tf.group(*assignments, name=name) 369 | 370 | def _do_use_weight_decay(self, param_name): 371 | """Whether to use L2 weight decay for `param_name`.""" 372 | if not self.weight_decay_rate: 373 | return False 374 | for r in self.include_in_weight_decay: 375 | if re.search(r, param_name) is not None: 376 | return True 377 | 378 | if self.exclude_from_weight_decay: 379 | for r in self.exclude_from_weight_decay: 380 | if re.search(r, param_name) is not None: 381 | tf.logging.info('Adam WD excludes {}'.format(param_name)) 382 | return False 383 | return True 384 | 385 | def _get_variable_name(self, param_name): 386 | """Get the variable name from the tensor name.""" 387 | m = re.match("^(.*):\\d+$", param_name) 388 | if m is not None: 389 | param_name = m.group(1) 390 | return param_name 391 | 392 | 393 | if __name__ == "__main__": 394 | flags.DEFINE_string("clean_input_ckpt", "", "input ckpt for cleaning") 395 | flags.DEFINE_string("clean_output_model_dir", "", "output dir for cleaned ckpt") 396 | 397 | FLAGS = flags.FLAGS 398 | 399 | tf.app.run(clean_ckpt) 400 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | 9 | def gelu(x): 10 | """Gaussian Error Linear Unit. 11 | 12 | This is a smoother version of the RELU. 13 | Original paper: https://arxiv.org/abs/1606.08415 14 | Args: 15 | x: float Tensor to perform activation. 16 | 17 | Returns: 18 | `x` with the GELU activation applied. 19 | """ 20 | cdf = 0.5 * (1.0 + tf.tanh( 21 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 22 | return x * cdf 23 | 24 | 25 | def embedding_lookup(x, n_token, d_embed, initializer, use_tpu=True, 26 | scope='embedding', reuse=None, dtype=tf.float32): 27 | """TPU and GPU embedding_lookup function.""" 28 | with tf.variable_scope(scope, reuse=reuse): 29 | lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], 30 | dtype=dtype, initializer=initializer) 31 | if use_tpu: 32 | one_hot_idx = tf.one_hot(x, n_token, dtype=dtype) 33 | if one_hot_idx.shape.ndims == 2: 34 | return tf.einsum('in,nd->id', one_hot_idx, lookup_table), lookup_table 35 | else: 36 | return tf.einsum('ibn,nd->ibd', one_hot_idx, lookup_table), lookup_table 37 | else: 38 | return tf.nn.embedding_lookup(lookup_table, x), lookup_table 39 | 40 | 41 | def positional_embedding(pos_seq, inv_freq, bsz=None): 42 | sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq) 43 | pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) 44 | pos_emb = pos_emb[:, None, :] 45 | 46 | if bsz is not None: 47 | pos_emb = tf.tile(pos_emb, [1, bsz, 1]) 48 | 49 | return pos_emb 50 | 51 | 52 | def positionwise_ffn(inp, d_model, d_inner, dropout, kernel_initializer, 53 | activation_type='relu', scope='ff', is_training=True, 54 | reuse=None): 55 | """Position-wise Feed-forward Network.""" 56 | if activation_type == 'relu': 57 | activation = tf.nn.relu 58 | elif activation_type == 'gelu': 59 | activation = gelu 60 | else: 61 | raise ValueError('Unsupported activation type {}'.format(activation_type)) 62 | 63 | output = inp 64 | with tf.variable_scope(scope, reuse=reuse): 65 | output = tf.layers.dense(output, d_inner, activation=activation, 66 | kernel_initializer=kernel_initializer, 67 | name='layer_1') 68 | output = tf.layers.dropout(output, dropout, training=is_training, 69 | name='drop_1') 70 | output = tf.layers.dense(output, d_model, 71 | kernel_initializer=kernel_initializer, 72 | name='layer_2') 73 | output = tf.layers.dropout(output, dropout, training=is_training, 74 | name='drop_2') 75 | output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1, 76 | scope='LayerNorm') 77 | return output 78 | 79 | 80 | def head_projection(h, d_model, n_head, d_head, kernel_initializer, name): 81 | """Project hidden states to a specific head with a 4D-shape.""" 82 | proj_weight = tf.get_variable('{}/kernel'.format(name), 83 | [d_model, n_head, d_head], dtype=h.dtype, 84 | initializer=kernel_initializer) 85 | head = tf.einsum('ibh,hnd->ibnd', h, proj_weight) 86 | 87 | return head 88 | 89 | 90 | def post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training, 91 | kernel_initializer, residual=True): 92 | """Post-attention processing.""" 93 | # post-attention projection (back to `d_model`) 94 | proj_o = tf.get_variable('o/kernel', [d_model, n_head, d_head], 95 | dtype=h.dtype, initializer=kernel_initializer) 96 | attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o) 97 | 98 | attn_out = tf.layers.dropout(attn_out, dropout, training=is_training) 99 | if residual: 100 | output = tf.contrib.layers.layer_norm(attn_out + h, begin_norm_axis=-1, 101 | scope='LayerNorm') 102 | else: 103 | output = tf.contrib.layers.layer_norm(attn_out, begin_norm_axis=-1, 104 | scope='LayerNorm') 105 | 106 | return output 107 | 108 | 109 | def abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt, is_training, 110 | scale): 111 | """Core absolute positional attention operations.""" 112 | 113 | attn_score = tf.einsum('ibnd,jbnd->ijbn', q_head, k_head) 114 | attn_score *= scale 115 | if attn_mask is not None: 116 | attn_score = attn_score - 1e30 * attn_mask 117 | 118 | # attention probability 119 | attn_prob = tf.nn.softmax(attn_score, 1) 120 | attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) 121 | 122 | # attention output 123 | attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head) 124 | 125 | return attn_vec 126 | 127 | 128 | def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, 129 | r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training, 130 | scale): 131 | """Core relative positional attention operations.""" 132 | 133 | # content based attention score 134 | ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h) 135 | 136 | # position based attention score 137 | bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r) 138 | bd = rel_shift(bd, klen=tf.shape(ac)[1]) 139 | 140 | # segment based attention score 141 | if seg_mat is None: 142 | ef = 0 143 | else: 144 | ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed) 145 | ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef) 146 | 147 | # merge attention scores and perform masking 148 | attn_score = (ac + bd + ef) * scale 149 | if attn_mask is not None: 150 | # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask 151 | attn_score = attn_score - 1e30 * attn_mask 152 | 153 | # attention probability 154 | attn_prob = tf.nn.softmax(attn_score, 1) 155 | attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) 156 | 157 | # attention output 158 | attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h) 159 | 160 | return attn_vec 161 | 162 | 163 | def rel_shift(x, klen=-1): 164 | """perform relative shift to form the relative attention score.""" 165 | x_size = tf.shape(x) 166 | 167 | x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]]) 168 | x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) 169 | x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]]) 170 | x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1]) 171 | 172 | return x 173 | 174 | 175 | def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False): 176 | """create causal attention mask.""" 177 | attn_mask = tf.ones([qlen, qlen], dtype=dtype) 178 | mask_u = tf.matrix_band_part(attn_mask, 0, -1) 179 | mask_dia = tf.matrix_band_part(attn_mask, 0, 0) 180 | attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype) 181 | ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) 182 | if same_length: 183 | mask_l = tf.matrix_band_part(attn_mask, -1, 0) 184 | ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) 185 | 186 | return ret 187 | 188 | 189 | def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None): 190 | """cache hidden states into memory.""" 191 | if mem_len is None or mem_len == 0: 192 | return None 193 | else: 194 | if reuse_len is not None and reuse_len > 0: 195 | curr_out = curr_out[:reuse_len] 196 | 197 | if prev_mem is None: 198 | new_mem = curr_out[-mem_len:] 199 | else: 200 | new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:] 201 | 202 | return tf.stop_gradient(new_mem) 203 | 204 | 205 | def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type, 206 | bi_data, bsz=None, dtype=None): 207 | """create relative positional encoding.""" 208 | freq_seq = tf.range(0, d_model, 2.0) 209 | if dtype is not None and dtype != tf.float32: 210 | freq_seq = tf.cast(freq_seq, dtype=dtype) 211 | inv_freq = 1 / (10000 ** (freq_seq / d_model)) 212 | 213 | if attn_type == 'bi': 214 | # beg, end = klen - 1, -qlen 215 | beg, end = klen, -qlen 216 | elif attn_type == 'uni': 217 | # beg, end = klen - 1, -1 218 | beg, end = klen, -1 219 | else: 220 | raise ValueError('Unknown `attn_type` {}.'.format(attn_type)) 221 | 222 | if bi_data: 223 | fwd_pos_seq = tf.range(beg, end, -1.0) 224 | bwd_pos_seq = tf.range(-beg, -end, 1.0) 225 | 226 | if dtype is not None and dtype != tf.float32: 227 | fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) 228 | bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype) 229 | 230 | if clamp_len > 0: 231 | fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) 232 | bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len) 233 | 234 | if bsz is not None: 235 | # With bi_data, the batch size should be divisible by 2. 236 | assert bsz%2 == 0 237 | fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2) 238 | bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2) 239 | else: 240 | fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq) 241 | bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq) 242 | 243 | pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1) 244 | else: 245 | fwd_pos_seq = tf.range(beg, end, -1.0) 246 | if dtype is not None and dtype != tf.float32: 247 | fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) 248 | if clamp_len > 0: 249 | fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len) 250 | pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz) 251 | 252 | return pos_emb 253 | 254 | 255 | def multihead_attn(q, k, v, attn_mask, d_model, n_head, d_head, dropout, 256 | dropatt, is_training, kernel_initializer, residual=True, 257 | scope='abs_attn', reuse=None): 258 | """Standard multi-head attention with absolute positional embedding.""" 259 | 260 | scale = 1 / (d_head ** 0.5) 261 | with tf.variable_scope(scope, reuse=reuse): 262 | # attention heads 263 | q_head = head_projection( 264 | q, d_model, n_head, d_head, kernel_initializer, 'q') 265 | k_head = head_projection( 266 | k, d_model, n_head, d_head, kernel_initializer, 'k') 267 | v_head = head_projection( 268 | v, d_model, n_head, d_head, kernel_initializer, 'v') 269 | 270 | # attention vector 271 | attn_vec = abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt, 272 | is_training, scale) 273 | 274 | # post processing 275 | output = post_attention(v, attn_vec, d_model, n_head, d_head, dropout, 276 | is_training, kernel_initializer, residual) 277 | 278 | return output 279 | 280 | 281 | 282 | def rel_multihead_attn(h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, 283 | attn_mask, mems, d_model, n_head, d_head, dropout, 284 | dropatt, is_training, kernel_initializer, 285 | scope='rel_attn', reuse=None): 286 | """Multi-head attention with relative positional encoding.""" 287 | 288 | scale = 1 / (d_head ** 0.5) 289 | with tf.variable_scope(scope, reuse=reuse): 290 | if mems is not None and mems.shape.ndims > 1: 291 | cat = tf.concat([mems, h], 0) 292 | else: 293 | cat = h 294 | 295 | # content heads 296 | q_head_h = head_projection( 297 | h, d_model, n_head, d_head, kernel_initializer, 'q') 298 | k_head_h = head_projection( 299 | cat, d_model, n_head, d_head, kernel_initializer, 'k') 300 | v_head_h = head_projection( 301 | cat, d_model, n_head, d_head, kernel_initializer, 'v') 302 | 303 | # positional heads 304 | k_head_r = head_projection( 305 | r, d_model, n_head, d_head, kernel_initializer, 'r') 306 | 307 | # core attention ops 308 | attn_vec = rel_attn_core( 309 | q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, 310 | r_r_bias, r_s_bias, attn_mask, dropatt, is_training, scale) 311 | 312 | # post processing 313 | output = post_attention(h, attn_vec, d_model, n_head, d_head, dropout, 314 | is_training, kernel_initializer) 315 | 316 | return output 317 | 318 | 319 | def two_stream_rel_attn(h, g, r, mems, r_w_bias, r_r_bias, seg_mat, r_s_bias, 320 | seg_embed, attn_mask_h, attn_mask_g, target_mapping, 321 | d_model, n_head, d_head, dropout, dropatt, is_training, 322 | kernel_initializer, scope='rel_attn'): 323 | """Two-stream attention with relative positional encoding.""" 324 | 325 | scale = 1 / (d_head ** 0.5) 326 | with tf.variable_scope(scope, reuse=False): 327 | 328 | # content based attention score 329 | if mems is not None and mems.shape.ndims > 1: 330 | cat = tf.concat([mems, h], 0) 331 | else: 332 | cat = h 333 | 334 | # content-based key head 335 | k_head_h = head_projection( 336 | cat, d_model, n_head, d_head, kernel_initializer, 'k') 337 | 338 | # content-based value head 339 | v_head_h = head_projection( 340 | cat, d_model, n_head, d_head, kernel_initializer, 'v') 341 | 342 | # position-based key head 343 | k_head_r = head_projection( 344 | r, d_model, n_head, d_head, kernel_initializer, 'r') 345 | 346 | ##### h-stream 347 | # content-stream query head 348 | q_head_h = head_projection( 349 | h, d_model, n_head, d_head, kernel_initializer, 'q') 350 | 351 | # core attention ops 352 | attn_vec_h = rel_attn_core( 353 | q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, 354 | r_r_bias, r_s_bias, attn_mask_h, dropatt, is_training, scale) 355 | 356 | # post processing 357 | output_h = post_attention(h, attn_vec_h, d_model, n_head, d_head, dropout, 358 | is_training, kernel_initializer) 359 | 360 | with tf.variable_scope(scope, reuse=True): 361 | ##### g-stream 362 | # query-stream query head 363 | q_head_g = head_projection( 364 | g, d_model, n_head, d_head, kernel_initializer, 'q') 365 | 366 | # core attention ops 367 | if target_mapping is not None: 368 | q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping) 369 | attn_vec_g = rel_attn_core( 370 | q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, 371 | r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale) 372 | attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping) 373 | else: 374 | attn_vec_g = rel_attn_core( 375 | q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, 376 | r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale) 377 | 378 | # post processing 379 | output_g = post_attention(g, attn_vec_g, d_model, n_head, d_head, dropout, 380 | is_training, kernel_initializer) 381 | 382 | return output_h, output_g 383 | 384 | 385 | def transformer_xl(inp_k, n_token, n_layer, d_model, n_head, 386 | d_head, d_inner, dropout, dropatt, attn_type, 387 | bi_data, initializer, is_training, mem_len=None, 388 | inp_q=None, mems=None, 389 | same_length=False, clamp_len=-1, untie_r=False, 390 | use_tpu=True, input_mask=None, 391 | perm_mask=None, seg_id=None, reuse_len=None, 392 | ff_activation='relu', target_mapping=None, 393 | use_bfloat16=False, scope='transformer', **kwargs): 394 | """ 395 | Defines a Transformer-XL computation graph with additional 396 | support for XLNet. 397 | 398 | Args: 399 | 400 | inp_k: int32 Tensor in shape [len, bsz], the input token IDs. 401 | seg_id: int32 Tensor in shape [len, bsz], the input segment IDs. 402 | input_mask: float32 Tensor in shape [len, bsz], the input mask. 403 | 0 for real tokens and 1 for padding. 404 | mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory 405 | from previous batches. The length of the list equals n_layer. 406 | If None, no memory is used. 407 | perm_mask: float32 Tensor in shape [len, len, bsz]. 408 | If perm_mask[i, j, k] = 0, i attend to j in batch k; 409 | if perm_mask[i, j, k] = 1, i does not attend to j in batch k. 410 | If None, each position attends to all the others. 411 | target_mapping: float32 Tensor in shape [num_predict, len, bsz]. 412 | If target_mapping[i, j, k] = 1, the i-th predict in batch k is 413 | on the j-th token. 414 | Only used during pretraining for partial prediction. 415 | Set to None during finetuning. 416 | inp_q: float32 Tensor in shape [len, bsz]. 417 | 1 for tokens with losses and 0 for tokens without losses. 418 | Only used during pretraining for two-stream attention. 419 | Set to None during finetuning. 420 | 421 | n_layer: int, the number of layers. 422 | d_model: int, the hidden size. 423 | n_head: int, the number of attention heads. 424 | d_head: int, the dimension size of each attention head. 425 | d_inner: int, the hidden size in feed-forward layers. 426 | ff_activation: str, "relu" or "gelu". 427 | untie_r: bool, whether to untie the biases in attention. 428 | n_token: int, the vocab size. 429 | 430 | is_training: bool, whether in training mode. 431 | use_tpu: bool, whether TPUs are used. 432 | use_bfloat16: bool, use bfloat16 instead of float32. 433 | dropout: float, dropout rate. 434 | dropatt: float, dropout rate on attention probabilities. 435 | init: str, the initialization scheme, either "normal" or "uniform". 436 | init_range: float, initialize the parameters with a uniform distribution 437 | in [-init_range, init_range]. Only effective when init="uniform". 438 | init_std: float, initialize the parameters with a normal distribution 439 | with mean 0 and stddev init_std. Only effective when init="normal". 440 | mem_len: int, the number of tokens to cache. 441 | reuse_len: int, the number of tokens in the currect batch to be cached 442 | and reused in the future. 443 | bi_data: bool, whether to use bidirectional input pipeline. 444 | Usually set to True during pretraining and False during finetuning. 445 | clamp_len: int, clamp all relative distances larger than clamp_len. 446 | -1 means no clamping. 447 | same_length: bool, whether to use the same attention length for each token. 448 | summary_type: str, "last", "first", "mean", or "attn". The method 449 | to pool the input to get a vector representation. 450 | initializer: A tf initializer. 451 | scope: scope name for the computation graph. 452 | """ 453 | tf.logging.info('memory input {}'.format(mems)) 454 | tf_float = tf.bfloat16 if use_bfloat16 else tf.float32 455 | tf.logging.info('Use float type {}'.format(tf_float)) 456 | 457 | new_mems = [] 458 | with tf.variable_scope(scope): 459 | if untie_r: 460 | r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], 461 | dtype=tf_float, initializer=initializer) 462 | r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], 463 | dtype=tf_float, initializer=initializer) 464 | else: 465 | r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], 466 | dtype=tf_float, initializer=initializer) 467 | r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], 468 | dtype=tf_float, initializer=initializer) 469 | 470 | bsz = tf.shape(inp_k)[1] 471 | qlen = tf.shape(inp_k)[0] 472 | mlen = tf.shape(mems[0])[0] if mems is not None else 0 473 | klen = mlen + qlen 474 | 475 | ##### Attention mask 476 | # causal attention mask 477 | if attn_type == 'uni': 478 | attn_mask = _create_mask(qlen, mlen, tf_float, same_length) 479 | attn_mask = attn_mask[:, :, None, None] 480 | elif attn_type == 'bi': 481 | attn_mask = None 482 | else: 483 | raise ValueError('Unsupported attention type: {}'.format(attn_type)) 484 | 485 | # data mask: input mask & perm mask 486 | if input_mask is not None and perm_mask is not None: 487 | data_mask = input_mask[None] + perm_mask 488 | elif input_mask is not None and perm_mask is None: 489 | data_mask = input_mask[None] 490 | elif input_mask is None and perm_mask is not None: 491 | data_mask = perm_mask 492 | else: 493 | data_mask = None 494 | 495 | if data_mask is not None: 496 | # all mems can be attended to 497 | mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], 498 | dtype=tf_float) 499 | data_mask = tf.concat([mems_mask, data_mask], 1) 500 | if attn_mask is None: 501 | attn_mask = data_mask[:, :, :, None] 502 | else: 503 | attn_mask += data_mask[:, :, :, None] 504 | 505 | if attn_mask is not None: 506 | attn_mask = tf.cast(attn_mask > 0, dtype=tf_float) 507 | 508 | if attn_mask is not None: 509 | non_tgt_mask = -tf.eye(qlen, dtype=tf_float) 510 | non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float), 511 | non_tgt_mask], axis=-1) 512 | non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, 513 | dtype=tf_float) 514 | else: 515 | non_tgt_mask = None 516 | 517 | ##### Word embedding 518 | word_emb_k, lookup_table = embedding_lookup( 519 | x=inp_k, 520 | n_token=n_token, 521 | d_embed=d_model, 522 | initializer=initializer, 523 | use_tpu=use_tpu, 524 | dtype=tf_float, 525 | scope='word_embedding') 526 | 527 | if inp_q is not None: 528 | with tf.variable_scope('mask_emb'): 529 | mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], dtype=tf_float) 530 | if target_mapping is not None: 531 | word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) 532 | else: 533 | inp_q_ext = inp_q[:, :, None] 534 | word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k 535 | output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training) 536 | if inp_q is not None: 537 | output_g = tf.layers.dropout(word_emb_q, dropout, training=is_training) 538 | 539 | ##### Segment embedding 540 | if seg_id is not None: 541 | if untie_r: 542 | r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head], 543 | dtype=tf_float, initializer=initializer) 544 | else: 545 | # default case (tie) 546 | r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head], 547 | dtype=tf_float, initializer=initializer) 548 | 549 | seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head], 550 | dtype=tf_float, initializer=initializer) 551 | 552 | # Convert `seg_id` to one-hot `seg_mat` 553 | mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) 554 | cat_ids = tf.concat([mem_pad, seg_id], 0) 555 | 556 | # `1` indicates not in the same segment [qlen x klen x bsz] 557 | seg_mat = tf.cast( 558 | tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])), 559 | tf.int32) 560 | seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float) 561 | else: 562 | seg_mat = None 563 | 564 | ##### Positional encoding 565 | pos_emb = relative_positional_encoding( 566 | qlen, klen, d_model, clamp_len, attn_type, bi_data, 567 | bsz=bsz, dtype=tf_float) 568 | pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) 569 | 570 | ##### Attention layers 571 | if mems is None: 572 | mems = [None] * n_layer 573 | 574 | for i in range(n_layer): 575 | # cache new mems 576 | new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len)) 577 | 578 | # segment bias 579 | if seg_id is None: 580 | r_s_bias_i = None 581 | seg_embed_i = None 582 | else: 583 | r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i] 584 | seg_embed_i = seg_embed[i] 585 | 586 | with tf.variable_scope('layer_{}'.format(i)): 587 | if inp_q is not None: 588 | output_h, output_g = two_stream_rel_attn( 589 | h=output_h, 590 | g=output_g, 591 | r=pos_emb, 592 | r_w_bias=r_w_bias if not untie_r else r_w_bias[i], 593 | r_r_bias=r_r_bias if not untie_r else r_r_bias[i], 594 | seg_mat=seg_mat, 595 | r_s_bias=r_s_bias_i, 596 | seg_embed=seg_embed_i, 597 | attn_mask_h=non_tgt_mask, 598 | attn_mask_g=attn_mask, 599 | mems=mems[i], 600 | target_mapping=target_mapping, 601 | d_model=d_model, 602 | n_head=n_head, 603 | d_head=d_head, 604 | dropout=dropout, 605 | dropatt=dropatt, 606 | is_training=is_training, 607 | kernel_initializer=initializer) 608 | reuse = True 609 | else: 610 | reuse = False 611 | 612 | output_h = rel_multihead_attn( 613 | h=output_h, 614 | r=pos_emb, 615 | r_w_bias=r_w_bias if not untie_r else r_w_bias[i], 616 | r_r_bias=r_r_bias if not untie_r else r_r_bias[i], 617 | seg_mat=seg_mat, 618 | r_s_bias=r_s_bias_i, 619 | seg_embed=seg_embed_i, 620 | attn_mask=non_tgt_mask, 621 | mems=mems[i], 622 | d_model=d_model, 623 | n_head=n_head, 624 | d_head=d_head, 625 | dropout=dropout, 626 | dropatt=dropatt, 627 | is_training=is_training, 628 | kernel_initializer=initializer, 629 | reuse=reuse) 630 | 631 | if inp_q is not None: 632 | output_g = positionwise_ffn( 633 | inp=output_g, 634 | d_model=d_model, 635 | d_inner=d_inner, 636 | dropout=dropout, 637 | kernel_initializer=initializer, 638 | activation_type=ff_activation, 639 | is_training=is_training) 640 | 641 | output_h = positionwise_ffn( 642 | inp=output_h, 643 | d_model=d_model, 644 | d_inner=d_inner, 645 | dropout=dropout, 646 | kernel_initializer=initializer, 647 | activation_type=ff_activation, 648 | is_training=is_training, 649 | reuse=reuse) 650 | 651 | if inp_q is not None: 652 | output = tf.layers.dropout(output_g, dropout, training=is_training) 653 | else: 654 | output = tf.layers.dropout(output_h, dropout, training=is_training) 655 | 656 | return output, new_mems, lookup_table 657 | 658 | 659 | def lm_loss(hidden, target, n_token, d_model, initializer, lookup_table=None, 660 | tie_weight=False, bi_data=True, use_tpu=False): 661 | """doc.""" 662 | 663 | with tf.variable_scope('lm_loss'): 664 | if tie_weight: 665 | assert lookup_table is not None, \ 666 | 'lookup_table cannot be None for tie_weight' 667 | softmax_w = lookup_table 668 | else: 669 | softmax_w = tf.get_variable('weight', [n_token, d_model], 670 | dtype=hidden.dtype, initializer=initializer) 671 | 672 | softmax_b = tf.get_variable('bias', [n_token], dtype=hidden.dtype, 673 | initializer=tf.zeros_initializer()) 674 | 675 | logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b 676 | 677 | if use_tpu: 678 | one_hot_target = tf.one_hot(target, n_token, dtype=logits.dtype) 679 | loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1) 680 | else: 681 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, 682 | logits=logits) 683 | 684 | return loss 685 | 686 | 687 | def summarize_sequence(summary_type, hidden, d_model, n_head, d_head, dropout, 688 | dropatt, input_mask, is_training, initializer, 689 | scope=None, reuse=None, use_proj=True): 690 | 691 | """ 692 | Different classification tasks may not may not share the same parameters 693 | to summarize the sequence features. 694 | 695 | If shared, one can keep the `scope` to the default value `None`. 696 | Otherwise, one should specify a different `scope` for each task. 697 | """ 698 | 699 | with tf.variable_scope(scope, 'sequnece_summary', reuse=reuse): 700 | if summary_type == 'last': 701 | summary = hidden[-1] 702 | elif summary_type == 'first': 703 | summary = hidden[0] 704 | elif summary_type == 'mean': 705 | summary = tf.reduce_mean(hidden, axis=0) 706 | elif summary_type == 'attn': 707 | bsz = tf.shape(hidden)[1] 708 | 709 | summary_bias = tf.get_variable('summary_bias', [d_model], 710 | dtype=hidden.dtype, 711 | initializer=initializer) 712 | summary_bias = tf.tile(summary_bias[None, None], [1, bsz, 1]) 713 | 714 | if input_mask is not None: 715 | input_mask = input_mask[None, :, :, None] 716 | 717 | summary = multihead_attn(summary_bias, hidden, hidden, input_mask, 718 | d_model, n_head, d_head, dropout, dropatt, 719 | is_training, initializer, residual=False) 720 | summary = summary[0] 721 | else: 722 | raise ValueError('Unsupported summary type {}'.format(summary_type)) 723 | 724 | # use another projection as in BERT 725 | if use_proj: 726 | summary = tf.layers.dense( 727 | summary, 728 | d_model, 729 | activation=tf.tanh, 730 | kernel_initializer=initializer, 731 | name='summary') 732 | 733 | # dropout 734 | summary = tf.layers.dropout( 735 | summary, dropout, training=is_training, 736 | name='dropout') 737 | 738 | return summary 739 | 740 | 741 | def classification_loss(hidden, labels, n_class, initializer, scope, reuse=None, 742 | return_logits=False): 743 | """ 744 | Different classification tasks should use different scope names to ensure 745 | different dense layers (parameters) are used to produce the logits. 746 | 747 | An exception will be in transfer learning, where one hopes to transfer 748 | the classification weights. 749 | """ 750 | 751 | with tf.variable_scope(scope, reuse=reuse): 752 | logits = tf.layers.dense( 753 | hidden, 754 | n_class, 755 | kernel_initializer=initializer, 756 | name='logit') 757 | 758 | one_hot_target = tf.one_hot(labels, n_class, dtype=hidden.dtype) 759 | loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1) 760 | 761 | if return_logits: 762 | return loss, logits 763 | 764 | return loss 765 | 766 | 767 | def regression_loss(hidden, labels, initializer, scope, reuse=None, 768 | return_logits=False): 769 | with tf.variable_scope(scope, reuse=reuse): 770 | logits = tf.layers.dense( 771 | hidden, 772 | 1, 773 | kernel_initializer=initializer, 774 | name='logit') 775 | 776 | logits = tf.squeeze(logits, axis=-1) 777 | loss = tf.square(logits - labels) 778 | 779 | if return_logits: 780 | return loss, logits 781 | 782 | return loss 783 | 784 | -------------------------------------------------------------------------------- /notebooks/colab_imdb_gpu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "colab_type": "text", 17 | "id": "fnOHnctkG6kW" 18 | }, 19 | "source": [ 20 | "# XLNet IMDB movie review classification project\n", 21 | "\n", 22 | "This notebook is for classifying the [imdb sentiment dataset](https://ai.stanford.edu/~amaas/data/sentiment/). It will be easy to edit this notebook in order to run all of the classification tasks referenced in the [XLNet paper](https://arxiv.org/abs/1906.08237). Whilst you cannot expect to obtain the state-of-the-art results in the paper on a GPU, this model will still score very highly. " 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": { 28 | "colab_type": "text", 29 | "id": "2mBzLdrdzodb" 30 | }, 31 | "source": [ 32 | "## Setup\n", 33 | "Install dependencies" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 0, 39 | "metadata": { 40 | "colab": {}, 41 | "colab_type": "code", 42 | "id": "hRHRPImGUth7" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "! pip install sentencepiece" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": { 52 | "colab_type": "text", 53 | "id": "jy8gUsPuJNyw" 54 | }, 55 | "source": [ 56 | "Download the pretrained XLNet model and unzip" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 0, 62 | "metadata": { 63 | "colab": {}, 64 | "colab_type": "code", 65 | "id": "HfPDGsUtHKG0" 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "# only needs to be done once\n", 70 | "! wget https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip\n", 71 | "! unzip cased_L-24_H-1024_A-16.zip " 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": { 77 | "colab_type": "text", 78 | "id": "4uUwjq3BJRbu" 79 | }, 80 | "source": [ 81 | "Download extract the imdb dataset - surpessing output" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 0, 87 | "metadata": { 88 | "colab": {}, 89 | "colab_type": "code", 90 | "id": "QOGRICbOIsU8" 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "! wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n", 95 | "! tar zxf aclImdb_v1.tar.gz" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": { 101 | "colab_type": "text", 102 | "id": "yGY_ggUUMrwU" 103 | }, 104 | "source": [ 105 | "Git clone XLNet repo for access to run_classifier and the rest of the xlnet module" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 0, 111 | "metadata": { 112 | "colab": {}, 113 | "colab_type": "code", 114 | "id": "-r190eYVMpiG" 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "! git clone https://github.com/zihangdai/xlnet.git" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": { 124 | "colab_type": "text", 125 | "id": "jDP-IaVuPC-z" 126 | }, 127 | "source": [ 128 | "## Define Variables\n", 129 | "Define all the dirs: data, xlnet scripts & pretrained model. \n", 130 | "If you would like to save models then you can authenticate a GCP account and use that for the OUTPUT_DIR & CHECKPOINT_DIR - you will need a large amount storage to fix these models. \n", 131 | "\n", 132 | "Alternatively it is easy to integrate a google drive account, checkout this guide for [I/O in colab](https://colab.research.google.com/notebooks/io.ipynb) but rememeber these will take up a large amount of storage. \n" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 0, 138 | "metadata": { 139 | "colab": {}, 140 | "colab_type": "code", 141 | "id": "y7N_xVwavQlV" 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "SCRIPTS_DIR = 'xlnet' #@param {type:\"string\"}\n", 146 | "DATA_DIR = 'aclImdb' #@param {type:\"string\"}\n", 147 | "OUTPUT_DIR = 'proc_data/imdb' #@param {type:\"string\"}\n", 148 | "PRETRAINED_MODEL_DIR = 'xlnet_cased_L-24_H-1024_A-16' #@param {type:\"string\"}\n", 149 | "CHECKPOINT_DIR = 'exp/imdb' #@param {type:\"string\"}" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "colab_type": "text", 156 | "id": "jR6euqwL1KBV" 157 | }, 158 | "source": [ 159 | "## Run Model\n", 160 | "This will set off the fine tuning of XLNet. There are a few things to note here:\n", 161 | "\n", 162 | "\n", 163 | "1. This script will train and evaluate the model\n", 164 | "2. This will store the results locally on colab and will be lost when you are disconnected from the runtime\n", 165 | "3. This uses the large version of the model (base not released presently)\n", 166 | "4. We are using a max seq length of 128 with a batch size of 8 please refer to the [README](https://github.com/zihangdai/xlnet#memory-issue-during-finetuning) for why this is.\n", 167 | "5. This will take approx 4hrs to run on GPU.\n", 168 | "\n" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 0, 174 | "metadata": { 175 | "colab": {}, 176 | "colab_type": "code", 177 | "id": "CEMuT6LU0avg" 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "train_command = \"python xlnet/run_classifier.py \\\n", 182 | " --do_train=True \\\n", 183 | " --do_eval=True \\\n", 184 | " --eval_all_ckpt=True \\\n", 185 | " --task_name=imdb \\\n", 186 | " --data_dir=\"+DATA_DIR+\" \\\n", 187 | " --output_dir=\"+OUTPUT_DIR+\" \\\n", 188 | " --model_dir=\"+CHECKPOINT_DIR+\" \\\n", 189 | " --uncased=False \\\n", 190 | " --spiece_model_file=\"+PRETRAINED_MODEL_DIR+\"/spiece.model \\\n", 191 | " --model_config_path=\"+PRETRAINED_MODEL_DIR+\"/xlnet_config.json \\\n", 192 | " --init_checkpoint=\"+PRETRAINED_MODEL_DIR+\"/xlnet_model.ckpt \\\n", 193 | " --max_seq_length=128 \\\n", 194 | " --train_batch_size=8 \\\n", 195 | " --eval_batch_size=8 \\\n", 196 | " --num_hosts=1 \\\n", 197 | " --num_core_per_host=1 \\\n", 198 | " --learning_rate=2e-5 \\\n", 199 | " --train_steps=4000 \\\n", 200 | " --warmup_steps=500 \\\n", 201 | " --save_steps=500 \\\n", 202 | " --iterations=500\"\n", 203 | "\n", 204 | "! {train_command}\n" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": { 210 | "colab_type": "text", 211 | "id": "VvhqD-sO0Kyh" 212 | }, 213 | "source": [ 214 | "## Running & Results\n", 215 | "These are the results that I got from running this experiment\n", 216 | "### Params\n", 217 | "* --max_seq_length=128 \\\n", 218 | "* --train_batch_size= 8 \n", 219 | "\n", 220 | "### Times\n", 221 | "* Training: 1hr 11mins\n", 222 | "* Evaluation: 2.5hr\n", 223 | "\n", 224 | "### Results\n", 225 | "* Most accurate model on final step\n", 226 | "* Accuracy: 0.92416, eval_loss: 0.31708\n" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": { 232 | "colab_type": "text", 233 | "id": "XUW2avFM_fi_" 234 | }, 235 | "source": [ 236 | "### Model\n", 237 | "\n", 238 | "* The trained model checkpoints can be found in 'exp/imdb'\n", 239 | "\n" 240 | ] 241 | } 242 | ], 243 | "metadata": { 244 | "accelerator": "GPU", 245 | "colab": { 246 | "collapsed_sections": [], 247 | "include_colab_link": true, 248 | "name": "XLNet-imdb-GPU.ipynb", 249 | "provenance": [], 250 | "toc_visible": true, 251 | "version": "0.3.2" 252 | }, 253 | "kernelspec": { 254 | "display_name": "Python 3", 255 | "language": "python", 256 | "name": "python3" 257 | } 258 | }, 259 | "nbformat": 4, 260 | "nbformat_minor": 1 261 | } 262 | -------------------------------------------------------------------------------- /prepro_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import unicodedata 7 | import six 8 | from functools import partial 9 | 10 | 11 | SPIECE_UNDERLINE = '▁' 12 | 13 | 14 | def printable_text(text): 15 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 16 | 17 | # These functions want `str` for both Python2 and Python3, but in one case 18 | # it's a Unicode string and in the other it's a byte string. 19 | if six.PY3: 20 | if isinstance(text, str): 21 | return text 22 | elif isinstance(text, bytes): 23 | return text.decode("utf-8", "ignore") 24 | else: 25 | raise ValueError("Unsupported string type: %s" % (type(text))) 26 | elif six.PY2: 27 | if isinstance(text, str): 28 | return text 29 | elif isinstance(text, unicode): 30 | return text.encode("utf-8") 31 | else: 32 | raise ValueError("Unsupported string type: %s" % (type(text))) 33 | else: 34 | raise ValueError("Not running on Python2 or Python 3?") 35 | 36 | 37 | def print_(*args): 38 | new_args = [] 39 | for arg in args: 40 | if isinstance(arg, list): 41 | s = [printable_text(i) for i in arg] 42 | s = ' '.join(s) 43 | new_args.append(s) 44 | else: 45 | new_args.append(printable_text(arg)) 46 | print(*new_args) 47 | 48 | 49 | def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False): 50 | if remove_space: 51 | outputs = ' '.join(inputs.strip().split()) 52 | else: 53 | outputs = inputs 54 | outputs = outputs.replace("``", '"').replace("''", '"') 55 | 56 | if six.PY2 and isinstance(outputs, str): 57 | outputs = outputs.decode('utf-8') 58 | 59 | if not keep_accents: 60 | outputs = unicodedata.normalize('NFKD', outputs) 61 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) 62 | if lower: 63 | outputs = outputs.lower() 64 | 65 | return outputs 66 | 67 | 68 | def encode_pieces(sp_model, text, return_unicode=True, sample=False): 69 | # return_unicode is used only for py2 70 | 71 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2 72 | if six.PY2 and isinstance(text, unicode): 73 | text = text.encode('utf-8') 74 | 75 | if not sample: 76 | pieces = sp_model.EncodeAsPieces(text) 77 | else: 78 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) 79 | new_pieces = [] 80 | for piece in pieces: 81 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): 82 | cur_pieces = sp_model.EncodeAsPieces( 83 | piece[:-1].replace(SPIECE_UNDERLINE, '')) 84 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 85 | if len(cur_pieces[0]) == 1: 86 | cur_pieces = cur_pieces[1:] 87 | else: 88 | cur_pieces[0] = cur_pieces[0][1:] 89 | cur_pieces.append(piece[-1]) 90 | new_pieces.extend(cur_pieces) 91 | else: 92 | new_pieces.append(piece) 93 | 94 | # note(zhiliny): convert back to unicode for py2 95 | if six.PY2 and return_unicode: 96 | ret_pieces = [] 97 | for piece in new_pieces: 98 | if isinstance(piece, str): 99 | piece = piece.decode('utf-8') 100 | ret_pieces.append(piece) 101 | new_pieces = ret_pieces 102 | 103 | return new_pieces 104 | 105 | 106 | def encode_ids(sp_model, text, sample=False): 107 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) 108 | ids = [sp_model.PieceToId(piece) for piece in pieces] 109 | return ids 110 | 111 | 112 | if __name__ == '__main__': 113 | import sentencepiece as spm 114 | 115 | sp = spm.SentencePieceProcessor() 116 | sp.load('sp10m.uncased.v3.model') 117 | 118 | print_(u'I was born in 2000, and this is falsé.') 119 | print_(u'ORIGINAL', sp.EncodeAsPieces(u'I was born in 2000, and this is falsé.')) 120 | print_(u'OURS', encode_pieces(sp, u'I was born in 2000, and this is falsé.')) 121 | print(encode_ids(sp, u'I was born in 2000, and this is falsé.')) 122 | print_('') 123 | prepro_func = partial(preprocess_text, lower=True) 124 | print_(prepro_func('I was born in 2000, and this is falsé.')) 125 | print_('ORIGINAL', sp.EncodeAsPieces(prepro_func('I was born in 2000, and this is falsé.'))) 126 | print_('OURS', encode_pieces(sp, prepro_func('I was born in 2000, and this is falsé.'))) 127 | print(encode_ids(sp, prepro_func('I was born in 2000, and this is falsé.'))) 128 | print_('') 129 | print_('I was born in 2000, and this is falsé.') 130 | print_('ORIGINAL', sp.EncodeAsPieces('I was born in 2000, and this is falsé.')) 131 | print_('OURS', encode_pieces(sp, 'I was born in 2000, and this is falsé.')) 132 | print(encode_ids(sp, 'I was born in 2000, and this is falsé.')) 133 | print_('') 134 | print_('I was born in 92000, and this is falsé.') 135 | print_('ORIGINAL', sp.EncodeAsPieces('I was born in 92000, and this is falsé.')) 136 | print_('OURS', encode_pieces(sp, 'I was born in 92000, and this is falsé.')) 137 | print(encode_ids(sp, 'I was born in 92000, and this is falsé.')) 138 | 139 | -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from os.path import join 6 | from absl import flags 7 | import os 8 | import sys 9 | import csv 10 | import collections 11 | import numpy as np 12 | import time 13 | import math 14 | import json 15 | import random 16 | from copy import copy 17 | from collections import defaultdict as dd 18 | 19 | import absl.logging as _logging # pylint: disable=unused-import 20 | import tensorflow as tf 21 | 22 | import sentencepiece as spm 23 | 24 | from data_utils import SEP_ID, VOCAB_SIZE, CLS_ID 25 | import model_utils 26 | import function_builder 27 | from classifier_utils import PaddingInputExample 28 | from classifier_utils import convert_single_example 29 | from prepro_utils import preprocess_text, encode_ids 30 | 31 | 32 | # Model 33 | flags.DEFINE_string("model_config_path", default=None, 34 | help="Model config path.") 35 | flags.DEFINE_float("dropout", default=0.1, 36 | help="Dropout rate.") 37 | flags.DEFINE_float("dropatt", default=0.1, 38 | help="Attention dropout rate.") 39 | flags.DEFINE_integer("clamp_len", default=-1, 40 | help="Clamp length") 41 | flags.DEFINE_string("summary_type", default="last", 42 | help="Method used to summarize a sequence into a compact vector.") 43 | flags.DEFINE_bool("use_summ_proj", default=True, 44 | help="Whether to use projection for summarizing sequences.") 45 | flags.DEFINE_bool("use_bfloat16", False, 46 | help="Whether to use bfloat16.") 47 | 48 | # Parameter initialization 49 | flags.DEFINE_enum("init", default="normal", 50 | enum_values=["normal", "uniform"], 51 | help="Initialization method.") 52 | flags.DEFINE_float("init_std", default=0.02, 53 | help="Initialization std when init is normal.") 54 | flags.DEFINE_float("init_range", default=0.1, 55 | help="Initialization std when init is uniform.") 56 | 57 | # I/O paths 58 | flags.DEFINE_bool("overwrite_data", default=False, 59 | help="If False, will use cached data if available.") 60 | flags.DEFINE_string("init_checkpoint", default=None, 61 | help="checkpoint path for initializing the model. " 62 | "Could be a pretrained model or a finetuned model.") 63 | flags.DEFINE_string("output_dir", default="", 64 | help="Output dir for TF records.") 65 | flags.DEFINE_string("spiece_model_file", default="", 66 | help="Sentence Piece model path.") 67 | flags.DEFINE_string("model_dir", default="", 68 | help="Directory for saving the finetuned model.") 69 | flags.DEFINE_string("data_dir", default="", 70 | help="Directory for input data.") 71 | 72 | # TPUs and machines 73 | flags.DEFINE_bool("use_tpu", default=False, help="whether to use TPU.") 74 | flags.DEFINE_integer("num_hosts", default=1, help="How many TPU hosts.") 75 | flags.DEFINE_integer("num_core_per_host", default=8, 76 | help="8 for TPU v2 and v3-8, 16 for larger TPU v3 pod. In the context " 77 | "of GPU training, it refers to the number of GPUs used.") 78 | flags.DEFINE_string("tpu_job_name", default=None, help="TPU worker job name.") 79 | flags.DEFINE_string("tpu", default=None, help="TPU name.") 80 | flags.DEFINE_string("tpu_zone", default=None, help="TPU zone.") 81 | flags.DEFINE_string("gcp_project", default=None, help="gcp project.") 82 | flags.DEFINE_string("master", default=None, help="master") 83 | flags.DEFINE_integer("iterations", default=1000, 84 | help="number of iterations per TPU training loop.") 85 | 86 | # training 87 | flags.DEFINE_bool("do_train", default=False, help="whether to do training") 88 | flags.DEFINE_integer("train_steps", default=1000, 89 | help="Number of training steps") 90 | flags.DEFINE_integer("warmup_steps", default=0, help="number of warmup steps") 91 | flags.DEFINE_float("learning_rate", default=1e-5, help="initial learning rate") 92 | flags.DEFINE_float("lr_layer_decay_rate", 1.0, 93 | "Top layer: lr[L] = FLAGS.learning_rate." 94 | "Low layer: lr[l-1] = lr[l] * lr_layer_decay_rate.") 95 | flags.DEFINE_float("min_lr_ratio", default=0.0, 96 | help="min lr ratio for cos decay.") 97 | flags.DEFINE_float("clip", default=1.0, help="Gradient clipping") 98 | flags.DEFINE_integer("max_save", default=0, 99 | help="Max number of checkpoints to save. Use 0 to save all.") 100 | flags.DEFINE_integer("save_steps", default=None, 101 | help="Save the model for every save_steps. " 102 | "If None, not to save any model.") 103 | flags.DEFINE_integer("train_batch_size", default=8, 104 | help="Batch size for training") 105 | flags.DEFINE_float("weight_decay", default=0.00, help="Weight decay rate") 106 | flags.DEFINE_float("adam_epsilon", default=1e-8, help="Adam epsilon") 107 | flags.DEFINE_string("decay_method", default="poly", help="poly or cos") 108 | 109 | # evaluation 110 | flags.DEFINE_bool("do_eval", default=False, help="whether to do eval") 111 | flags.DEFINE_bool("do_predict", default=False, help="whether to do prediction") 112 | flags.DEFINE_float("predict_threshold", default=0, 113 | help="Threshold for binary prediction.") 114 | flags.DEFINE_string("eval_split", default="dev", help="could be dev or test") 115 | flags.DEFINE_integer("eval_batch_size", default=128, 116 | help="batch size for evaluation") 117 | flags.DEFINE_integer("predict_batch_size", default=128, 118 | help="batch size for prediction.") 119 | flags.DEFINE_string("predict_dir", default=None, 120 | help="Dir for saving prediction files.") 121 | flags.DEFINE_bool("eval_all_ckpt", default=False, 122 | help="Eval all ckpts. If False, only evaluate the last one.") 123 | flags.DEFINE_string("predict_ckpt", default=None, 124 | help="Ckpt path for do_predict. If None, use the last one.") 125 | 126 | # task specific 127 | flags.DEFINE_string("task_name", default=None, help="Task name") 128 | flags.DEFINE_integer("max_seq_length", default=128, help="Max sequence length") 129 | flags.DEFINE_integer("shuffle_buffer", default=2048, 130 | help="Buffer size used for shuffle.") 131 | flags.DEFINE_integer("num_passes", default=1, 132 | help="Num passes for processing training data. " 133 | "This is use to batch data without loss for TPUs.") 134 | flags.DEFINE_bool("uncased", default=False, 135 | help="Use uncased.") 136 | flags.DEFINE_string("cls_scope", default=None, 137 | help="Classifier layer scope.") 138 | flags.DEFINE_bool("is_regression", default=False, 139 | help="Whether it's a regression task.") 140 | 141 | FLAGS = flags.FLAGS 142 | 143 | 144 | class InputExample(object): 145 | """A single training/test example for simple sequence classification.""" 146 | 147 | def __init__(self, guid, text_a, text_b=None, label=None): 148 | """Constructs a InputExample. 149 | Args: 150 | guid: Unique id for the example. 151 | text_a: string. The untokenized text of the first sequence. For single 152 | sequence tasks, only this sequence must be specified. 153 | text_b: (Optional) string. The untokenized text of the second sequence. 154 | Only must be specified for sequence pair tasks. 155 | label: (Optional) string. The label of the example. This should be 156 | specified for train and dev examples, but not for test examples. 157 | """ 158 | self.guid = guid 159 | self.text_a = text_a 160 | self.text_b = text_b 161 | self.label = label 162 | 163 | 164 | class DataProcessor(object): 165 | """Base class for data converters for sequence classification data sets.""" 166 | 167 | def get_train_examples(self, data_dir): 168 | """Gets a collection of `InputExample`s for the train set.""" 169 | raise NotImplementedError() 170 | 171 | def get_dev_examples(self, data_dir): 172 | """Gets a collection of `InputExample`s for the dev set.""" 173 | raise NotImplementedError() 174 | 175 | def get_test_examples(self, data_dir): 176 | """Gets a collection of `InputExample`s for prediction.""" 177 | raise NotImplementedError() 178 | 179 | def get_labels(self): 180 | """Gets the list of labels for this data set.""" 181 | raise NotImplementedError() 182 | 183 | @classmethod 184 | def _read_tsv(cls, input_file, quotechar=None): 185 | """Reads a tab separated value file.""" 186 | with tf.gfile.Open(input_file, "r") as f: 187 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 188 | lines = [] 189 | for line in reader: 190 | if len(line) == 0: continue 191 | lines.append(line) 192 | return lines 193 | 194 | 195 | class GLUEProcessor(DataProcessor): 196 | def __init__(self): 197 | self.train_file = "train.tsv" 198 | self.dev_file = "dev.tsv" 199 | self.test_file = "test.tsv" 200 | self.label_column = None 201 | self.text_a_column = None 202 | self.text_b_column = None 203 | self.contains_header = True 204 | self.test_text_a_column = None 205 | self.test_text_b_column = None 206 | self.test_contains_header = True 207 | 208 | def get_train_examples(self, data_dir): 209 | """See base class.""" 210 | return self._create_examples( 211 | self._read_tsv(os.path.join(data_dir, self.train_file)), "train") 212 | 213 | def get_dev_examples(self, data_dir): 214 | """See base class.""" 215 | return self._create_examples( 216 | self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev") 217 | 218 | def get_test_examples(self, data_dir): 219 | """See base class.""" 220 | if self.test_text_a_column is None: 221 | self.test_text_a_column = self.text_a_column 222 | if self.test_text_b_column is None: 223 | self.test_text_b_column = self.text_b_column 224 | 225 | return self._create_examples( 226 | self._read_tsv(os.path.join(data_dir, self.test_file)), "test") 227 | 228 | def get_labels(self): 229 | """See base class.""" 230 | return ["0", "1"] 231 | 232 | def _create_examples(self, lines, set_type): 233 | """Creates examples for the training and dev sets.""" 234 | examples = [] 235 | for (i, line) in enumerate(lines): 236 | if i == 0 and self.contains_header and set_type != "test": 237 | continue 238 | if i == 0 and self.test_contains_header and set_type == "test": 239 | continue 240 | guid = "%s-%s" % (set_type, i) 241 | 242 | a_column = (self.text_a_column if set_type != "test" else 243 | self.test_text_a_column) 244 | b_column = (self.text_b_column if set_type != "test" else 245 | self.test_text_b_column) 246 | 247 | # there are some incomplete lines in QNLI 248 | if len(line) <= a_column: 249 | tf.logging.warning('Incomplete line, ignored.') 250 | continue 251 | text_a = line[a_column] 252 | 253 | if b_column is not None: 254 | if len(line) <= b_column: 255 | tf.logging.warning('Incomplete line, ignored.') 256 | continue 257 | text_b = line[b_column] 258 | else: 259 | text_b = None 260 | 261 | if set_type == "test": 262 | label = self.get_labels()[0] 263 | else: 264 | if len(line) <= self.label_column: 265 | tf.logging.warning('Incomplete line, ignored.') 266 | continue 267 | label = line[self.label_column] 268 | examples.append( 269 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 270 | return examples 271 | 272 | 273 | class Yelp5Processor(DataProcessor): 274 | def get_train_examples(self, data_dir): 275 | return self._create_examples(os.path.join(data_dir, "train.csv")) 276 | 277 | def get_dev_examples(self, data_dir): 278 | return self._create_examples(os.path.join(data_dir, "test.csv")) 279 | 280 | def get_labels(self): 281 | """See base class.""" 282 | return ["1", "2", "3", "4", "5"] 283 | 284 | def _create_examples(self, input_file): 285 | """Creates examples for the training and dev sets.""" 286 | examples = [] 287 | with tf.gfile.Open(input_file) as f: 288 | reader = csv.reader(f) 289 | for i, line in enumerate(reader): 290 | 291 | label = line[0] 292 | text_a = line[1].replace('""', '"').replace('\\"', '"') 293 | examples.append( 294 | InputExample(guid=str(i), text_a=text_a, text_b=None, label=label)) 295 | return examples 296 | 297 | 298 | class ImdbProcessor(DataProcessor): 299 | def get_labels(self): 300 | return ["neg", "pos"] 301 | 302 | def get_train_examples(self, data_dir): 303 | return self._create_examples(os.path.join(data_dir, "train")) 304 | 305 | def get_dev_examples(self, data_dir): 306 | return self._create_examples(os.path.join(data_dir, "test")) 307 | 308 | def _create_examples(self, data_dir): 309 | examples = [] 310 | for label in ["neg", "pos"]: 311 | cur_dir = os.path.join(data_dir, label) 312 | for filename in tf.gfile.ListDirectory(cur_dir): 313 | if not filename.endswith("txt"): continue 314 | 315 | path = os.path.join(cur_dir, filename) 316 | with tf.gfile.Open(path) as f: 317 | text = f.read().strip().replace("
", " ") 318 | examples.append(InputExample( 319 | guid="unused_id", text_a=text, text_b=None, label=label)) 320 | return examples 321 | 322 | 323 | class MnliMatchedProcessor(GLUEProcessor): 324 | def __init__(self): 325 | super(MnliMatchedProcessor, self).__init__() 326 | self.dev_file = "dev_matched.tsv" 327 | self.test_file = "test_matched.tsv" 328 | self.label_column = -1 329 | self.text_a_column = 8 330 | self.text_b_column = 9 331 | 332 | def get_labels(self): 333 | return ["contradiction", "entailment", "neutral"] 334 | 335 | 336 | class MnliMismatchedProcessor(MnliMatchedProcessor): 337 | def __init__(self): 338 | super(MnliMismatchedProcessor, self).__init__() 339 | self.dev_file = "dev_mismatched.tsv" 340 | self.test_file = "test_mismatched.tsv" 341 | 342 | 343 | class StsbProcessor(GLUEProcessor): 344 | def __init__(self): 345 | super(StsbProcessor, self).__init__() 346 | self.label_column = 9 347 | self.text_a_column = 7 348 | self.text_b_column = 8 349 | 350 | def get_labels(self): 351 | return [0.0] 352 | 353 | def _create_examples(self, lines, set_type): 354 | """Creates examples for the training and dev sets.""" 355 | examples = [] 356 | for (i, line) in enumerate(lines): 357 | if i == 0 and self.contains_header and set_type != "test": 358 | continue 359 | if i == 0 and self.test_contains_header and set_type == "test": 360 | continue 361 | guid = "%s-%s" % (set_type, i) 362 | 363 | a_column = (self.text_a_column if set_type != "test" else 364 | self.test_text_a_column) 365 | b_column = (self.text_b_column if set_type != "test" else 366 | self.test_text_b_column) 367 | 368 | # there are some incomplete lines in QNLI 369 | if len(line) <= a_column: 370 | tf.logging.warning('Incomplete line, ignored.') 371 | continue 372 | text_a = line[a_column] 373 | 374 | if b_column is not None: 375 | if len(line) <= b_column: 376 | tf.logging.warning('Incomplete line, ignored.') 377 | continue 378 | text_b = line[b_column] 379 | else: 380 | text_b = None 381 | 382 | if set_type == "test": 383 | label = self.get_labels()[0] 384 | else: 385 | if len(line) <= self.label_column: 386 | tf.logging.warning('Incomplete line, ignored.') 387 | continue 388 | label = float(line[self.label_column]) 389 | examples.append( 390 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 391 | 392 | return examples 393 | 394 | 395 | def file_based_convert_examples_to_features( 396 | examples, label_list, max_seq_length, tokenize_fn, output_file, 397 | num_passes=1): 398 | """Convert a set of `InputExample`s to a TFRecord file.""" 399 | 400 | # do not create duplicated records 401 | if tf.gfile.Exists(output_file) and not FLAGS.overwrite_data: 402 | tf.logging.info("Do not overwrite tfrecord {} exists.".format(output_file)) 403 | return 404 | 405 | tf.logging.info("Create new tfrecord {}.".format(output_file)) 406 | 407 | writer = tf.python_io.TFRecordWriter(output_file) 408 | 409 | if num_passes > 1: 410 | examples *= num_passes 411 | 412 | for (ex_index, example) in enumerate(examples): 413 | if ex_index % 10000 == 0: 414 | tf.logging.info("Writing example {} of {}".format(ex_index, 415 | len(examples))) 416 | 417 | feature = convert_single_example(ex_index, example, label_list, 418 | max_seq_length, tokenize_fn) 419 | 420 | def create_int_feature(values): 421 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 422 | return f 423 | 424 | def create_float_feature(values): 425 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 426 | return f 427 | 428 | features = collections.OrderedDict() 429 | features["input_ids"] = create_int_feature(feature.input_ids) 430 | features["input_mask"] = create_float_feature(feature.input_mask) 431 | features["segment_ids"] = create_int_feature(feature.segment_ids) 432 | if label_list is not None: 433 | features["label_ids"] = create_int_feature([feature.label_id]) 434 | else: 435 | features["label_ids"] = create_float_feature([float(feature.label_id)]) 436 | features["is_real_example"] = create_int_feature( 437 | [int(feature.is_real_example)]) 438 | 439 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 440 | writer.write(tf_example.SerializeToString()) 441 | writer.close() 442 | 443 | 444 | def file_based_input_fn_builder(input_file, seq_length, is_training, 445 | drop_remainder): 446 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 447 | 448 | 449 | name_to_features = { 450 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 451 | "input_mask": tf.FixedLenFeature([seq_length], tf.float32), 452 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 453 | "label_ids": tf.FixedLenFeature([], tf.int64), 454 | "is_real_example": tf.FixedLenFeature([], tf.int64), 455 | } 456 | if FLAGS.is_regression: 457 | name_to_features["label_ids"] = tf.FixedLenFeature([], tf.float32) 458 | 459 | tf.logging.info("Input tfrecord file {}".format(input_file)) 460 | 461 | def _decode_record(record, name_to_features): 462 | """Decodes a record to a TensorFlow example.""" 463 | example = tf.parse_single_example(record, name_to_features) 464 | 465 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 466 | # So cast all int64 to int32. 467 | for name in list(example.keys()): 468 | t = example[name] 469 | if t.dtype == tf.int64: 470 | t = tf.cast(t, tf.int32) 471 | example[name] = t 472 | 473 | return example 474 | 475 | def input_fn(params, input_context=None): 476 | """The actual input function.""" 477 | if FLAGS.use_tpu: 478 | batch_size = params["batch_size"] 479 | elif is_training: 480 | batch_size = FLAGS.train_batch_size 481 | elif FLAGS.do_eval: 482 | batch_size = FLAGS.eval_batch_size 483 | else: 484 | batch_size = FLAGS.predict_batch_size 485 | 486 | d = tf.data.TFRecordDataset(input_file) 487 | # Shard the dataset to difference devices 488 | if input_context is not None: 489 | tf.logging.info("Input pipeline id %d out of %d", 490 | input_context.input_pipeline_id, input_context.num_replicas_in_sync) 491 | d = d.shard(input_context.num_input_pipelines, 492 | input_context.input_pipeline_id) 493 | 494 | # For training, we want a lot of parallel reading and shuffling. 495 | # For eval, we want no shuffling and parallel reading doesn't matter. 496 | if is_training: 497 | d = d.shuffle(buffer_size=FLAGS.shuffle_buffer) 498 | d = d.repeat() 499 | 500 | d = d.apply( 501 | tf.contrib.data.map_and_batch( 502 | lambda record: _decode_record(record, name_to_features), 503 | batch_size=batch_size, 504 | drop_remainder=drop_remainder)) 505 | 506 | return d 507 | 508 | return input_fn 509 | 510 | 511 | def get_model_fn(n_class): 512 | def model_fn(features, labels, mode, params): 513 | #### Training or Evaluation 514 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 515 | 516 | #### Get loss from inputs 517 | if FLAGS.is_regression: 518 | (total_loss, per_example_loss, logits 519 | ) = function_builder.get_regression_loss(FLAGS, features, is_training) 520 | else: 521 | (total_loss, per_example_loss, logits 522 | ) = function_builder.get_classification_loss( 523 | FLAGS, features, n_class, is_training) 524 | 525 | #### Check model parameters 526 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) 527 | tf.logging.info('#params: {}'.format(num_params)) 528 | 529 | #### load pretrained models 530 | scaffold_fn = model_utils.init_from_checkpoint(FLAGS) 531 | 532 | #### Evaluation mode 533 | if mode == tf.estimator.ModeKeys.EVAL: 534 | assert FLAGS.num_hosts == 1 535 | 536 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 537 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 538 | eval_input_dict = { 539 | 'labels': label_ids, 540 | 'predictions': predictions, 541 | 'weights': is_real_example 542 | } 543 | accuracy = tf.metrics.accuracy(**eval_input_dict) 544 | 545 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 546 | return { 547 | 'eval_accuracy': accuracy, 548 | 'eval_loss': loss} 549 | 550 | def regression_metric_fn( 551 | per_example_loss, label_ids, logits, is_real_example): 552 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 553 | pearsonr = tf.contrib.metrics.streaming_pearson_correlation( 554 | logits, label_ids, weights=is_real_example) 555 | return {'eval_loss': loss, 'eval_pearsonr': pearsonr} 556 | 557 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 558 | 559 | #### Constucting evaluation TPUEstimatorSpec with new cache. 560 | label_ids = tf.reshape(features['label_ids'], [-1]) 561 | 562 | if FLAGS.is_regression: 563 | metric_fn = regression_metric_fn 564 | else: 565 | metric_fn = metric_fn 566 | metric_args = [per_example_loss, label_ids, logits, is_real_example] 567 | 568 | if FLAGS.use_tpu: 569 | eval_spec = tf.contrib.tpu.TPUEstimatorSpec( 570 | mode=mode, 571 | loss=total_loss, 572 | eval_metrics=(metric_fn, metric_args), 573 | scaffold_fn=scaffold_fn) 574 | else: 575 | eval_spec = tf.estimator.EstimatorSpec( 576 | mode=mode, 577 | loss=total_loss, 578 | eval_metric_ops=metric_fn(*metric_args)) 579 | 580 | return eval_spec 581 | 582 | elif mode == tf.estimator.ModeKeys.PREDICT: 583 | label_ids = tf.reshape(features["label_ids"], [-1]) 584 | 585 | predictions = { 586 | "logits": logits, 587 | "labels": label_ids, 588 | "is_real": features["is_real_example"] 589 | } 590 | 591 | if FLAGS.use_tpu: 592 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 593 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 594 | else: 595 | output_spec = tf.estimator.EstimatorSpec( 596 | mode=mode, predictions=predictions) 597 | return output_spec 598 | 599 | #### Configuring the optimizer 600 | train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss) 601 | 602 | monitor_dict = {} 603 | monitor_dict["lr"] = learning_rate 604 | 605 | #### Constucting training TPUEstimatorSpec with new cache. 606 | if FLAGS.use_tpu: 607 | #### Creating host calls 608 | if not FLAGS.is_regression: 609 | label_ids = tf.reshape(features['label_ids'], [-1]) 610 | predictions = tf.argmax(logits, axis=-1, output_type=label_ids.dtype) 611 | is_correct = tf.equal(predictions, label_ids) 612 | accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32)) 613 | 614 | monitor_dict["accuracy"] = accuracy 615 | 616 | host_call = function_builder.construct_scalar_host_call( 617 | monitor_dict=monitor_dict, 618 | model_dir=FLAGS.model_dir, 619 | prefix="train/", 620 | reduce_fn=tf.reduce_mean) 621 | else: 622 | host_call = None 623 | 624 | train_spec = tf.contrib.tpu.TPUEstimatorSpec( 625 | mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, 626 | scaffold_fn=scaffold_fn) 627 | else: 628 | train_spec = tf.estimator.EstimatorSpec( 629 | mode=mode, loss=total_loss, train_op=train_op) 630 | 631 | return train_spec 632 | 633 | return model_fn 634 | 635 | 636 | def main(_): 637 | tf.logging.set_verbosity(tf.logging.INFO) 638 | 639 | #### Validate flags 640 | if FLAGS.save_steps is not None: 641 | FLAGS.iterations = min(FLAGS.iterations, FLAGS.save_steps) 642 | 643 | if FLAGS.do_predict: 644 | predict_dir = FLAGS.predict_dir 645 | if not tf.gfile.Exists(predict_dir): 646 | tf.gfile.MakeDirs(predict_dir) 647 | 648 | processors = { 649 | "mnli_matched": MnliMatchedProcessor, 650 | "mnli_mismatched": MnliMismatchedProcessor, 651 | 'sts-b': StsbProcessor, 652 | 'imdb': ImdbProcessor, 653 | "yelp5": Yelp5Processor 654 | } 655 | 656 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 657 | raise ValueError( 658 | "At least one of `do_train`, `do_eval, `do_predict` or " 659 | "`do_submit` must be True.") 660 | 661 | if not tf.gfile.Exists(FLAGS.output_dir): 662 | tf.gfile.MakeDirs(FLAGS.output_dir) 663 | 664 | task_name = FLAGS.task_name.lower() 665 | 666 | if task_name not in processors: 667 | raise ValueError("Task not found: %s" % (task_name)) 668 | 669 | processor = processors[task_name]() 670 | label_list = processor.get_labels() if not FLAGS.is_regression else None 671 | 672 | sp = spm.SentencePieceProcessor() 673 | sp.Load(FLAGS.spiece_model_file) 674 | def tokenize_fn(text): 675 | text = preprocess_text(text, lower=FLAGS.uncased) 676 | return encode_ids(sp, text) 677 | 678 | run_config = model_utils.configure_tpu(FLAGS) 679 | 680 | model_fn = get_model_fn(len(label_list) if label_list is not None else None) 681 | 682 | spm_basename = os.path.basename(FLAGS.spiece_model_file) 683 | 684 | # If TPU is not available, this will fall back to normal Estimator on CPU 685 | # or GPU. 686 | if FLAGS.use_tpu: 687 | estimator = tf.contrib.tpu.TPUEstimator( 688 | use_tpu=FLAGS.use_tpu, 689 | model_fn=model_fn, 690 | config=run_config, 691 | train_batch_size=FLAGS.train_batch_size, 692 | predict_batch_size=FLAGS.predict_batch_size, 693 | eval_batch_size=FLAGS.eval_batch_size) 694 | else: 695 | estimator = tf.estimator.Estimator( 696 | model_fn=model_fn, 697 | config=run_config) 698 | 699 | if FLAGS.do_train: 700 | train_file_base = "{}.len-{}.train.tf_record".format( 701 | spm_basename, FLAGS.max_seq_length) 702 | train_file = os.path.join(FLAGS.output_dir, train_file_base) 703 | tf.logging.info("Use tfrecord file {}".format(train_file)) 704 | 705 | train_examples = processor.get_train_examples(FLAGS.data_dir) 706 | np.random.shuffle(train_examples) 707 | tf.logging.info("Num of train samples: {}".format(len(train_examples))) 708 | 709 | file_based_convert_examples_to_features( 710 | train_examples, label_list, FLAGS.max_seq_length, tokenize_fn, 711 | train_file, FLAGS.num_passes) 712 | 713 | train_input_fn = file_based_input_fn_builder( 714 | input_file=train_file, 715 | seq_length=FLAGS.max_seq_length, 716 | is_training=True, 717 | drop_remainder=True) 718 | 719 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps) 720 | 721 | if FLAGS.do_eval or FLAGS.do_predict: 722 | if FLAGS.eval_split == "dev": 723 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 724 | else: 725 | eval_examples = processor.get_test_examples(FLAGS.data_dir) 726 | 727 | tf.logging.info("Num of eval samples: {}".format(len(eval_examples))) 728 | 729 | if FLAGS.do_eval: 730 | # TPU requires a fixed batch size for all batches, therefore the number 731 | # of examples must be a multiple of the batch size, or else examples 732 | # will get dropped. So we pad with fake examples which are ignored 733 | # later on. These do NOT count towards the metric (all tf.metrics 734 | # support a per-instance weight, and these get a weight of 0.0). 735 | # 736 | # Modified in XL: We also adopt the same mechanism for GPUs. 737 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 738 | eval_examples.append(PaddingInputExample()) 739 | 740 | eval_file_base = "{}.len-{}.{}.eval.tf_record".format( 741 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split) 742 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base) 743 | 744 | file_based_convert_examples_to_features( 745 | eval_examples, label_list, FLAGS.max_seq_length, tokenize_fn, 746 | eval_file) 747 | 748 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 749 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 750 | 751 | eval_input_fn = file_based_input_fn_builder( 752 | input_file=eval_file, 753 | seq_length=FLAGS.max_seq_length, 754 | is_training=False, 755 | drop_remainder=True) 756 | 757 | # Filter out all checkpoints in the directory 758 | steps_and_files = [] 759 | filenames = tf.gfile.ListDirectory(FLAGS.model_dir) 760 | 761 | for filename in filenames: 762 | if filename.endswith(".index"): 763 | ckpt_name = filename[:-6] 764 | cur_filename = join(FLAGS.model_dir, ckpt_name) 765 | global_step = int(cur_filename.split("-")[-1]) 766 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 767 | steps_and_files.append([global_step, cur_filename]) 768 | steps_and_files = sorted(steps_and_files, key=lambda x: x[0]) 769 | 770 | # Decide whether to evaluate all ckpts 771 | if not FLAGS.eval_all_ckpt: 772 | steps_and_files = steps_and_files[-1:] 773 | 774 | eval_results = [] 775 | for global_step, filename in sorted(steps_and_files, key=lambda x: x[0]): 776 | ret = estimator.evaluate( 777 | input_fn=eval_input_fn, 778 | steps=eval_steps, 779 | checkpoint_path=filename) 780 | 781 | ret["step"] = global_step 782 | ret["path"] = filename 783 | 784 | eval_results.append(ret) 785 | 786 | tf.logging.info("=" * 80) 787 | log_str = "Eval result | " 788 | for key, val in sorted(ret.items(), key=lambda x: x[0]): 789 | log_str += "{} {} | ".format(key, val) 790 | tf.logging.info(log_str) 791 | 792 | key_name = "eval_pearsonr" if FLAGS.is_regression else "eval_accuracy" 793 | eval_results.sort(key=lambda x: x[key_name], reverse=True) 794 | 795 | tf.logging.info("=" * 80) 796 | log_str = "Best result | " 797 | for key, val in sorted(eval_results[0].items(), key=lambda x: x[0]): 798 | log_str += "{} {} | ".format(key, val) 799 | tf.logging.info(log_str) 800 | 801 | if FLAGS.do_predict: 802 | eval_file_base = "{}.len-{}.{}.predict.tf_record".format( 803 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split) 804 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base) 805 | 806 | file_based_convert_examples_to_features( 807 | eval_examples, label_list, FLAGS.max_seq_length, tokenize_fn, 808 | eval_file) 809 | 810 | pred_input_fn = file_based_input_fn_builder( 811 | input_file=eval_file, 812 | seq_length=FLAGS.max_seq_length, 813 | is_training=False, 814 | drop_remainder=False) 815 | 816 | predict_results = [] 817 | with tf.gfile.Open(os.path.join(predict_dir, "{}.tsv".format( 818 | task_name)), "w") as fout: 819 | fout.write("index\tprediction\n") 820 | 821 | for pred_cnt, result in enumerate(estimator.predict( 822 | input_fn=pred_input_fn, 823 | yield_single_examples=True, 824 | checkpoint_path=FLAGS.predict_ckpt)): 825 | if pred_cnt % 1000 == 0: 826 | tf.logging.info("Predicting submission for example: {}".format( 827 | pred_cnt)) 828 | 829 | logits = [float(x) for x in result["logits"].flat] 830 | predict_results.append(logits) 831 | 832 | if len(logits) == 1: 833 | label_out = logits[0] 834 | elif len(logits) == 2: 835 | if logits[1] - logits[0] > FLAGS.predict_threshold: 836 | label_out = label_list[1] 837 | else: 838 | label_out = label_list[0] 839 | elif len(logits) > 2: 840 | max_index = np.argmax(np.array(logits, dtype=np.float32)) 841 | label_out = label_list[max_index] 842 | else: 843 | raise NotImplementedError 844 | 845 | fout.write("{}\t{}\n".format(pred_cnt, label_out)) 846 | 847 | predict_json_path = os.path.join(predict_dir, "{}.logits.json".format( 848 | task_name)) 849 | 850 | with tf.gfile.Open(predict_json_path, "w") as fp: 851 | json.dump(predict_results, fp, indent=4) 852 | 853 | 854 | if __name__ == "__main__": 855 | tf.app.run() 856 | -------------------------------------------------------------------------------- /run_race.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from os.path import join 6 | from absl import flags 7 | import os 8 | import csv 9 | import collections 10 | import numpy as np 11 | import time 12 | import math 13 | import json 14 | import random 15 | from copy import copy 16 | from collections import defaultdict as dd 17 | 18 | from scipy.stats import pearsonr, spearmanr 19 | from sklearn.metrics import matthews_corrcoef, f1_score 20 | 21 | import absl.logging as _logging # pylint: disable=unused-import 22 | 23 | import tensorflow as tf 24 | import sentencepiece as spm 25 | 26 | from data_utils import SEP_ID, VOCAB_SIZE, CLS_ID 27 | import model_utils 28 | import function_builder 29 | from classifier_utils import PaddingInputExample 30 | from classifier_utils import convert_single_example 31 | from prepro_utils import preprocess_text, encode_ids 32 | 33 | # Model 34 | flags.DEFINE_string("model_config_path", default=None, 35 | help="Model config path.") 36 | flags.DEFINE_float("dropout", default=0.1, 37 | help="Dropout rate.") 38 | flags.DEFINE_float("dropatt", default=0.1, 39 | help="Attention dropout rate.") 40 | flags.DEFINE_integer("clamp_len", default=-1, 41 | help="Clamp length") 42 | flags.DEFINE_string("summary_type", default="last", 43 | help="Method used to summarize a sequence into a compact vector.") 44 | flags.DEFINE_bool("use_summ_proj", default=True, 45 | help="Whether to use projection for summarizing sequences.") 46 | flags.DEFINE_bool("use_bfloat16", default=False, 47 | help="Whether to use bfloat16.") 48 | 49 | # Parameter initialization 50 | flags.DEFINE_enum("init", default="normal", 51 | enum_values=["normal", "uniform"], 52 | help="Initialization method.") 53 | flags.DEFINE_float("init_std", default=0.02, 54 | help="Initialization std when init is normal.") 55 | flags.DEFINE_float("init_range", default=0.1, 56 | help="Initialization std when init is uniform.") 57 | 58 | # I/O paths 59 | flags.DEFINE_bool("overwrite_data", default=False, 60 | help="If False, will use cached data if available.") 61 | flags.DEFINE_string("init_checkpoint", default=None, 62 | help="checkpoint path for initializing the model. " 63 | "Could be a pretrained model or a finetuned model.") 64 | flags.DEFINE_string("output_dir", default="", 65 | help="Output dir for TF records.") 66 | flags.DEFINE_string("spiece_model_file", default="", 67 | help="Sentence Piece model path.") 68 | flags.DEFINE_string("model_dir", default="", 69 | help="Directory for saving the finetuned model.") 70 | flags.DEFINE_string("data_dir", default="", 71 | help="Directory for input data.") 72 | 73 | # TPUs and machines 74 | flags.DEFINE_bool("use_tpu", default=False, help="whether to use TPU.") 75 | flags.DEFINE_integer("num_hosts", default=1, help="How many TPU hosts.") 76 | flags.DEFINE_integer("num_core_per_host", default=8, 77 | help="8 for TPU v2 and v3-8, 16 for larger TPU v3 pod. In the context " 78 | "of GPU training, it refers to the number of GPUs used.") 79 | flags.DEFINE_string("tpu_job_name", default=None, help="TPU worker job name.") 80 | flags.DEFINE_string("tpu", default=None, help="TPU name.") 81 | flags.DEFINE_string("tpu_zone", default=None, help="TPU zone.") 82 | flags.DEFINE_string("gcp_project", default=None, help="gcp project.") 83 | flags.DEFINE_string("master", default=None, help="master") 84 | flags.DEFINE_integer("iterations", default=1000, 85 | help="number of iterations per TPU training loop.") 86 | 87 | # Training 88 | flags.DEFINE_bool("do_train", default=False, help="whether to do training") 89 | flags.DEFINE_integer("train_steps", default=12000, 90 | help="Number of training steps") 91 | flags.DEFINE_integer("warmup_steps", default=0, help="number of warmup steps") 92 | flags.DEFINE_float("learning_rate", default=2e-5, help="initial learning rate") 93 | flags.DEFINE_float("lr_layer_decay_rate", 1.0, 94 | "Top layer: lr[L] = FLAGS.learning_rate." 95 | "Low layer: lr[l-1] = lr[l] * lr_layer_decay_rate.") 96 | flags.DEFINE_float("min_lr_ratio", default=0.0, 97 | help="min lr ratio for cos decay.") 98 | flags.DEFINE_float("clip", default=1.0, help="Gradient clipping") 99 | flags.DEFINE_integer("max_save", default=0, 100 | help="Max number of checkpoints to save. Use 0 to save all.") 101 | flags.DEFINE_integer("save_steps", default=None, 102 | help="Save the model for every save_steps. " 103 | "If None, not to save any model.") 104 | flags.DEFINE_integer("train_batch_size", default=8, 105 | help="Batch size for training. Note that batch size 1 corresponds to " 106 | "4 sequences: one paragraph + one quesetion + 4 candidate answers.") 107 | flags.DEFINE_float("weight_decay", default=0.00, help="weight decay rate") 108 | flags.DEFINE_float("adam_epsilon", default=1e-6, help="adam epsilon") 109 | flags.DEFINE_string("decay_method", default="poly", help="poly or cos") 110 | 111 | # Evaluation 112 | flags.DEFINE_bool("do_eval", default=False, help="whether to do eval") 113 | flags.DEFINE_string("eval_split", default="dev", 114 | help="could be dev or test") 115 | flags.DEFINE_integer("eval_batch_size", default=32, 116 | help="Batch size for evaluation.") 117 | 118 | # Data config 119 | flags.DEFINE_integer("max_seq_length", default=512, 120 | help="Max length for the paragraph.") 121 | flags.DEFINE_integer("max_qa_length", default=128, 122 | help="Max length for the concatenated question and answer.") 123 | flags.DEFINE_integer("shuffle_buffer", default=2048, 124 | help="Buffer size used for shuffle.") 125 | flags.DEFINE_bool("uncased", default=False, 126 | help="Use uncased.") 127 | flags.DEFINE_bool("high_only", default=False, 128 | help="Evaluate on high school only.") 129 | flags.DEFINE_bool("middle_only", default=False, 130 | help="Evaluate on middle school only.") 131 | 132 | FLAGS = flags.FLAGS 133 | 134 | SEG_ID_A = 0 135 | SEG_ID_B = 1 136 | SEG_ID_CLS = 2 137 | SEG_ID_SEP = 3 138 | SEG_ID_PAD = 4 139 | 140 | 141 | class PaddingInputExample(object): 142 | """Fake example so the num input examples is a multiple of the batch size. 143 | When running eval/predict on the TPU, we need to pad the number of examples 144 | to be a multiple of the batch size, because the TPU requires a fixed batch 145 | size. The alternative is to drop the last batch, which is bad because it means 146 | the entire output data won't be generated. 147 | We use this class instead of `None` because treating `None` as padding 148 | battches could cause silent errors. 149 | """ 150 | 151 | 152 | class InputFeatures(object): 153 | """A single set of features of data.""" 154 | 155 | def __init__(self, 156 | input_ids, 157 | input_mask, 158 | segment_ids, 159 | label_id, 160 | is_real_example=True): 161 | self.input_ids = input_ids 162 | self.input_mask = input_mask 163 | self.segment_ids = segment_ids 164 | self.label_id = label_id 165 | self.is_real_example = is_real_example 166 | 167 | 168 | def convert_single_example(example, tokenize_fn): 169 | """Converts a single `InputExample` into a single `InputFeatures`.""" 170 | 171 | if isinstance(example, PaddingInputExample): 172 | return InputFeatures( 173 | input_ids=[0] * FLAGS.max_seq_length * 4, 174 | input_mask=[1] * FLAGS.max_seq_length * 4, 175 | segment_ids=[0] * FLAGS.max_seq_length * 4, 176 | label_id=0, 177 | is_real_example=False) 178 | 179 | input_ids, input_mask, all_seg_ids = [], [], [] 180 | tokens_context = tokenize_fn(example.context) 181 | for i in range(len(example.qa_list)): 182 | tokens_qa = tokenize_fn(example.qa_list[i]) 183 | if len(tokens_qa) > FLAGS.max_qa_length: 184 | tokens_qa = tokens_qa[- FLAGS.max_qa_length:] 185 | 186 | if len(tokens_context) + len(tokens_qa) > FLAGS.max_seq_length - 3: 187 | tokens = tokens_context[: FLAGS.max_seq_length - 3 - len(tokens_qa)] 188 | else: 189 | tokens = tokens_context 190 | 191 | segment_ids = [SEG_ID_A] * len(tokens) 192 | 193 | tokens.append(SEP_ID) 194 | segment_ids.append(SEG_ID_A) 195 | 196 | tokens.extend(tokens_qa) 197 | segment_ids.extend([SEG_ID_B] * len(tokens_qa)) 198 | 199 | tokens.append(SEP_ID) 200 | segment_ids.append(SEG_ID_B) 201 | 202 | tokens.append(CLS_ID) 203 | segment_ids.append(SEG_ID_CLS) 204 | 205 | cur_input_ids = tokens 206 | cur_input_mask = [0] * len(cur_input_ids) 207 | 208 | if len(cur_input_ids) < FLAGS.max_seq_length: 209 | delta_len = FLAGS.max_seq_length - len(cur_input_ids) 210 | cur_input_ids = [0] * delta_len + cur_input_ids 211 | cur_input_mask = [1] * delta_len + cur_input_mask 212 | segment_ids = [SEG_ID_PAD] * delta_len + segment_ids 213 | 214 | assert len(cur_input_ids) == FLAGS.max_seq_length 215 | assert len(cur_input_mask) == FLAGS.max_seq_length 216 | assert len(segment_ids) == FLAGS.max_seq_length 217 | 218 | input_ids.extend(cur_input_ids) 219 | input_mask.extend(cur_input_mask) 220 | all_seg_ids.extend(segment_ids) 221 | 222 | label_id = example.label 223 | 224 | feature = InputFeatures( 225 | input_ids=input_ids, 226 | input_mask=input_mask, 227 | segment_ids=all_seg_ids, 228 | label_id=label_id) 229 | return feature 230 | 231 | 232 | class InputExample(object): 233 | def __init__(self, context, qa_list, label, level): 234 | self.context = context 235 | self.qa_list = qa_list 236 | self.label = label 237 | self.level = level 238 | 239 | 240 | def get_examples(data_dir, set_type): 241 | examples = [] 242 | 243 | for level in ["middle", "high"]: 244 | if level == "middle" and FLAGS.high_only: continue 245 | if level == "high" and FLAGS.middle_only: continue 246 | 247 | cur_dir = os.path.join(data_dir, set_type, level) 248 | for filename in tf.gfile.ListDirectory(cur_dir): 249 | cur_path = os.path.join(cur_dir, filename) 250 | with tf.gfile.Open(cur_path) as f: 251 | cur_data = json.load(f) 252 | 253 | answers = cur_data["answers"] 254 | options = cur_data["options"] 255 | questions = cur_data["questions"] 256 | context = cur_data["article"] 257 | 258 | for i in range(len(answers)): 259 | label = ord(answers[i]) - ord("A") 260 | qa_list = [] 261 | 262 | question = questions[i] 263 | for j in range(4): 264 | option = options[i][j] 265 | 266 | if "_" in question: 267 | qa_cat = question.replace("_", option) 268 | else: 269 | qa_cat = " ".join([question, option]) 270 | 271 | qa_list.append(qa_cat) 272 | 273 | examples.append(InputExample(context, qa_list, label, level)) 274 | 275 | return examples 276 | 277 | 278 | def file_based_convert_examples_to_features(examples, tokenize_fn, output_file): 279 | if tf.gfile.Exists(output_file) and not FLAGS.overwrite_data: 280 | return 281 | 282 | tf.logging.info("Start writing tfrecord %s.", output_file) 283 | writer = tf.python_io.TFRecordWriter(output_file) 284 | 285 | for ex_index, example in enumerate(examples): 286 | if ex_index % 10000 == 0: 287 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 288 | 289 | feature = convert_single_example(example, tokenize_fn) 290 | 291 | def create_int_feature(values): 292 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 293 | return f 294 | 295 | def create_float_feature(values): 296 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 297 | return f 298 | 299 | features = collections.OrderedDict() 300 | features["input_ids"] = create_int_feature(feature.input_ids) 301 | features["input_mask"] = create_float_feature(feature.input_mask) 302 | features["segment_ids"] = create_int_feature(feature.segment_ids) 303 | features["label_ids"] = create_int_feature([feature.label_id]) 304 | features["is_real_example"] = create_int_feature( 305 | [int(feature.is_real_example)]) 306 | 307 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 308 | writer.write(tf_example.SerializeToString()) 309 | writer.close() 310 | 311 | 312 | def file_based_input_fn_builder(input_file, seq_length, is_training, 313 | drop_remainder): 314 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 315 | 316 | name_to_features = { 317 | "input_ids": tf.FixedLenFeature([seq_length * 4], tf.int64), 318 | "input_mask": tf.FixedLenFeature([seq_length * 4], tf.float32), 319 | "segment_ids": tf.FixedLenFeature([seq_length * 4], tf.int64), 320 | "label_ids": tf.FixedLenFeature([], tf.int64), 321 | "is_real_example": tf.FixedLenFeature([], tf.int64), 322 | } 323 | 324 | tf.logging.info("Input tfrecord file {}".format(input_file)) 325 | 326 | def _decode_record(record, name_to_features): 327 | """Decodes a record to a TensorFlow example.""" 328 | example = tf.parse_single_example(record, name_to_features) 329 | 330 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 331 | # So cast all int64 to int32. 332 | for name in list(example.keys()): 333 | t = example[name] 334 | if t.dtype == tf.int64: 335 | t = tf.cast(t, tf.int32) 336 | example[name] = t 337 | 338 | return example 339 | 340 | def input_fn(params): 341 | """The actual input function.""" 342 | if FLAGS.use_tpu: 343 | batch_size = params["batch_size"] 344 | elif is_training: 345 | batch_size = FLAGS.train_batch_size 346 | elif FLAGS.do_eval: 347 | batch_size = FLAGS.eval_batch_size 348 | 349 | # For training, we want a lot of parallel reading and shuffling. 350 | # For eval, we want no shuffling and parallel reading doesn't matter. 351 | d = tf.data.TFRecordDataset(input_file) 352 | if is_training: 353 | d = d.shuffle(buffer_size=FLAGS.shuffle_buffer) 354 | d = d.repeat() 355 | # d = d.shuffle(buffer_size=100) 356 | 357 | d = d.apply( 358 | tf.contrib.data.map_and_batch( 359 | lambda record: _decode_record(record, name_to_features), 360 | batch_size=batch_size, 361 | drop_remainder=drop_remainder)) 362 | 363 | return d 364 | 365 | return input_fn 366 | 367 | 368 | def get_model_fn(): 369 | def model_fn(features, labels, mode, params): 370 | #### Training or Evaluation 371 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 372 | 373 | total_loss, per_example_loss, logits = function_builder.get_race_loss( 374 | FLAGS, features, is_training) 375 | 376 | #### Check model parameters 377 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) 378 | tf.logging.info('#params: {}'.format(num_params)) 379 | 380 | #### load pretrained models 381 | scaffold_fn = model_utils.init_from_checkpoint(FLAGS) 382 | 383 | #### Evaluation mode 384 | if mode == tf.estimator.ModeKeys.EVAL: 385 | assert FLAGS.num_hosts == 1 386 | 387 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 388 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 389 | eval_input_dict = { 390 | 'labels': label_ids, 391 | 'predictions': predictions, 392 | 'weights': is_real_example 393 | } 394 | accuracy = tf.metrics.accuracy(**eval_input_dict) 395 | 396 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 397 | return { 398 | 'eval_accuracy': accuracy, 399 | 'eval_loss': loss} 400 | 401 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 402 | 403 | #### Constucting evaluation TPUEstimatorSpec with new cache. 404 | label_ids = tf.reshape(features['label_ids'], [-1]) 405 | metric_args = [per_example_loss, label_ids, logits, is_real_example] 406 | 407 | if FLAGS.use_tpu: 408 | eval_spec = tf.contrib.tpu.TPUEstimatorSpec( 409 | mode=mode, 410 | loss=total_loss, 411 | eval_metrics=(metric_fn, metric_args), 412 | scaffold_fn=scaffold_fn) 413 | else: 414 | eval_spec = tf.estimator.EstimatorSpec( 415 | mode=mode, 416 | loss=total_loss, 417 | eval_metric_ops=metric_fn(*metric_args)) 418 | 419 | return eval_spec 420 | 421 | 422 | #### Configuring the optimizer 423 | train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss) 424 | 425 | monitor_dict = {} 426 | monitor_dict["lr"] = learning_rate 427 | 428 | #### Constucting training TPUEstimatorSpec with new cache. 429 | if FLAGS.use_tpu: 430 | #### Creating host calls 431 | host_call = None 432 | 433 | train_spec = tf.contrib.tpu.TPUEstimatorSpec( 434 | mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, 435 | scaffold_fn=scaffold_fn) 436 | else: 437 | train_spec = tf.estimator.EstimatorSpec( 438 | mode=mode, loss=total_loss, train_op=train_op) 439 | 440 | return train_spec 441 | 442 | return model_fn 443 | 444 | 445 | def main(_): 446 | tf.logging.set_verbosity(tf.logging.INFO) 447 | 448 | #### Validate flags 449 | if FLAGS.save_steps is not None: 450 | FLAGS.iterations = min(FLAGS.iterations, FLAGS.save_steps) 451 | 452 | if not FLAGS.do_train and not FLAGS.do_eval: 453 | raise ValueError( 454 | "At least one of `do_train` or `do_eval` must be True.") 455 | 456 | if not tf.gfile.Exists(FLAGS.output_dir): 457 | tf.gfile.MakeDirs(FLAGS.output_dir) 458 | 459 | sp = spm.SentencePieceProcessor() 460 | sp.Load(FLAGS.spiece_model_file) 461 | def tokenize_fn(text): 462 | text = preprocess_text(text, lower=FLAGS.uncased) 463 | return encode_ids(sp, text) 464 | 465 | # TPU Configuration 466 | run_config = model_utils.configure_tpu(FLAGS) 467 | 468 | model_fn = get_model_fn() 469 | 470 | spm_basename = os.path.basename(FLAGS.spiece_model_file) 471 | 472 | # If TPU is not available, this will fall back to normal Estimator on CPU 473 | # or GPU. 474 | if FLAGS.use_tpu: 475 | estimator = tf.contrib.tpu.TPUEstimator( 476 | use_tpu=FLAGS.use_tpu, 477 | model_fn=model_fn, 478 | config=run_config, 479 | train_batch_size=FLAGS.train_batch_size, 480 | eval_batch_size=FLAGS.eval_batch_size) 481 | else: 482 | estimator = tf.estimator.Estimator( 483 | model_fn=model_fn, 484 | config=run_config) 485 | 486 | if FLAGS.do_train: 487 | train_file_base = "{}.len-{}.train.tf_record".format( 488 | spm_basename, FLAGS.max_seq_length) 489 | train_file = os.path.join(FLAGS.output_dir, train_file_base) 490 | 491 | if not tf.gfile.Exists(train_file) or FLAGS.overwrite_data: 492 | train_examples = get_examples(FLAGS.data_dir, "train") 493 | random.shuffle(train_examples) 494 | file_based_convert_examples_to_features( 495 | train_examples, tokenize_fn, train_file) 496 | 497 | train_input_fn = file_based_input_fn_builder( 498 | input_file=train_file, 499 | seq_length=FLAGS.max_seq_length, 500 | is_training=True, 501 | drop_remainder=True) 502 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps) 503 | 504 | if FLAGS.do_eval: 505 | eval_examples = get_examples(FLAGS.data_dir, FLAGS.eval_split) 506 | tf.logging.info("Num of eval samples: {}".format(len(eval_examples))) 507 | 508 | # TPU requires a fixed batch size for all batches, therefore the number 509 | # of examples must be a multiple of the batch size, or else examples 510 | # will get dropped. So we pad with fake examples which are ignored 511 | # later on. These do NOT count towards the metric (all tf.metrics 512 | # support a per-instance weight, and these get a weight of 0.0). 513 | # 514 | # Modified in XL: We also adopt the same mechanism for GPUs. 515 | 516 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 517 | eval_examples.append(PaddingInputExample()) 518 | 519 | eval_file_base = "{}.len-{}.{}.tf_record".format( 520 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split) 521 | 522 | if FLAGS.high_only: 523 | eval_file_base = "high." + eval_file_base 524 | elif FLAGS.middle_only: 525 | eval_file_base = "middle." + eval_file_base 526 | 527 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base) 528 | file_based_convert_examples_to_features( 529 | eval_examples, tokenize_fn, eval_file) 530 | 531 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 532 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 533 | 534 | eval_input_fn = file_based_input_fn_builder( 535 | input_file=eval_file, 536 | seq_length=FLAGS.max_seq_length, 537 | is_training=False, 538 | drop_remainder=True) 539 | 540 | ret = estimator.evaluate( 541 | input_fn=eval_input_fn, 542 | steps=eval_steps) 543 | 544 | # Log current result 545 | tf.logging.info("=" * 80) 546 | log_str = "Eval | " 547 | for key, val in ret.items(): 548 | log_str += "{} {} | ".format(key, val) 549 | tf.logging.info(log_str) 550 | tf.logging.info("=" * 80) 551 | 552 | 553 | if __name__ == "__main__": 554 | tf.app.run() 555 | -------------------------------------------------------------------------------- /scripts/gpu_squad_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #### local path 4 | SQUAD_DIR=data/squad 5 | INIT_CKPT_DIR=xlnet_cased_L-12_H-768_A-12 6 | PROC_DATA_DIR=proc_data/squad 7 | MODEL_DIR=experiment/squad 8 | 9 | #### Use 3 GPUs, each with 8 seqlen-512 samples 10 | 11 | python run_squad.py \ 12 | --use_tpu=False \ 13 | --num_hosts=1 \ 14 | --num_core_per_host=3 \ 15 | --model_config_path=${INIT_CKPT_DIR}/xlnet_config.json \ 16 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \ 17 | --output_dir=${PROC_DATA_DIR} \ 18 | --init_checkpoint=${INIT_CKPT_DIR}/xlnet_model.ckpt \ 19 | --model_dir=${MODEL_DIR} \ 20 | --train_file=${SQUAD_DIR}/train-v2.0.json \ 21 | --predict_file=${SQUAD_DIR}/dev-v2.0.json \ 22 | --uncased=False \ 23 | --max_seq_length=512 \ 24 | --do_train=True \ 25 | --train_batch_size=8 \ 26 | --do_predict=True \ 27 | --predict_batch_size=32 \ 28 | --learning_rate=2e-5 \ 29 | --adam_epsilon=1e-6 \ 30 | --iterations=1000 \ 31 | --save_steps=1000 \ 32 | --train_steps=12000 \ 33 | --warmup_steps=1000 \ 34 | $@ 35 | -------------------------------------------------------------------------------- /scripts/prepro_squad.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #### local path 4 | SQUAD_DIR=data/squad 5 | INIT_CKPT_DIR=xlnet_cased_L-24_H-1024_A-16 6 | 7 | #### google storage path 8 | GS_ROOT= 9 | GS_PROC_DATA_DIR=${GS_ROOT}/proc_data/squad 10 | 11 | python run_squad.py \ 12 | --use_tpu=False \ 13 | --do_prepro=True \ 14 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \ 15 | --train_file=${SQUAD_DIR}/train-v2.0.json \ 16 | --output_dir=${GS_PROC_DATA_DIR} \ 17 | --uncased=False \ 18 | --max_seq_length=512 \ 19 | $@ 20 | 21 | #### Potential multi-processing version 22 | # NUM_PROC=8 23 | # for i in `seq 0 $((NUM_PROC - 1))`; do 24 | # python run_squad.py \ 25 | # --use_tpu=False \ 26 | # --do_prepro=True \ 27 | # --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \ 28 | # --train_file=${SQUAD_DIR}/train-v2.0.json \ 29 | # --output_dir=${GS_PROC_DATA_DIR} \ 30 | # --uncased=False \ 31 | # --max_seq_length=512 \ 32 | # --num_proc=${NUM_PROC} \ 33 | # --proc_id=${i} \ 34 | # $@ & 35 | # done 36 | -------------------------------------------------------------------------------- /scripts/tpu_race_large_bsz32.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #### local path 4 | RACE_DIR=data/RACE 5 | INIT_CKPT_DIR=xlnet_cased_L-24_H-1024_A-16 6 | 7 | #### google storage path 8 | GS_ROOT= 9 | GS_INIT_CKPT_DIR=${GS_ROOT}/${INIT_CKPT_DIR} 10 | GS_PROC_DATA_DIR=${GS_ROOT}/proc_data/race 11 | GS_MODEL_DIR=${GS_ROOT}/experiment/race 12 | 13 | # TPU name in google cloud 14 | TPU_NAME= 15 | 16 | python run_race.py \ 17 | --use_tpu=True \ 18 | --tpu=${TPU_NAME} \ 19 | --num_hosts=4 \ 20 | --num_core_per_host=8 \ 21 | --model_config_path=${INIT_CKPT_DIR}/xlnet_config.json \ 22 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \ 23 | --output_dir=${GS_PROC_DATA_DIR} \ 24 | --init_checkpoint=${GS_INIT_CKPT_DIR}/xlnet_model.ckpt \ 25 | --model_dir=${GS_MODEL_DIR} \ 26 | --data_dir=${RACE_DIR} \ 27 | --max_seq_length=512 \ 28 | --max_qa_length=128 \ 29 | --uncased=False \ 30 | --do_train=True \ 31 | --train_batch_size=32 \ 32 | --do_eval=True \ 33 | --eval_batch_size=32 \ 34 | --train_steps=12000 \ 35 | --save_steps=1000 \ 36 | --iterations=1000 \ 37 | --warmup_steps=1000 \ 38 | --learning_rate=2e-5 \ 39 | --weight_decay=0 \ 40 | --adam_epsilon=1e-6 \ 41 | $@ 42 | -------------------------------------------------------------------------------- /scripts/tpu_race_large_bsz8.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #### local path 4 | RACE_DIR=data/RACE 5 | INIT_CKPT_DIR=xlnet_cased_L-24_H-1024_A-16 6 | 7 | #### google storage path 8 | GS_ROOT= 9 | GS_INIT_CKPT_DIR=${GS_ROOT}/${INIT_CKPT_DIR} 10 | GS_PROC_DATA_DIR=${GS_ROOT}/proc_data/race 11 | GS_MODEL_DIR=${GS_ROOT}/experiment/race 12 | 13 | # TPU name in google cloud 14 | TPU_NAME= 15 | 16 | python run_race.py \ 17 | --use_tpu=True \ 18 | --tpu=${TPU_NAME} \ 19 | --num_hosts=1 \ 20 | --num_core_per_host=8 \ 21 | --model_config_path=${INIT_CKPT_DIR}/xlnet_config.json \ 22 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \ 23 | --output_dir=${GS_PROC_DATA_DIR} \ 24 | --init_checkpoint=${GS_INIT_CKPT_DIR}/xlnet_model.ckpt \ 25 | --model_dir=${GS_MODEL_DIR} \ 26 | --data_dir=${RACE_DIR} \ 27 | --max_seq_length=512 \ 28 | --max_qa_length=128 \ 29 | --uncased=False \ 30 | --do_train=True \ 31 | --train_batch_size=8 \ 32 | --do_eval=True \ 33 | --eval_batch_size=32 \ 34 | --train_steps=12000 \ 35 | --save_steps=1000 \ 36 | --iterations=1000 \ 37 | --warmup_steps=1000 \ 38 | --learning_rate=2e-5 \ 39 | --weight_decay=0 \ 40 | --adam_epsilon=1e-6 \ 41 | $@ 42 | -------------------------------------------------------------------------------- /scripts/tpu_squad_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #### local path 4 | SQUAD_DIR=data/squad 5 | INIT_CKPT_DIR=xlnet_cased_L-24_H-1024_A-16 6 | 7 | #### google storage path 8 | GS_ROOT= 9 | GS_INIT_CKPT_DIR=${GS_ROOT}/${INIT_CKPT_DIR} 10 | GS_PROC_DATA_DIR=${GS_ROOT}/proc_data/squad 11 | GS_MODEL_DIR=${GS_ROOT}/experiment/squad 12 | 13 | # TPU name in google cloud 14 | TPU_NAME= 15 | 16 | python run_squad.py \ 17 | --use_tpu=True \ 18 | --tpu=${TPU_NAME} \ 19 | --num_hosts=1 \ 20 | --num_core_per_host=8 \ 21 | --model_config_path=${INIT_CKPT_DIR}/xlnet_config.json \ 22 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \ 23 | --output_dir=${GS_PROC_DATA_DIR} \ 24 | --init_checkpoint=${GS_INIT_CKPT_DIR}/xlnet_model.ckpt \ 25 | --model_dir=${GS_MODEL_DIR} \ 26 | --train_file=${SQUAD_DIR}/train-v2.0.json \ 27 | --predict_file=${SQUAD_DIR}/dev-v2.0.json \ 28 | --uncased=False \ 29 | --max_seq_length=512 \ 30 | --do_train=True \ 31 | --train_batch_size=48 \ 32 | --do_predict=True \ 33 | --predict_batch_size=32 \ 34 | --learning_rate=3e-5 \ 35 | --adam_epsilon=1e-6 \ 36 | --iterations=1000 \ 37 | --save_steps=1000 \ 38 | --train_steps=8000 \ 39 | --warmup_steps=1000 \ 40 | $@ 41 | -------------------------------------------------------------------------------- /squad_utils.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | OPTS = None 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 21 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 22 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 23 | parser.add_argument('--out-file', '-o', metavar='eval.json', 24 | help='Write accuracy metrics to file (default is stdout).') 25 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 26 | help='Model estimates of probability of no answer.') 27 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 28 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 29 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 30 | help='Save precision-recall curves to directory.') 31 | parser.add_argument('--verbose', '-v', action='store_true') 32 | if len(sys.argv) == 1: 33 | parser.print_help() 34 | sys.exit(1) 35 | return parser.parse_args() 36 | 37 | def make_qid_to_has_ans(dataset): 38 | qid_to_has_ans = {} 39 | for article in dataset: 40 | for p in article['paragraphs']: 41 | for qa in p['qas']: 42 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 43 | return qid_to_has_ans 44 | 45 | def normalize_answer(s): 46 | """Lower text and remove punctuation, articles and extra whitespace.""" 47 | def remove_articles(text): 48 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 49 | return re.sub(regex, ' ', text) 50 | def white_space_fix(text): 51 | return ' '.join(text.split()) 52 | def remove_punc(text): 53 | exclude = set(string.punctuation) 54 | return ''.join(ch for ch in text if ch not in exclude) 55 | def lower(text): 56 | return text.lower() 57 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 58 | 59 | def get_tokens(s): 60 | if not s: return [] 61 | return normalize_answer(s).split() 62 | 63 | def compute_exact(a_gold, a_pred): 64 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 65 | 66 | def compute_f1(a_gold, a_pred): 67 | gold_toks = get_tokens(a_gold) 68 | pred_toks = get_tokens(a_pred) 69 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 70 | num_same = sum(common.values()) 71 | if len(gold_toks) == 0 or len(pred_toks) == 0: 72 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 73 | return int(gold_toks == pred_toks) 74 | if num_same == 0: 75 | return 0 76 | precision = 1.0 * num_same / len(pred_toks) 77 | recall = 1.0 * num_same / len(gold_toks) 78 | f1 = (2 * precision * recall) / (precision + recall) 79 | return f1 80 | 81 | def get_raw_scores(dataset, preds): 82 | exact_scores = {} 83 | f1_scores = {} 84 | for article in dataset: 85 | for p in article['paragraphs']: 86 | for qa in p['qas']: 87 | qid = qa['id'] 88 | gold_answers = [a['text'] for a in qa['answers'] 89 | if normalize_answer(a['text'])] 90 | if not gold_answers: 91 | # For unanswerable questions, only correct answer is empty string 92 | gold_answers = [''] 93 | if qid not in preds: 94 | print('Missing prediction for %s' % qid) 95 | continue 96 | a_pred = preds[qid] 97 | # Take max over all gold answers 98 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 99 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 100 | return exact_scores, f1_scores 101 | 102 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 103 | new_scores = {} 104 | for qid, s in scores.items(): 105 | pred_na = na_probs[qid] > na_prob_thresh 106 | if pred_na: 107 | new_scores[qid] = float(not qid_to_has_ans[qid]) 108 | else: 109 | new_scores[qid] = s 110 | return new_scores 111 | 112 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 113 | if not qid_list: 114 | total = len(exact_scores) 115 | return collections.OrderedDict([ 116 | ('exact', 100.0 * sum(exact_scores.values()) / total), 117 | ('f1', 100.0 * sum(f1_scores.values()) / total), 118 | ('total', total), 119 | ]) 120 | else: 121 | total = len(qid_list) 122 | return collections.OrderedDict([ 123 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 124 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 125 | ('total', total), 126 | ]) 127 | 128 | def merge_eval(main_eval, new_eval, prefix): 129 | for k in new_eval: 130 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 131 | 132 | def plot_pr_curve(precisions, recalls, out_image, title): 133 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 134 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 135 | plt.xlabel('Recall') 136 | plt.ylabel('Precision') 137 | plt.xlim([0.0, 1.05]) 138 | plt.ylim([0.0, 1.05]) 139 | plt.title(title) 140 | plt.savefig(out_image) 141 | plt.clf() 142 | 143 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 144 | out_image=None, title=None): 145 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 146 | true_pos = 0.0 147 | cur_p = 1.0 148 | cur_r = 0.0 149 | precisions = [1.0] 150 | recalls = [0.0] 151 | avg_prec = 0.0 152 | for i, qid in enumerate(qid_list): 153 | if qid_to_has_ans[qid]: 154 | true_pos += scores[qid] 155 | cur_p = true_pos / float(i+1) 156 | cur_r = true_pos / float(num_true_pos) 157 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 158 | # i.e., if we can put a threshold after this point 159 | avg_prec += cur_p * (cur_r - recalls[-1]) 160 | precisions.append(cur_p) 161 | recalls.append(cur_r) 162 | if out_image: 163 | plot_pr_curve(precisions, recalls, out_image, title) 164 | return {'ap': 100.0 * avg_prec} 165 | 166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 167 | qid_to_has_ans, out_image_dir): 168 | if out_image_dir and not os.path.exists(out_image_dir): 169 | os.makedirs(out_image_dir) 170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 171 | if num_true_pos == 0: 172 | return 173 | pr_exact = make_precision_recall_eval( 174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 176 | title='Precision-Recall curve for Exact Match score') 177 | pr_f1 = make_precision_recall_eval( 178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 180 | title='Precision-Recall curve for F1 score') 181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 182 | pr_oracle = make_precision_recall_eval( 183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 186 | merge_eval(main_eval, pr_exact, 'pr_exact') 187 | merge_eval(main_eval, pr_f1, 'pr_f1') 188 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 189 | 190 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 191 | if not qid_list: 192 | return 193 | x = [na_probs[k] for k in qid_list] 194 | weights = np.ones_like(x) / float(len(x)) 195 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 196 | plt.xlabel('Model probability of no-answer') 197 | plt.ylabel('Proportion of dataset') 198 | plt.title('Histogram of no-answer probability: %s' % name) 199 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 200 | plt.clf() 201 | 202 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 203 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 204 | cur_score = num_no_ans 205 | best_score = cur_score 206 | best_thresh = 0.0 207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 208 | for i, qid in enumerate(qid_list): 209 | if qid not in scores: continue 210 | if qid_to_has_ans[qid]: 211 | diff = scores[qid] 212 | else: 213 | if preds[qid]: 214 | diff = -1 215 | else: 216 | diff = 0 217 | cur_score += diff 218 | if cur_score > best_score: 219 | best_score = cur_score 220 | best_thresh = na_probs[qid] 221 | return 100.0 * best_score / len(scores), best_thresh 222 | 223 | def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): 224 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 225 | cur_score = num_no_ans 226 | best_score = cur_score 227 | best_thresh = 0.0 228 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 229 | for i, qid in enumerate(qid_list): 230 | if qid not in scores: continue 231 | if qid_to_has_ans[qid]: 232 | diff = scores[qid] 233 | else: 234 | if preds[qid]: 235 | diff = -1 236 | else: 237 | diff = 0 238 | cur_score += diff 239 | if cur_score > best_score: 240 | best_score = cur_score 241 | best_thresh = na_probs[qid] 242 | 243 | has_ans_score, has_ans_cnt = 0, 0 244 | for qid in qid_list: 245 | if not qid_to_has_ans[qid]: continue 246 | has_ans_cnt += 1 247 | 248 | if qid not in scores: continue 249 | has_ans_score += scores[qid] 250 | 251 | return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt 252 | 253 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 254 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 255 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 256 | main_eval['best_exact'] = best_exact 257 | main_eval['best_exact_thresh'] = exact_thresh 258 | main_eval['best_f1'] = best_f1 259 | main_eval['best_f1_thresh'] = f1_thresh 260 | 261 | def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 262 | best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans) 263 | best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans) 264 | main_eval['best_exact'] = best_exact 265 | main_eval['best_exact_thresh'] = exact_thresh 266 | main_eval['best_f1'] = best_f1 267 | main_eval['best_f1_thresh'] = f1_thresh 268 | main_eval['has_ans_exact'] = has_ans_exact 269 | main_eval['has_ans_f1'] = has_ans_f1 270 | 271 | def main(): 272 | with open(OPTS.data_file) as f: 273 | dataset_json = json.load(f) 274 | dataset = dataset_json['data'] 275 | with open(OPTS.pred_file) as f: 276 | preds = json.load(f) 277 | 278 | new_orig_data = [] 279 | for article in dataset: 280 | for p in article['paragraphs']: 281 | for qa in p['qas']: 282 | if qa['id'] in preds: 283 | new_para = {'qas': [qa]} 284 | new_article = {'paragraphs': [new_para]} 285 | new_orig_data.append(new_article) 286 | dataset = new_orig_data 287 | 288 | if OPTS.na_prob_file: 289 | with open(OPTS.na_prob_file) as f: 290 | na_probs = json.load(f) 291 | else: 292 | na_probs = {k: 0.0 for k in preds} 293 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 294 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 295 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 296 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 297 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 298 | OPTS.na_prob_thresh) 299 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 300 | OPTS.na_prob_thresh) 301 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 302 | if has_ans_qids: 303 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 304 | merge_eval(out_eval, has_ans_eval, 'HasAns') 305 | if no_ans_qids: 306 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 307 | merge_eval(out_eval, no_ans_eval, 'NoAns') 308 | if OPTS.na_prob_file: 309 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 310 | if OPTS.na_prob_file and OPTS.out_image_dir: 311 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 312 | qid_to_has_ans, OPTS.out_image_dir) 313 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 314 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 315 | if OPTS.out_file: 316 | with open(OPTS.out_file, 'w') as f: 317 | json.dump(out_eval, f) 318 | else: 319 | print(json.dumps(out_eval, indent=2)) 320 | 321 | if __name__ == '__main__': 322 | OPTS = parse_args() 323 | if OPTS.out_image_dir: 324 | import matplotlib 325 | matplotlib.use('Agg') 326 | import matplotlib.pyplot as plt 327 | main() 328 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Pretraining on TPUs.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | 8 | from absl import app 9 | from absl import flags 10 | import absl.logging as _logging # pylint: disable=unused-import 11 | 12 | import numpy as np 13 | 14 | import tensorflow as tf 15 | import model_utils 16 | import tpu_estimator 17 | import function_builder 18 | import data_utils 19 | 20 | # TPU parameters 21 | flags.DEFINE_string("master", default=None, 22 | help="master") 23 | flags.DEFINE_string("tpu", default=None, 24 | help="The Cloud TPU to use for training. This should be either the name " 25 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.") 26 | flags.DEFINE_string("gcp_project", default=None, 27 | help="Project name for the Cloud TPU-enabled project. If not specified, " 28 | "we will attempt to automatically detect the GCE project from metadata.") 29 | flags.DEFINE_string("tpu_zone",default=None, 30 | help="GCE zone where the Cloud TPU is located in. If not specified, we " 31 | "will attempt to automatically detect the GCE project from metadata.") 32 | flags.DEFINE_bool("use_tpu", default=True, 33 | help="Use TPUs rather than plain CPUs.") 34 | flags.DEFINE_integer("num_hosts", default=1, 35 | help="number of TPU hosts") 36 | flags.DEFINE_integer("num_core_per_host", default=8, 37 | help="number of cores per host") 38 | flags.DEFINE_bool("track_mean", default=False, 39 | help="Whether to track mean loss.") 40 | 41 | # Experiment (data/checkpoint/directory) config 42 | flags.DEFINE_integer("num_passes", default=1, 43 | help="Number of passed used for training.") 44 | flags.DEFINE_string("record_info_dir", default=None, 45 | help="Path to local directory containing `record_info-lm.json`.") 46 | flags.DEFINE_string("model_dir", default=None, 47 | help="Estimator model_dir.") 48 | flags.DEFINE_string("init_checkpoint", default=None, 49 | help="Checkpoint path for initializing the model.") 50 | 51 | # Optimization config 52 | flags.DEFINE_float("learning_rate", default=1e-4, 53 | help="Maximum learning rate.") 54 | flags.DEFINE_float("clip", default=1.0, 55 | help="Gradient clipping value.") 56 | # lr decay 57 | flags.DEFINE_float("min_lr_ratio", default=0.001, 58 | help="Minimum ratio learning rate.") 59 | flags.DEFINE_integer("warmup_steps", default=0, 60 | help="Number of steps for linear lr warmup.") 61 | flags.DEFINE_float("adam_epsilon", default=1e-8, 62 | help="Adam epsilon.") 63 | flags.DEFINE_string("decay_method", default="poly", 64 | help="Poly or cos.") 65 | flags.DEFINE_float("weight_decay", default=0.0, 66 | help="Weight decay rate.") 67 | 68 | # Training config 69 | flags.DEFINE_integer("train_batch_size", default=16, 70 | help="Size of the train batch across all hosts.") 71 | flags.DEFINE_integer("train_steps", default=100000, 72 | help="Total number of training steps.") 73 | flags.DEFINE_integer("iterations", default=1000, 74 | help="Number of iterations per repeat loop.") 75 | flags.DEFINE_integer("save_steps", default=None, 76 | help="Number of steps for model checkpointing. " 77 | "None for not saving checkpoints") 78 | flags.DEFINE_integer("max_save", default=100000, 79 | help="Maximum number of checkpoints to save.") 80 | 81 | # Data config 82 | flags.DEFINE_integer("seq_len", default=0, 83 | help="Sequence length for pretraining.") 84 | flags.DEFINE_integer("reuse_len", default=0, 85 | help="How many tokens to be reused in the next batch. " 86 | "Could be half of `seq_len`.") 87 | flags.DEFINE_bool("uncased", False, 88 | help="Use uncased inputs or not.") 89 | flags.DEFINE_integer("perm_size", 0, 90 | help="Window size of permutation.") 91 | flags.DEFINE_bool("bi_data", default=True, 92 | help="Use bidirectional data streams, i.e., forward & backward.") 93 | flags.DEFINE_integer("mask_alpha", default=6, 94 | help="How many tokens to form a group.") 95 | flags.DEFINE_integer("mask_beta", default=1, 96 | help="How many tokens to mask within each group.") 97 | flags.DEFINE_integer("num_predict", default=None, 98 | help="Number of tokens to predict in partial prediction.") 99 | flags.DEFINE_integer("n_token", 32000, help="Vocab size") 100 | 101 | # Model config 102 | flags.DEFINE_integer("mem_len", default=0, 103 | help="Number of steps to cache") 104 | flags.DEFINE_bool("same_length", default=False, 105 | help="Same length attention") 106 | flags.DEFINE_integer("clamp_len", default=-1, 107 | help="Clamp length") 108 | 109 | flags.DEFINE_integer("n_layer", default=6, 110 | help="Number of layers.") 111 | flags.DEFINE_integer("d_model", default=32, 112 | help="Dimension of the model.") 113 | flags.DEFINE_integer("d_embed", default=32, 114 | help="Dimension of the embeddings.") 115 | flags.DEFINE_integer("n_head", default=4, 116 | help="Number of attention heads.") 117 | flags.DEFINE_integer("d_head", default=8, 118 | help="Dimension of each attention head.") 119 | flags.DEFINE_integer("d_inner", default=32, 120 | help="Dimension of inner hidden size in positionwise feed-forward.") 121 | flags.DEFINE_float("dropout", default=0.0, 122 | help="Dropout rate.") 123 | flags.DEFINE_float("dropatt", default=0.0, 124 | help="Attention dropout rate.") 125 | flags.DEFINE_bool("untie_r", default=False, 126 | help="Untie r_w_bias and r_r_bias") 127 | flags.DEFINE_string("summary_type", default="last", 128 | help="Method used to summarize a sequence into a compact vector.") 129 | flags.DEFINE_string("ff_activation", default="relu", 130 | help="Activation type used in position-wise feed-forward.") 131 | flags.DEFINE_bool("use_bfloat16", False, 132 | help="Whether to use bfloat16.") 133 | 134 | # Parameter initialization 135 | flags.DEFINE_enum("init", default="normal", 136 | enum_values=["normal", "uniform"], 137 | help="Initialization method.") 138 | flags.DEFINE_float("init_std", default=0.02, 139 | help="Initialization std when init is normal.") 140 | flags.DEFINE_float("init_range", default=0.1, 141 | help="Initialization std when init is uniform.") 142 | 143 | FLAGS = flags.FLAGS 144 | 145 | 146 | def get_model_fn(): 147 | """doc.""" 148 | def model_fn(features, labels, mode, params): 149 | """doc.""" 150 | #### Training or Evaluation 151 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 152 | assert is_training 153 | 154 | #### Retrieve `mems` from `params["cache"]` 155 | mems = {} 156 | idx = 0 157 | if FLAGS.mem_len > 0: 158 | mems["mems"] = params["cache"] 159 | 160 | #### Get loss from inputs 161 | total_loss, new_mems, monitor_dict = function_builder.get_loss( 162 | FLAGS, features, labels, mems, is_training) 163 | 164 | #### Turn `new_mems` into `new_cache` 165 | new_cache = [] 166 | if FLAGS.mem_len > 0: 167 | new_cache += new_mems["mems"] 168 | 169 | #### Check model parameters 170 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) 171 | tf.logging.info("#params: {}".format(num_params)) 172 | 173 | #### Configuring the optimizer 174 | train_op, learning_rate, gnorm = model_utils.get_train_op( 175 | FLAGS, total_loss) 176 | monitor_dict["lr"] = learning_rate 177 | monitor_dict["gnorm"] = gnorm 178 | 179 | #### Customized initial checkpoint 180 | scaffold_fn = model_utils.init_from_checkpoint(FLAGS, global_vars=True) 181 | 182 | #### Creating host calls 183 | host_call = function_builder.construct_scalar_host_call( 184 | monitor_dict=monitor_dict, 185 | model_dir=FLAGS.model_dir, 186 | prefix="train/", 187 | reduce_fn=tf.reduce_mean) 188 | 189 | #### Constucting training TPUEstimatorSpec with new cache. 190 | train_spec = tf.contrib.tpu.TPUEstimatorSpec( 191 | mode=mode, loss=total_loss, train_op=train_op, host_call=host_call, 192 | scaffold_fn=scaffold_fn) 193 | 194 | train_spec.cache = new_cache 195 | 196 | return train_spec 197 | 198 | return model_fn 199 | 200 | 201 | def get_cache_fn(mem_len): 202 | """doc.""" 203 | tf_float = tf.bfloat16 if FLAGS.use_bfloat16 else tf.float32 204 | def cache_fn(batch_size): 205 | mems = [] 206 | if FLAGS.mem_len > 0: 207 | for _ in range(FLAGS.n_layer): 208 | zeros = tf.zeros( 209 | [mem_len, batch_size, FLAGS.d_model], 210 | dtype=tf_float) 211 | mems.append(zeros) 212 | 213 | return mems 214 | 215 | if mem_len > 0: 216 | return cache_fn 217 | else: 218 | return None 219 | 220 | 221 | def get_input_fn(split): 222 | """doc.""" 223 | assert split == "train" 224 | batch_size = FLAGS.train_batch_size 225 | 226 | input_fn, record_info_dict = data_utils.get_input_fn( 227 | tfrecord_dir=FLAGS.record_info_dir, 228 | split=split, 229 | bsz_per_host=batch_size // FLAGS.num_hosts, 230 | seq_len=FLAGS.seq_len, 231 | reuse_len=FLAGS.reuse_len, 232 | bi_data=FLAGS.bi_data, 233 | num_hosts=FLAGS.num_hosts, 234 | num_core_per_host=FLAGS.num_core_per_host, 235 | perm_size=FLAGS.perm_size, 236 | mask_alpha=FLAGS.mask_alpha, 237 | mask_beta=FLAGS.mask_beta, 238 | uncased=FLAGS.uncased, 239 | num_passes=FLAGS.num_passes, 240 | use_bfloat16=FLAGS.use_bfloat16, 241 | num_predict=FLAGS.num_predict) 242 | 243 | return input_fn, record_info_dict 244 | 245 | 246 | def main(unused_argv): 247 | del unused_argv # Unused 248 | 249 | tf.logging.set_verbosity(tf.logging.INFO) 250 | 251 | assert FLAGS.seq_len > 0 252 | assert FLAGS.perm_size > 0 253 | 254 | FLAGS.n_token = data_utils.VOCAB_SIZE 255 | tf.logging.info("n_token {}".format(FLAGS.n_token)) 256 | 257 | if not tf.gfile.Exists(FLAGS.model_dir): 258 | tf.gfile.MakeDirs(FLAGS.model_dir) 259 | 260 | # Get train input function 261 | train_input_fn, train_record_info_dict = get_input_fn("train") 262 | 263 | tf.logging.info("num of batches {}".format( 264 | train_record_info_dict["num_batch"])) 265 | 266 | # Get train cache function 267 | train_cache_fn = get_cache_fn(FLAGS.mem_len) 268 | 269 | ##### Get model function 270 | model_fn = get_model_fn() 271 | 272 | ##### Create TPUEstimator 273 | # TPU Configuration 274 | run_config = model_utils.configure_tpu(FLAGS) 275 | 276 | # TPU Estimator 277 | estimator = tpu_estimator.TPUEstimator( 278 | model_fn=model_fn, 279 | train_cache_fn=train_cache_fn, 280 | use_tpu=FLAGS.use_tpu, 281 | config=run_config, 282 | params={"track_mean": FLAGS.track_mean}, 283 | train_batch_size=FLAGS.train_batch_size, 284 | eval_on_tpu=FLAGS.use_tpu) 285 | 286 | #### Training 287 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps) 288 | 289 | 290 | if __name__ == "__main__": 291 | app.run(main) 292 | -------------------------------------------------------------------------------- /train_gpu.py: -------------------------------------------------------------------------------- 1 | """Pretraining on GPUs.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os, sys 7 | import math 8 | import json 9 | import time 10 | import numpy as np 11 | 12 | from absl import flags 13 | import absl.logging as _logging # pylint: disable=unused-import 14 | 15 | import tensorflow as tf 16 | 17 | import data_utils 18 | import model_utils 19 | from gpu_utils import assign_to_gpu, average_grads_and_vars 20 | import function_builder 21 | 22 | 23 | # GPU config 24 | flags.DEFINE_integer("num_hosts", default=1, 25 | help="Number of hosts") 26 | flags.DEFINE_integer("num_core_per_host", default=8, 27 | help="Number of cores per host") 28 | flags.DEFINE_bool("use_tpu", default=False, 29 | help="Whether to use TPUs for training.") 30 | 31 | # Experiment (data/checkpoint/directory) config 32 | flags.DEFINE_integer("num_passes", default=1, 33 | help="Number of passed used for training.") 34 | flags.DEFINE_string("record_info_dir", default=None, 35 | help="Path to local directory containing `record_info-lm.json`.") 36 | flags.DEFINE_string("model_dir", default=None, 37 | help="Estimator model_dir.") 38 | flags.DEFINE_string("init_checkpoint", default=None, 39 | help="checkpoint path for initializing the model.") 40 | 41 | # Optimization config 42 | flags.DEFINE_float("learning_rate", default=1e-4, 43 | help="Maximum learning rate.") 44 | flags.DEFINE_float("clip", default=1.0, 45 | help="Gradient clipping value.") 46 | # for cosine decay 47 | flags.DEFINE_float("min_lr_ratio", default=0.001, 48 | help="Minimum ratio learning rate.") 49 | flags.DEFINE_integer("warmup_steps", default=0, 50 | help="Number of steps for linear lr warmup.") 51 | flags.DEFINE_float("adam_epsilon", default=1e-8, 52 | help="Adam epsilon") 53 | flags.DEFINE_string("decay_method", default="poly", 54 | help="poly or cos") 55 | flags.DEFINE_float("weight_decay", default=0.0, 56 | help="weight decay") 57 | 58 | # Training config 59 | flags.DEFINE_integer("train_batch_size", default=16, 60 | help="Size of train batch.") 61 | flags.DEFINE_integer("train_steps", default=100000, 62 | help="Total number of training steps.") 63 | flags.DEFINE_integer("iterations", default=1000, 64 | help="Number of iterations per repeat loop.") 65 | flags.DEFINE_integer("save_steps", default=None, 66 | help="number of steps for model checkpointing.") 67 | 68 | # Data config 69 | flags.DEFINE_integer('seq_len', default=0, 70 | help='Sequence length for pretraining.') 71 | flags.DEFINE_integer('reuse_len', default=0, 72 | help="How many tokens to be reused in the next batch. " 73 | "Could be half of seq_len") 74 | flags.DEFINE_bool("bi_data", default=True, 75 | help="Use bidirectional data streams, i.e., forward & backward.") 76 | flags.DEFINE_integer("mask_alpha", default=6, 77 | help="How many tokens to form a group.") 78 | flags.DEFINE_integer("mask_beta", default=1, 79 | help="How many tokens to mask within each group.") 80 | flags.DEFINE_integer("num_predict", default=None, 81 | help="Number of tokens to predict in partial prediction.") 82 | flags.DEFINE_integer('perm_size', default=None, 83 | help='perm size.') 84 | flags.DEFINE_bool("uncased", False, 85 | help="Use uncased inputs or not.") 86 | flags.DEFINE_integer("n_token", 32000, help="Vocab size") 87 | 88 | # Model config 89 | flags.DEFINE_integer("mem_len", default=0, 90 | help="Number of steps to cache") 91 | flags.DEFINE_bool("same_length", default=False, 92 | help="Same length attention") 93 | flags.DEFINE_integer("clamp_len", default=-1, 94 | help="Clamp length") 95 | 96 | flags.DEFINE_integer("n_layer", default=6, 97 | help="Number of layers.") 98 | flags.DEFINE_integer("d_model", default=32, 99 | help="Dimension of the model.") 100 | flags.DEFINE_integer("d_embed", default=32, 101 | help="Dimension of the embeddings.") 102 | flags.DEFINE_integer("n_head", default=4, 103 | help="Number of attention heads.") 104 | flags.DEFINE_integer("d_head", default=8, 105 | help="Dimension of each attention head.") 106 | flags.DEFINE_integer("d_inner", default=32, 107 | help="Dimension of inner hidden size in positionwise feed-forward.") 108 | flags.DEFINE_float("dropout", default=0.0, 109 | help="Dropout rate.") 110 | flags.DEFINE_float("dropatt", default=0.0, 111 | help="Attention dropout rate.") 112 | flags.DEFINE_bool("untie_r", default=False, 113 | help="Untie r_w_bias and r_r_bias") 114 | flags.DEFINE_string("summary_type", default="last", 115 | help="Method used to summarize a sequence into a compact vector.") 116 | flags.DEFINE_string("ff_activation", default="relu", 117 | help="Activation type used in position-wise feed-forward.") 118 | flags.DEFINE_bool("use_bfloat16", False, 119 | help="Whether to use bfloat16.") 120 | 121 | # Parameter initialization 122 | flags.DEFINE_enum("init", default="normal", 123 | enum_values=["normal", "uniform"], 124 | help="Initialization method.") 125 | flags.DEFINE_float("init_std", default=0.02, 126 | help="Initialization std when init is normal.") 127 | flags.DEFINE_float("init_range", default=0.1, 128 | help="Initialization std when init is uniform.") 129 | 130 | 131 | FLAGS = flags.FLAGS 132 | 133 | 134 | def get_model_fn(): 135 | def model_fn(features, labels, mems, is_training): 136 | #### Get loss from inputs 137 | total_loss, new_mems, monitor_dict = function_builder.get_loss( 138 | FLAGS, features, labels, mems, is_training) 139 | 140 | #### Check model parameters 141 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) 142 | tf.logging.info('#params: {}'.format(num_params)) 143 | 144 | # GPU 145 | assert is_training 146 | all_vars = tf.trainable_variables() 147 | grads = tf.gradients(total_loss, all_vars) 148 | grads_and_vars = list(zip(grads, all_vars)) 149 | 150 | return total_loss, new_mems, grads_and_vars 151 | 152 | return model_fn 153 | 154 | 155 | def single_core_graph(is_training, features, mems): 156 | model_fn = get_model_fn() 157 | 158 | model_ret = model_fn( 159 | features=features, 160 | labels=None, 161 | mems=mems, 162 | is_training=is_training) 163 | 164 | return model_ret 165 | 166 | 167 | def create_mems_tf(bsz_per_core): 168 | mems = [tf.placeholder(dtype=tf.float32, 169 | shape=[FLAGS.mem_len, bsz_per_core, FLAGS.d_model]) 170 | for layer in range(FLAGS.n_layer)] 171 | 172 | return mems 173 | 174 | 175 | def initialize_mems_np(bsz_per_core): 176 | mems_np = [np.zeros(shape=[FLAGS.mem_len, bsz_per_core, FLAGS.d_model], 177 | dtype=np.float32) 178 | for layer in range(FLAGS.n_layer)] 179 | 180 | return mems_np 181 | 182 | 183 | def train(ps_device): 184 | ##### Get input function and model function 185 | 186 | train_input_fn, record_info_dict = data_utils.get_input_fn( 187 | tfrecord_dir=FLAGS.record_info_dir, 188 | split="train", 189 | bsz_per_host=FLAGS.train_batch_size, 190 | seq_len=FLAGS.seq_len, 191 | reuse_len=FLAGS.reuse_len, 192 | bi_data=FLAGS.bi_data, 193 | num_hosts=1, 194 | num_core_per_host=1, # set to one no matter how many GPUs 195 | perm_size=FLAGS.perm_size, 196 | mask_alpha=FLAGS.mask_alpha, 197 | mask_beta=FLAGS.mask_beta, 198 | uncased=FLAGS.uncased, 199 | num_passes=FLAGS.num_passes, 200 | use_bfloat16=FLAGS.use_bfloat16, 201 | num_predict=FLAGS.num_predict) 202 | 203 | # for key, info in record_info_dict.items(): 204 | tf.logging.info("num of batches {}".format(record_info_dict["num_batch"])) 205 | 206 | ##### Create input tensors / placeholders 207 | bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host 208 | 209 | params = { 210 | "batch_size": FLAGS.train_batch_size # the whole batch 211 | } 212 | train_set = train_input_fn(params) 213 | 214 | example = train_set.make_one_shot_iterator().get_next() 215 | 216 | if FLAGS.num_core_per_host > 1: 217 | examples = [{} for _ in range(FLAGS.num_core_per_host)] 218 | for key in example.keys(): 219 | vals = tf.split(example[key], FLAGS.num_core_per_host, 0) 220 | for device_id in range(FLAGS.num_core_per_host): 221 | examples[device_id][key] = vals[device_id] 222 | else: 223 | examples = [example] 224 | 225 | ##### Create computational graph 226 | tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], [] 227 | 228 | for i in range(FLAGS.num_core_per_host): 229 | reuse = True if i > 0 else None 230 | with tf.device(assign_to_gpu(i, ps_device)), \ 231 | tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 232 | 233 | # The mems for each tower is a dictionary 234 | mems_i = {} 235 | if FLAGS.mem_len: 236 | mems_i["mems"] = create_mems_tf(bsz_per_core) 237 | 238 | loss_i, new_mems_i, grads_and_vars_i = single_core_graph( 239 | is_training=True, 240 | features=examples[i], 241 | mems=mems_i) 242 | 243 | tower_mems.append(mems_i) 244 | tower_losses.append(loss_i) 245 | tower_new_mems.append(new_mems_i) 246 | tower_grads_and_vars.append(grads_and_vars_i) 247 | 248 | ## average losses and gradients across towers 249 | if len(tower_losses) > 1: 250 | loss = tf.add_n(tower_losses) / len(tower_losses) 251 | grads_and_vars = average_grads_and_vars(tower_grads_and_vars) 252 | else: 253 | loss = tower_losses[0] 254 | grads_and_vars = tower_grads_and_vars[0] 255 | 256 | ## get train op 257 | train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None, 258 | grads_and_vars=grads_and_vars) 259 | global_step = tf.train.get_global_step() 260 | 261 | ##### Training loop 262 | # initialize mems 263 | tower_mems_np = [] 264 | for i in range(FLAGS.num_core_per_host): 265 | mems_i_np = {} 266 | for key in tower_mems[i].keys(): 267 | mems_i_np[key] = initialize_mems_np(bsz_per_core) 268 | tower_mems_np.append(mems_i_np) 269 | 270 | saver = tf.train.Saver() 271 | 272 | gpu_options = tf.GPUOptions(allow_growth=True) 273 | 274 | model_utils.init_from_checkpoint(FLAGS, global_vars=True) 275 | 276 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, 277 | gpu_options=gpu_options)) as sess: 278 | sess.run(tf.global_variables_initializer()) 279 | 280 | fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op] 281 | 282 | total_loss, prev_step = 0., -1 283 | while True: 284 | feed_dict = {} 285 | for i in range(FLAGS.num_core_per_host): 286 | for key in tower_mems_np[i].keys(): 287 | for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]): 288 | feed_dict[m] = m_np 289 | 290 | fetched = sess.run(fetches, feed_dict=feed_dict) 291 | 292 | loss_np, tower_mems_np, curr_step = fetched[:3] 293 | total_loss += loss_np 294 | 295 | if curr_step > 0 and curr_step % FLAGS.iterations == 0: 296 | curr_loss = total_loss / (curr_step - prev_step) 297 | tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} " 298 | "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format( 299 | curr_step, fetched[-3], fetched[-2], 300 | curr_loss, math.exp(curr_loss), curr_loss / math.log(2))) 301 | total_loss, prev_step = 0., curr_step 302 | 303 | if curr_step > 0 and curr_step % FLAGS.save_steps == 0: 304 | save_path = os.path.join(FLAGS.model_dir, "model.ckpt") 305 | saver.save(sess, save_path) 306 | tf.logging.info("Model saved in path: {}".format(save_path)) 307 | 308 | if curr_step >= FLAGS.train_steps: 309 | break 310 | 311 | 312 | def main(unused_argv): 313 | del unused_argv # Unused 314 | 315 | tf.logging.set_verbosity(tf.logging.INFO) 316 | 317 | # Get corpus info 318 | FLAGS.n_token = data_utils.VOCAB_SIZE 319 | tf.logging.info("n_token {}".format(FLAGS.n_token)) 320 | 321 | if not tf.gfile.Exists(FLAGS.model_dir): 322 | tf.gfile.MakeDirs(FLAGS.model_dir) 323 | 324 | train("/gpu:0") 325 | 326 | 327 | if __name__ == "__main__": 328 | tf.app.run() 329 | -------------------------------------------------------------------------------- /xlnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import json 6 | import os 7 | import tensorflow as tf 8 | import modeling 9 | 10 | 11 | def _get_initializer(FLAGS): 12 | """Get variable intializer.""" 13 | if FLAGS.init == "uniform": 14 | initializer = tf.initializers.random_uniform( 15 | minval=-FLAGS.init_range, 16 | maxval=FLAGS.init_range, 17 | seed=None) 18 | elif FLAGS.init == "normal": 19 | initializer = tf.initializers.random_normal( 20 | stddev=FLAGS.init_std, 21 | seed=None) 22 | else: 23 | raise ValueError("Initializer {} not supported".format(FLAGS.init)) 24 | return initializer 25 | 26 | 27 | class XLNetConfig(object): 28 | """XLNetConfig contains hyperparameters that are specific to a model checkpoint; 29 | i.e., these hyperparameters should be the same between 30 | pretraining and finetuning. 31 | 32 | The following hyperparameters are defined: 33 | n_layer: int, the number of layers. 34 | d_model: int, the hidden size. 35 | n_head: int, the number of attention heads. 36 | d_head: int, the dimension size of each attention head. 37 | d_inner: int, the hidden size in feed-forward layers. 38 | ff_activation: str, "relu" or "gelu". 39 | untie_r: bool, whether to untie the biases in attention. 40 | n_token: int, the vocab size. 41 | """ 42 | 43 | def __init__(self, FLAGS=None, json_path=None): 44 | """Constructing an XLNetConfig. 45 | One of FLAGS or json_path should be provided.""" 46 | 47 | assert FLAGS is not None or json_path is not None 48 | 49 | self.keys = ["n_layer", "d_model", "n_head", "d_head", "d_inner", 50 | "ff_activation", "untie_r", "n_token"] 51 | 52 | if FLAGS is not None: 53 | self.init_from_flags(FLAGS) 54 | 55 | if json_path is not None: 56 | self.init_from_json(json_path) 57 | 58 | def init_from_flags(self, FLAGS): 59 | for key in self.keys: 60 | setattr(self, key, getattr(FLAGS, key)) 61 | 62 | def init_from_json(self, json_path): 63 | with tf.gfile.Open(json_path) as f: 64 | json_data = json.load(f) 65 | for key in self.keys: 66 | setattr(self, key, json_data[key]) 67 | 68 | def to_json(self, json_path): 69 | """Save XLNetConfig to a json file.""" 70 | json_data = {} 71 | for key in self.keys: 72 | json_data[key] = getattr(self, key) 73 | 74 | json_dir = os.path.dirname(json_path) 75 | if not tf.gfile.Exists(json_dir): 76 | tf.gfile.MakeDirs(json_dir) 77 | with tf.gfile.Open(json_path, "w") as f: 78 | json.dump(json_data, f, indent=4, sort_keys=True) 79 | 80 | 81 | def create_run_config(is_training, is_finetune, FLAGS): 82 | kwargs = dict( 83 | is_training=is_training, 84 | use_tpu=FLAGS.use_tpu, 85 | use_bfloat16=FLAGS.use_bfloat16, 86 | dropout=FLAGS.dropout, 87 | dropatt=FLAGS.dropatt, 88 | init=FLAGS.init, 89 | init_range=FLAGS.init_range, 90 | init_std=FLAGS.init_std, 91 | clamp_len=FLAGS.clamp_len) 92 | 93 | if not is_finetune: 94 | kwargs.update(dict( 95 | mem_len=FLAGS.mem_len, 96 | reuse_len=FLAGS.reuse_len, 97 | bi_data=FLAGS.bi_data, 98 | clamp_len=FLAGS.clamp_len, 99 | same_length=FLAGS.same_length)) 100 | 101 | return RunConfig(**kwargs) 102 | 103 | 104 | class RunConfig(object): 105 | """RunConfig contains hyperparameters that could be different 106 | between pretraining and finetuning. 107 | These hyperparameters can also be changed from run to run. 108 | We store them separately from XLNetConfig for flexibility. 109 | """ 110 | 111 | def __init__(self, is_training, use_tpu, use_bfloat16, dropout, dropatt, 112 | init="normal", init_range=0.1, init_std=0.02, mem_len=None, 113 | reuse_len=None, bi_data=False, clamp_len=-1, same_length=False): 114 | """ 115 | Args: 116 | is_training: bool, whether in training mode. 117 | use_tpu: bool, whether TPUs are used. 118 | use_bfloat16: bool, use bfloat16 instead of float32. 119 | dropout: float, dropout rate. 120 | dropatt: float, dropout rate on attention probabilities. 121 | init: str, the initialization scheme, either "normal" or "uniform". 122 | init_range: float, initialize the parameters with a uniform distribution 123 | in [-init_range, init_range]. Only effective when init="uniform". 124 | init_std: float, initialize the parameters with a normal distribution 125 | with mean 0 and stddev init_std. Only effective when init="normal". 126 | mem_len: int, the number of tokens to cache. 127 | reuse_len: int, the number of tokens in the currect batch to be cached 128 | and reused in the future. 129 | bi_data: bool, whether to use bidirectional input pipeline. 130 | Usually set to True during pretraining and False during finetuning. 131 | clamp_len: int, clamp all relative distances larger than clamp_len. 132 | -1 means no clamping. 133 | same_length: bool, whether to use the same attention length for each token. 134 | """ 135 | 136 | self.init = init 137 | self.init_range = init_range 138 | self.init_std = init_std 139 | self.is_training = is_training 140 | self.dropout = dropout 141 | self.dropatt = dropatt 142 | self.use_tpu = use_tpu 143 | self.use_bfloat16 = use_bfloat16 144 | self.mem_len = mem_len 145 | self.reuse_len = reuse_len 146 | self.bi_data = bi_data 147 | self.clamp_len = clamp_len 148 | self.same_length = same_length 149 | 150 | 151 | class XLNetModel(object): 152 | """A wrapper of the XLNet model used during both pretraining and finetuning.""" 153 | 154 | def __init__(self, xlnet_config, run_config, input_ids, seg_ids, input_mask, 155 | mems=None, perm_mask=None, target_mapping=None, inp_q=None, 156 | **kwargs): 157 | """ 158 | Args: 159 | xlnet_config: XLNetConfig, 160 | run_config: RunConfig, 161 | input_ids: int32 Tensor in shape [len, bsz], the input token IDs. 162 | seg_ids: int32 Tensor in shape [len, bsz], the input segment IDs. 163 | input_mask: float32 Tensor in shape [len, bsz], the input mask. 164 | 0 for real tokens and 1 for padding. 165 | mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory 166 | from previous batches. The length of the list equals n_layer. 167 | If None, no memory is used. 168 | perm_mask: float32 Tensor in shape [len, len, bsz]. 169 | If perm_mask[i, j, k] = 0, i attend to j in batch k; 170 | if perm_mask[i, j, k] = 1, i does not attend to j in batch k. 171 | If None, each position attends to all the others. 172 | target_mapping: float32 Tensor in shape [num_predict, len, bsz]. 173 | If target_mapping[i, j, k] = 1, the i-th predict in batch k is 174 | on the j-th token. 175 | Only used during pretraining for partial prediction. 176 | Set to None during finetuning. 177 | inp_q: float32 Tensor in shape [len, bsz]. 178 | 1 for tokens with losses and 0 for tokens without losses. 179 | Only used during pretraining for two-stream attention. 180 | Set to None during finetuning. 181 | """ 182 | 183 | initializer = _get_initializer(run_config) 184 | 185 | tfm_args = dict( 186 | n_token=xlnet_config.n_token, 187 | initializer=initializer, 188 | attn_type="bi", 189 | n_layer=xlnet_config.n_layer, 190 | d_model=xlnet_config.d_model, 191 | n_head=xlnet_config.n_head, 192 | d_head=xlnet_config.d_head, 193 | d_inner=xlnet_config.d_inner, 194 | ff_activation=xlnet_config.ff_activation, 195 | untie_r=xlnet_config.untie_r, 196 | 197 | is_training=run_config.is_training, 198 | use_bfloat16=run_config.use_bfloat16, 199 | use_tpu=run_config.use_tpu, 200 | dropout=run_config.dropout, 201 | dropatt=run_config.dropatt, 202 | 203 | mem_len=run_config.mem_len, 204 | reuse_len=run_config.reuse_len, 205 | bi_data=run_config.bi_data, 206 | clamp_len=run_config.clamp_len, 207 | same_length=run_config.same_length 208 | ) 209 | 210 | input_args = dict( 211 | inp_k=input_ids, 212 | seg_id=seg_ids, 213 | input_mask=input_mask, 214 | mems=mems, 215 | perm_mask=perm_mask, 216 | target_mapping=target_mapping, 217 | inp_q=inp_q) 218 | tfm_args.update(input_args) 219 | 220 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 221 | (self.output, self.new_mems, self.lookup_table 222 | ) = modeling.transformer_xl(**tfm_args) 223 | 224 | self.input_mask = input_mask 225 | self.initializer = initializer 226 | self.xlnet_config = xlnet_config 227 | self.run_config = run_config 228 | 229 | def get_pooled_out(self, summary_type, use_summ_proj=True): 230 | """ 231 | Args: 232 | summary_type: str, "last", "first", "mean", or "attn". The method 233 | to pool the input to get a vector representation. 234 | use_summ_proj: bool, whether to use a linear projection during pooling. 235 | 236 | Returns: 237 | float32 Tensor in shape [bsz, d_model], the pooled representation. 238 | """ 239 | 240 | xlnet_config = self.xlnet_config 241 | run_config = self.run_config 242 | 243 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE): 244 | summary = modeling.summarize_sequence( 245 | summary_type=summary_type, 246 | hidden=self.output, 247 | d_model=xlnet_config.d_model, 248 | n_head=xlnet_config.n_head, 249 | d_head=xlnet_config.d_head, 250 | dropout=run_config.dropout, 251 | dropatt=run_config.dropatt, 252 | is_training=run_config.is_training, 253 | input_mask=self.input_mask, 254 | initializer=self.initializer, 255 | use_proj=use_summ_proj) 256 | 257 | return summary 258 | 259 | def get_sequence_output(self): 260 | """ 261 | Returns: 262 | float32 Tensor in shape [len, bsz, d_model]. The last layer hidden 263 | representation of XLNet. 264 | """ 265 | 266 | return self.output 267 | 268 | def get_new_memory(self): 269 | """ 270 | Returns: 271 | list of float32 Tensors in shape [mem_len, bsz, d_model], the new 272 | memory that concatenates the previous memory with the current input 273 | representations. 274 | The length of the list equals n_layer. 275 | """ 276 | return self.new_mems 277 | 278 | def get_embedding_table(self): 279 | """ 280 | Returns: 281 | float32 Tensor in shape [n_token, d_model]. The embedding lookup table. 282 | Used for tying embeddings between input and output layers. 283 | """ 284 | return self.lookup_table 285 | 286 | def get_initializer(self): 287 | """ 288 | Returns: 289 | A tf initializer. Used to initialize variables in layers on top of XLNet. 290 | """ 291 | return self.initializer 292 | 293 | --------------------------------------------------------------------------------