├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── classifier_utils.py
├── data_utils.py
├── function_builder.py
├── gpu_utils.py
├── misc
├── race_example.md
└── slides.pdf
├── model_utils.py
├── modeling.py
├── notebooks
└── colab_imdb_gpu.ipynb
├── prepro_utils.py
├── run_classifier.py
├── run_race.py
├── run_squad.py
├── scripts
├── gpu_squad_base.sh
├── prepro_squad.sh
├── tpu_race_large_bsz32.sh
├── tpu_race_large_bsz8.sh
└── tpu_squad_large.sh
├── squad_utils.py
├── tpu_estimator.py
├── train.py
├── train_gpu.py
└── xlnet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2019 XLNet Authors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | **XLNet** is a new unsupervised language representation learning method based on a novel generalized permutation language modeling objective. Additionally, XLNet employs [Transformer-XL](https://arxiv.org/abs/1901.02860) as the backbone model, exhibiting excellent performance for language tasks involving long context. Overall, XLNet achieves state-of-the-art (SOTA) results on various downstream language tasks including question answering, natural language inference, sentiment analysis, and document ranking.
4 |
5 | For a detailed description of technical details and experimental results, please refer to our paper:
6 |
7 | [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
8 |
9 | Zhilin Yang\*, Zihang Dai\*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le
10 |
11 | (*: equal contribution)
12 |
13 | Preprint 2019
14 |
15 |
16 |
17 |
18 | ## Release Notes
19 |
20 | * July 16, 2019: XLNet-Base.
21 | * June 19, 2019: initial release with XLNet-Large and code.
22 |
23 | ## Results
24 |
25 | As of June 19, 2019, XLNet outperforms BERT on 20 tasks and achieves state-of-the-art results on 18 tasks. Below are some comparison between XLNet-Large and BERT-Large, which have similar model sizes:
26 |
27 | ### Results on Reading Comprehension
28 |
29 | Model | [RACE accuracy](http://www.qizhexie.com/data/RACE_leaderboard.html) | SQuAD1.1 EM | SQuAD2.0 EM
30 | --- | --- | --- | ---
31 | BERT-Large | 72.0 | 84.1 | 78.98
32 | XLNet-Base | | | 80.18
33 | XLNet-Large | **81.75** | **88.95** | **86.12**
34 |
35 | We use SQuAD dev results in the table to exclude other factors such as using additional training data or other data augmentation techniques. See [SQuAD leaderboard](https://rajpurkar.github.io/SQuAD-explorer/) for test numbers.
36 |
37 | ### Results on Text Classification
38 |
39 | Model | IMDB | Yelp-2 | Yelp-5 | DBpedia | Amazon-2 | Amazon-5
40 | --- | --- | --- | --- | --- | --- | ---
41 | BERT-Large | 4.51 | 1.89 | 29.32 | 0.64 | 2.63 | 34.17
42 | XLNet-Large | **3.79** | **1.55** | **27.80** | **0.62** | **2.40** | **32.26**
43 |
44 | The above numbers are error rates.
45 |
46 | ### Results on GLUE
47 |
48 | Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
49 | --- | --- | --- | --- | --- | --- | --- | --- | ---
50 | BERT-Large | 86.6 | 92.3 | 91.3 | 70.4 | 93.2 | 88.0 | 60.6 | 90.0
51 | XLNet-Base | 86.8 | 91.7 | 91.4 | 74.0 | 94.7 | 88.2 | 60.2 | 89.5
52 | XLNet-Large | **89.8** | **93.9** | **91.8** | **83.8** | **95.6** | **89.2** | **63.6** | **91.8**
53 |
54 | We use single-task dev results in the table to exclude other factors such as multi-task learning or using ensembles.
55 |
56 | ## Pre-trained models
57 |
58 | ### Released Models
59 |
60 | As of July 16, 2019, the following models have been made available:
61 | * **[`XLNet-Large, Cased`](https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip)**: 24-layer, 1024-hidden, 16-heads
62 | * **[`XLNet-Base, Cased`](https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip)**: 12-layer, 768-hidden, 12-heads. This model is trained on full data (different from the one in the paper).
63 |
64 | We only release cased models for now because on the tasks we consider, we found: (1) for the base setting, cased and uncased models have similar performance; (2) for the large setting, cased models are a bit better in some tasks.
65 |
66 | Each .zip file contains three items:
67 | * A TensorFlow checkpoint (`xlnet_model.ckpt`) containing the pre-trained weights (which is actually 3 files).
68 | * A [Sentence Piece](https://github.com/google/sentencepiece) model (`spiece.model`) used for (de)tokenization.
69 | * A config file (`xlnet_config.json`) which specifies the hyperparameters of the model.
70 |
71 |
72 | ### Future Release Plan
73 |
74 | We also plan to continuously release more pretrained models under different settings, including:
75 | * A pretrained model that is **finetuned on Wikipedia**. This can be used for tasks with Wikipedia text such as SQuAD and HotpotQA.
76 | * Pretrained models with other hyperparameter configurations, targeting specific downstream tasks.
77 | * Pretrained models that benefit from new techniques.
78 |
79 | ### Subscribing to XLNet on Google Groups
80 |
81 | To receive notifications about updates, announcements and new releases, we recommend subscribing to the XLNet on [Google Groups](https://groups.google.com/forum/#!forum/xlnet).
82 |
83 |
84 |
85 | ## Fine-tuning with XLNet
86 |
87 | As of June 19, 2019, this code base has been tested with TensorFlow 1.13.1 under Python2.
88 |
89 | ### Memory Issue during Finetuning
90 |
91 | - Most of the SOTA results in our paper were produced on TPUs, which generally have more RAM than common GPUs. As a result, it is currently very difficult (costly) to re-produce most of the `XLNet-Large` SOTA results in the paper using GPUs with 12GB - 16GB of RAM, because a 16GB GPU is only able to hold a single sequence with length 512 for `XLNet-Large`. Therefore, a large number (ranging from 32 to 128, equal to `batch_size`) of GPUs are required to reproduce many results in the paper.
92 | - We are experimenting with gradient accumulation to potentially relieve the memory burden, which could be included in a near-future update.
93 | - **Alternative methods** of finetuning XLNet on **constrained hardware** have been presented in [renatoviolin's repo](https://github.com/renatoviolin/xlnet), which obtained 86.24 F1 on SQuAD2.0 with a 8GB memory GPU.
94 |
95 | Given the memory issue mentioned above, using the default finetuning scripts (`run_classifier.py` and `run_squad.py`), we benchmarked the maximum batch size on a single **16GB** GPU with TensorFlow **1.13.1**:
96 |
97 | | System | Seq Length | Max Batch Size |
98 | | ------------- | ---------- | -------------- |
99 | | `XLNet-Base` | 64 | 120 |
100 | | ... | 128 | 56 |
101 | | ... | 256 | 24 |
102 | | ... | 512 | 8 |
103 | | `XLNet-Large` | 64 | 16 |
104 | | ... | 128 | 8 |
105 | | ... | 256 | 2 |
106 | | ... | 512 | 1 |
107 |
108 | In most cases, it is possible to reduce the batch size `train_batch_size` or the maximum sequence length `max_seq_length` to fit in given hardware. The decrease in performance depends on the task and the available resources.
109 |
110 |
111 | ### Text Classification/Regression
112 |
113 | The code used to perform classification/regression finetuning is in `run_classifier.py`. It also contains examples for standard one-document classification, one-document regression, and document pair classification. Here, we provide two concrete examples of how `run_classifier.py` can be used.
114 |
115 | From here on, we assume XLNet-Large and XLNet-base has been downloaded to `$LARGE_DIR` and `$BASE_DIR` respectively.
116 |
117 |
118 | #### (1) STS-B: sentence pair relevance regression (with GPUs)
119 |
120 | - Download the [GLUE data](https://gluebenchmark.com/tasks) by running [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) and unpack it to some directory `$GLUE_DIR`.
121 |
122 | - Perform **multi-GPU** (4 V100 GPUs) finetuning with XLNet-Large by running
123 |
124 | ```shell
125 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_classifier.py \
126 | --do_train=True \
127 | --do_eval=False \
128 | --task_name=sts-b \
129 | --data_dir=${GLUE_DIR}/STS-B \
130 | --output_dir=proc_data/sts-b \
131 | --model_dir=exp/sts-b \
132 | --uncased=False \
133 | --spiece_model_file=${LARGE_DIR}/spiece.model \
134 | --model_config_path=${LARGE_DIR}/xlnet_config.json \
135 | --init_checkpoint=${LARGE_DIR}/xlnet_model.ckpt \
136 | --max_seq_length=128 \
137 | --train_batch_size=8 \
138 | --num_hosts=1 \
139 | --num_core_per_host=4 \
140 | --learning_rate=5e-5 \
141 | --train_steps=1200 \
142 | --warmup_steps=120 \
143 | --save_steps=600 \
144 | --is_regression=True
145 | ```
146 |
147 | - Evaluate the finetuning results with a single GPU by
148 |
149 | ```shell
150 | CUDA_VISIBLE_DEVICES=0 python run_classifier.py \
151 | --do_train=False \
152 | --do_eval=True \
153 | --task_name=sts-b \
154 | --data_dir=${GLUE_DIR}/STS-B \
155 | --output_dir=proc_data/sts-b \
156 | --model_dir=exp/sts-b \
157 | --uncased=False \
158 | --spiece_model_file=${LARGE_DIR}/spiece.model \
159 | --model_config_path=${LARGE_DIR}/xlnet_config.json \
160 | --max_seq_length=128 \
161 | --eval_batch_size=8 \
162 | --num_hosts=1 \
163 | --num_core_per_host=1 \
164 | --eval_all_ckpt=True \
165 | --is_regression=True
166 |
167 | # Expected performance: "eval_pearsonr 0.916+ "
168 | ```
169 |
170 | **Notes**:
171 |
172 | - In the context of GPU training, `num_core_per_host` denotes the number of GPUs to use.
173 | - In the multi-GPU setting, `train_batch_size` refers to the per-GPU batch size.
174 | - `eval_all_ckpt` allows one to evaluate all saved checkpoints (save frequency is controlled by `save_steps`) after training finishes and choose the best model based on dev performance.
175 | - `data_dir` and `output_dir` refer to the directories of the "raw data" and "preprocessed tfrecords" respectively, while `model_dir` is the working directory for saving checkpoints and tensorflow events. **`model_dir` should be set as a separate folder to `init_checkpoint`.**
176 | - To try out XLNet-base, one can simply set `--train_batch_size=32` and `--num_core_per_host=1`, along with according changes in `init_checkpoint` and `model_config_path`.
177 | - For GPUs with smaller RAM, please proportionally decrease the `train_batch_size` and increase `num_core_per_host` to use the same training setting.
178 | - **Important**: we separate the training and evaluation into "two phases", as using multi GPUs to perform evaluation is tricky (one has to correctly separate the data across GPUs). To ensure correctness, we only support single-GPU evaluation for now.
179 |
180 |
181 | #### (2) IMDB: movie review sentiment classification (with TPU V3-8)
182 |
183 | - Download and unpack the IMDB dataset by running
184 |
185 | ```shell
186 | wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
187 | tar zxvf aclImdb_v1.tar.gz
188 | ```
189 |
190 | - Launch a Google cloud TPU V3-8 instance (see the [Google Cloud TPU tutorial](https://cloud.google.com/tpu/docs/tutorials/mnist) for how to set up Cloud TPUs).
191 |
192 | - Set up your Google storage bucket path `$GS_ROOT` and move the IMDB dataset and pretrained checkpoint into your Google storage.
193 |
194 | - Perform TPU finetuning with XLNet-Large by running
195 |
196 | ```shell
197 | python run_classifier.py \
198 | --use_tpu=True \
199 | --tpu=${TPU_NAME} \
200 | --do_train=True \
201 | --do_eval=True \
202 | --eval_all_ckpt=True \
203 | --task_name=imdb \
204 | --data_dir=${IMDB_DIR} \
205 | --output_dir=${GS_ROOT}/proc_data/imdb \
206 | --model_dir=${GS_ROOT}/exp/imdb \
207 | --uncased=False \
208 | --spiece_model_file=${LARGE_DIR}/spiece.model \
209 | --model_config_path=${GS_ROOT}/${LARGE_DIR}/model_config.json \
210 | --init_checkpoint=${GS_ROOT}/${LARGE_DIR}/xlnet_model.ckpt \
211 | --max_seq_length=512 \
212 | --train_batch_size=32 \
213 | --eval_batch_size=8 \
214 | --num_hosts=1 \
215 | --num_core_per_host=8 \
216 | --learning_rate=2e-5 \
217 | --train_steps=4000 \
218 | --warmup_steps=500 \
219 | --save_steps=500 \
220 | --iterations=500
221 |
222 | # Expected performance: "eval_accuracy 0.962+ "
223 | ```
224 |
225 | **Notes**:
226 |
227 | - To obtain the SOTA on the IMDB dataset, using sequence length 512 is **necessary**. Therefore, we show how this can be done with a TPU V3-8.
228 | - Alternatively, one can use a sequence length smaller than 512, a smaller batch size, or switch to XLNet-base to train on GPUs. But performance drop is expected.
229 | - Notice that the `data_dir` and `spiece_model_file` both use a local path rather than a Google Storage path. The reason is that data preprocessing is actually performed locally. Hence, using local paths leads to a faster preprocessing speed.
230 |
231 | ### SQuAD2.0
232 |
233 | The code for the SQuAD dataset is included in `run_squad.py`.
234 |
235 | To run the code:
236 |
237 | (1) Download the SQuAD2.0 dataset into `$SQUAD_DIR` by:
238 |
239 | ```shell
240 | mkdir -p ${SQUAD_DIR} && cd ${SQUAD_DIR}
241 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
242 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
243 | ```
244 |
245 | (2) Perform data preprocessing using the script `scripts/prepro_squad.sh`.
246 |
247 | - This will take quite some time in order to accurately map character positions (raw data) to sentence piece positions (used for training).
248 |
249 | - For faster parallel preprocessing, please refer to the flags `--num_proc` and `--proc_id` in `run_squad.py`.
250 |
251 | (3) Perform training and evaluation.
252 |
253 | For the best performance, XLNet-Large uses sequence length 512 and batch size 48 for training.
254 |
255 | - As a result, reproducing the best result with GPUs is quite difficult.
256 |
257 | - For training with one TPU v3-8, one can simply run the script `scripts/tpu_squad_large.sh` after both the TPU and Google storage have been setup.
258 | - `run_squad.py` will automatically perform threshold searching on the dev set of squad and output the score. With `scripts/tpu_squad_large.sh`, the expected F1 score should be around 88.6 (median of our multiple runs).
259 |
260 | Alternatively, one can use XLNet-Base with GPUs (e.g. three V100). One set of reasonable hyper-parameters can be found in the script `scripts/gpu_squad_base.sh`.
261 |
262 |
263 | ### RACE reading comprehension
264 |
265 | The code for the reading comprehension task [RACE](https://www.cs.cmu.edu/~glai1/data/race/) is included in `run_race.py`.
266 |
267 | - Notably, the average length of the passages in RACE is over 300 tokens (not peices), which is significantly longer than other popular reading comprehension datasets such as SQuAD.
268 | - Also, many questions can be very difficult and requires complex reasoning for machines to solve (see [one example here](misc/race_example.md)).
269 |
270 |
271 | To run the code:
272 |
273 | (1) Download the RACE dataset from the [official website](https://www.cs.cmu.edu/~glai1/data/race/) and unpack the raw data to `$RACE_DIR`.
274 |
275 | (2) Perform training and evaluation:
276 |
277 | - The SOTA performance (accuracy 81.75) of RACE is produced using XLNet-Large with sequence length 512 and batch size 32, which requires a large TPU v3-32 in the pod setting. Please refer to the script `script/tpu_race_large_bsz32.sh` for this setting.
278 | - Using XLNet-Large with sequence length 512 and batch size 8 on a TPU v3-8 can give you an accuracy of around 80.3 (see `script/tpu_race_large_bsz8.sh`).
279 |
280 | ### Using Google Colab
281 |
282 | [An example](notebooks/colab_imdb_gpu.ipynb) of using Google Colab with GPUs has been provided. Note that since the hardware is constrained in the example, the results are worse than the best we can get. It mainly serves as an example and should be modified accordingly to maximize performance.
283 |
284 |
285 | ## Custom Usage of XLNet
286 |
287 | ### XLNet Abstraction
288 |
289 | For finetuning, it is likely that you will be able to modify existing files such as `run_classifier.py`, `run_squad.py` and `run_race.py` for your task at hand. However, we also provide an abstraction of XLNet to enable more flexible usage. Below is an example:
290 |
291 | ```python
292 | import xlnet
293 |
294 | # some code omitted here...
295 | # initialize FLAGS
296 | # initialize instances of tf.Tensor, including input_ids, seg_ids, and input_mask
297 |
298 | # XLNetConfig contains hyperparameters that are specific to a model checkpoint.
299 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
300 |
301 | # RunConfig contains hyperparameters that could be different between pretraining and finetuning.
302 | run_config = xlnet.create_run_config(is_training=True, is_finetune=True, FLAGS=FLAGS)
303 |
304 | # Construct an XLNet model
305 | xlnet_model = xlnet.XLNetModel(
306 | xlnet_config=xlnet_config,
307 | run_config=run_config,
308 | input_ids=input_ids,
309 | seg_ids=seg_ids,
310 | input_mask=input_mask)
311 |
312 | # Get a summary of the sequence using the last hidden state
313 | summary = xlnet_model.get_pooled_out(summary_type="last")
314 |
315 | # Get a sequence output
316 | seq_out = xlnet_model.get_sequence_output()
317 |
318 | # build your applications based on `summary` or `seq_out`
319 | ```
320 |
321 | ### Tokenization
322 |
323 | Below is an example of doing tokenization in XLNet:
324 | ```python
325 | import sentencepiece as spm
326 | from prepro_utils import preprocess_text, encode_ids
327 |
328 | # some code omitted here...
329 | # initialize FLAGS
330 |
331 | text = "An input text string."
332 |
333 | sp_model = spm.SentencePieceProcessor()
334 | sp_model.Load(FLAGS.spiece_model_file)
335 | text = preprocess_text(text, lower=FLAGS.uncased)
336 | ids = encode_ids(sp_model, text)
337 | ```
338 | where `FLAGS.spiece_model_file` is the SentencePiece model file in the same zip as the pretrained model, `FLAGS.uncased` is a bool indicating whether to do uncasing.
339 |
340 |
341 | ## Pretraining with XLNet
342 |
343 | Refer to `train.py` for pretraining on TPUs and `train_gpu.py` for pretraining on GPUs. First we need to preprocess the text data into tfrecords.
344 |
345 | ```shell
346 | python data_utils.py \
347 | --bsz_per_host=32 \
348 | --num_core_per_host=16 \
349 | --seq_len=512 \
350 | --reuse_len=256 \
351 | --input_glob=*.txt \
352 | --save_dir=${SAVE_DIR} \
353 | --num_passes=20 \
354 | --bi_data=True \
355 | --sp_path=spiece.model \
356 | --mask_alpha=6 \
357 | --mask_beta=1 \
358 | --num_predict=85
359 | ```
360 |
361 | where `input_glob` defines all input text files, `save_dir` is the output directory for tfrecords, and `sp_path` is a [Sentence Piece](https://github.com/google/sentencepiece) model. Here is our script to train the Sentence Piece model
362 |
363 | ```bash
364 | spm_train \
365 | --input=$INPUT \
366 | --model_prefix=sp10m.cased.v3 \
367 | --vocab_size=32000 \
368 | --character_coverage=0.99995 \
369 | --model_type=unigram \
370 | --control_symbols=,,,, \
371 | --user_defined_symbols=,.,(,),",-,–,£,€ \
372 | --shuffle_input_sentence \
373 | --input_sentence_size=10000000
374 | ```
375 |
376 | Special symbols are used, including `control_symbols` and `user_defined_symbols`. We use `` and `` to denote End of Paragraph and End of Document respectively.
377 |
378 | The input text files to `data_utils.py` must use the following format:
379 | * Each line is a sentence.
380 | * An empty line means End of Document.
381 | * (Optional) If one also wants to model paragraph structures, `` can be inserted at the end of certain lines (without any space) to indicate that the corresponding sentence ends a paragraph.
382 |
383 | For example, the text input file could be:
384 | ```
385 | This is the first sentence.
386 | This is the second sentence and also the end of the paragraph.
387 | Another paragraph.
388 |
389 | Another document starts here.
390 | ```
391 |
392 | After preprocessing, we are ready to pretrain an XLNet. Below are the hyperparameters used for pretraining XLNet-Large:
393 |
394 | ```shell
395 | python train.py
396 | --record_info_dir=$DATA/tfrecords \
397 | --train_batch_size=2048 \
398 | --seq_len=512 \
399 | --reuse_len=256 \
400 | --mem_len=384 \
401 | --perm_size=256 \
402 | --n_layer=24 \
403 | --d_model=1024 \
404 | --d_embed=1024 \
405 | --n_head=16 \
406 | --d_head=64 \
407 | --d_inner=4096 \
408 | --untie_r=True \
409 | --mask_alpha=6 \
410 | --mask_beta=1 \
411 | --num_predict=85
412 | ```
413 |
414 | where we only list the most important flags and the other flags could be adjusted based on specific use cases.
415 |
416 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zihangdai/xlnet/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/__init__.py
--------------------------------------------------------------------------------
/classifier_utils.py:
--------------------------------------------------------------------------------
1 | from absl import flags
2 |
3 | import re
4 | import numpy as np
5 |
6 | import tensorflow as tf
7 | from data_utils import SEP_ID, CLS_ID
8 |
9 | FLAGS = flags.FLAGS
10 |
11 | SEG_ID_A = 0
12 | SEG_ID_B = 1
13 | SEG_ID_CLS = 2
14 | SEG_ID_SEP = 3
15 | SEG_ID_PAD = 4
16 |
17 | class PaddingInputExample(object):
18 | """Fake example so the num input examples is a multiple of the batch size.
19 | When running eval/predict on the TPU, we need to pad the number of examples
20 | to be a multiple of the batch size, because the TPU requires a fixed batch
21 | size. The alternative is to drop the last batch, which is bad because it means
22 | the entire output data won't be generated.
23 | We use this class instead of `None` because treating `None` as padding
24 | battches could cause silent errors.
25 | """
26 |
27 |
28 | class InputFeatures(object):
29 | """A single set of features of data."""
30 |
31 | def __init__(self,
32 | input_ids,
33 | input_mask,
34 | segment_ids,
35 | label_id,
36 | is_real_example=True):
37 | self.input_ids = input_ids
38 | self.input_mask = input_mask
39 | self.segment_ids = segment_ids
40 | self.label_id = label_id
41 | self.is_real_example = is_real_example
42 |
43 |
44 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
45 | """Truncates a sequence pair in place to the maximum length."""
46 |
47 | # This is a simple heuristic which will always truncate the longer sequence
48 | # one token at a time. This makes more sense than truncating an equal percent
49 | # of tokens from each, since if one sequence is very short then each token
50 | # that's truncated likely contains more information than a longer sequence.
51 | while True:
52 | total_length = len(tokens_a) + len(tokens_b)
53 | if total_length <= max_length:
54 | break
55 | if len(tokens_a) > len(tokens_b):
56 | tokens_a.pop()
57 | else:
58 | tokens_b.pop()
59 |
60 |
61 | def convert_single_example(ex_index, example, label_list, max_seq_length,
62 | tokenize_fn):
63 | """Converts a single `InputExample` into a single `InputFeatures`."""
64 |
65 | if isinstance(example, PaddingInputExample):
66 | return InputFeatures(
67 | input_ids=[0] * max_seq_length,
68 | input_mask=[1] * max_seq_length,
69 | segment_ids=[0] * max_seq_length,
70 | label_id=0,
71 | is_real_example=False)
72 |
73 | if label_list is not None:
74 | label_map = {}
75 | for (i, label) in enumerate(label_list):
76 | label_map[label] = i
77 |
78 | tokens_a = tokenize_fn(example.text_a)
79 | tokens_b = None
80 | if example.text_b:
81 | tokens_b = tokenize_fn(example.text_b)
82 |
83 | if tokens_b:
84 | # Modifies `tokens_a` and `tokens_b` in place so that the total
85 | # length is less than the specified length.
86 | # Account for two [SEP] & one [CLS] with "- 3"
87 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
88 | else:
89 | # Account for one [SEP] & one [CLS] with "- 2"
90 | if len(tokens_a) > max_seq_length - 2:
91 | tokens_a = tokens_a[:max_seq_length - 2]
92 |
93 | tokens = []
94 | segment_ids = []
95 | for token in tokens_a:
96 | tokens.append(token)
97 | segment_ids.append(SEG_ID_A)
98 | tokens.append(SEP_ID)
99 | segment_ids.append(SEG_ID_A)
100 |
101 | if tokens_b:
102 | for token in tokens_b:
103 | tokens.append(token)
104 | segment_ids.append(SEG_ID_B)
105 | tokens.append(SEP_ID)
106 | segment_ids.append(SEG_ID_B)
107 |
108 | tokens.append(CLS_ID)
109 | segment_ids.append(SEG_ID_CLS)
110 |
111 | input_ids = tokens
112 |
113 | # The mask has 0 for real tokens and 1 for padding tokens. Only real
114 | # tokens are attended to.
115 | input_mask = [0] * len(input_ids)
116 |
117 | # Zero-pad up to the sequence length.
118 | if len(input_ids) < max_seq_length:
119 | delta_len = max_seq_length - len(input_ids)
120 | input_ids = [0] * delta_len + input_ids
121 | input_mask = [1] * delta_len + input_mask
122 | segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
123 |
124 | assert len(input_ids) == max_seq_length
125 | assert len(input_mask) == max_seq_length
126 | assert len(segment_ids) == max_seq_length
127 |
128 | if label_list is not None:
129 | label_id = label_map[example.label]
130 | else:
131 | label_id = example.label
132 | if ex_index < 5:
133 | tf.logging.info("*** Example ***")
134 | tf.logging.info("guid: %s" % (example.guid))
135 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
136 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
137 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
138 | tf.logging.info("label: {} (id = {})".format(example.label, label_id))
139 |
140 | feature = InputFeatures(
141 | input_ids=input_ids,
142 | input_mask=input_mask,
143 | segment_ids=segment_ids,
144 | label_id=label_id)
145 | return feature
146 |
147 |
148 |
149 |
--------------------------------------------------------------------------------
/function_builder.py:
--------------------------------------------------------------------------------
1 | """doc."""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import functools
7 | import os
8 | import tensorflow as tf
9 | import modeling
10 | import xlnet
11 |
12 |
13 | def construct_scalar_host_call(
14 | monitor_dict,
15 | model_dir,
16 | prefix="",
17 | reduce_fn=None):
18 | """
19 | Construct host calls to monitor training progress on TPUs.
20 | """
21 |
22 | metric_names = list(monitor_dict.keys())
23 |
24 | def host_call_fn(global_step, *args):
25 | """actual host call function."""
26 | step = global_step[0]
27 | with tf.contrib.summary.create_file_writer(
28 | logdir=model_dir, filename_suffix=".host_call").as_default():
29 | with tf.contrib.summary.always_record_summaries():
30 | for i, name in enumerate(metric_names):
31 | if reduce_fn is None:
32 | scalar = args[i][0]
33 | else:
34 | scalar = reduce_fn(args[i])
35 | with tf.contrib.summary.record_summaries_every_n_global_steps(
36 | 100, global_step=step):
37 | tf.contrib.summary.scalar(prefix + name, scalar, step=step)
38 |
39 | return tf.contrib.summary.all_summary_ops()
40 |
41 | global_step_tensor = tf.reshape(tf.train.get_or_create_global_step(), [1])
42 | other_tensors = [tf.reshape(monitor_dict[key], [1]) for key in metric_names]
43 |
44 | return host_call_fn, [global_step_tensor] + other_tensors
45 |
46 |
47 | def two_stream_loss(FLAGS, features, labels, mems, is_training):
48 | """Pretraining loss with two-stream attention Transformer-XL."""
49 |
50 | #### Unpack input
51 | mem_name = "mems"
52 | mems = mems.get(mem_name, None)
53 |
54 | inp_k = tf.transpose(features["input_k"], [1, 0])
55 | inp_q = tf.transpose(features["input_q"], [1, 0])
56 |
57 | seg_id = tf.transpose(features["seg_id"], [1, 0])
58 |
59 | inp_mask = None
60 | perm_mask = tf.transpose(features["perm_mask"], [1, 2, 0])
61 |
62 | if FLAGS.num_predict is not None:
63 | # [num_predict x tgt_len x bsz]
64 | target_mapping = tf.transpose(features["target_mapping"], [1, 2, 0])
65 | else:
66 | target_mapping = None
67 |
68 | # target for LM loss
69 | tgt = tf.transpose(features["target"], [1, 0])
70 |
71 | # target mask for LM loss
72 | tgt_mask = tf.transpose(features["target_mask"], [1, 0])
73 |
74 | # construct xlnet config and save to model_dir
75 | xlnet_config = xlnet.XLNetConfig(FLAGS=FLAGS)
76 | xlnet_config.to_json(os.path.join(FLAGS.model_dir, "config.json"))
77 |
78 | # construct run config from FLAGS
79 | run_config = xlnet.create_run_config(is_training, False, FLAGS)
80 |
81 | xlnet_model = xlnet.XLNetModel(
82 | xlnet_config=xlnet_config,
83 | run_config=run_config,
84 | input_ids=inp_k,
85 | seg_ids=seg_id,
86 | input_mask=inp_mask,
87 | mems=mems,
88 | perm_mask=perm_mask,
89 | target_mapping=target_mapping,
90 | inp_q=inp_q)
91 |
92 | output = xlnet_model.get_sequence_output()
93 | new_mems = {mem_name: xlnet_model.get_new_memory()}
94 | lookup_table = xlnet_model.get_embedding_table()
95 |
96 | initializer = xlnet_model.get_initializer()
97 |
98 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
99 | # LM loss
100 | lm_loss = modeling.lm_loss(
101 | hidden=output,
102 | target=tgt,
103 | n_token=xlnet_config.n_token,
104 | d_model=xlnet_config.d_model,
105 | initializer=initializer,
106 | lookup_table=lookup_table,
107 | tie_weight=True,
108 | bi_data=run_config.bi_data,
109 | use_tpu=run_config.use_tpu)
110 |
111 | #### Quantity to monitor
112 | monitor_dict = {}
113 |
114 | if FLAGS.use_bfloat16:
115 | tgt_mask = tf.cast(tgt_mask, tf.float32)
116 | lm_loss = tf.cast(lm_loss, tf.float32)
117 |
118 | total_loss = tf.reduce_sum(lm_loss * tgt_mask) / tf.reduce_sum(tgt_mask)
119 | monitor_dict["total_loss"] = total_loss
120 |
121 | return total_loss, new_mems, monitor_dict
122 |
123 |
124 | def get_loss(FLAGS, features, labels, mems, is_training):
125 | """Pretraining loss with two-stream attention Transformer-XL."""
126 | if FLAGS.use_bfloat16:
127 | with tf.tpu.bfloat16_scope():
128 | return two_stream_loss(FLAGS, features, labels, mems, is_training)
129 | else:
130 | return two_stream_loss(FLAGS, features, labels, mems, is_training)
131 |
132 |
133 | def get_classification_loss(
134 | FLAGS, features, n_class, is_training):
135 | """Loss for downstream classification tasks."""
136 |
137 | bsz_per_core = tf.shape(features["input_ids"])[0]
138 |
139 | inp = tf.transpose(features["input_ids"], [1, 0])
140 | seg_id = tf.transpose(features["segment_ids"], [1, 0])
141 | inp_mask = tf.transpose(features["input_mask"], [1, 0])
142 | label = tf.reshape(features["label_ids"], [bsz_per_core])
143 |
144 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
145 | run_config = xlnet.create_run_config(is_training, True, FLAGS)
146 |
147 | xlnet_model = xlnet.XLNetModel(
148 | xlnet_config=xlnet_config,
149 | run_config=run_config,
150 | input_ids=inp,
151 | seg_ids=seg_id,
152 | input_mask=inp_mask)
153 |
154 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
155 |
156 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
157 |
158 | if FLAGS.cls_scope is not None and FLAGS.cls_scope:
159 | cls_scope = "classification_{}".format(FLAGS.cls_scope)
160 | else:
161 | cls_scope = "classification_{}".format(FLAGS.task_name.lower())
162 |
163 | per_example_loss, logits = modeling.classification_loss(
164 | hidden=summary,
165 | labels=label,
166 | n_class=n_class,
167 | initializer=xlnet_model.get_initializer(),
168 | scope=cls_scope,
169 | return_logits=True)
170 |
171 | total_loss = tf.reduce_mean(per_example_loss)
172 |
173 | return total_loss, per_example_loss, logits
174 |
175 |
176 | def get_regression_loss(
177 | FLAGS, features, is_training):
178 | """Loss for downstream regression tasks."""
179 |
180 | bsz_per_core = tf.shape(features["input_ids"])[0]
181 |
182 | inp = tf.transpose(features["input_ids"], [1, 0])
183 | seg_id = tf.transpose(features["segment_ids"], [1, 0])
184 | inp_mask = tf.transpose(features["input_mask"], [1, 0])
185 | label = tf.reshape(features["label_ids"], [bsz_per_core])
186 |
187 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
188 | run_config = xlnet.create_run_config(is_training, True, FLAGS)
189 |
190 | xlnet_model = xlnet.XLNetModel(
191 | xlnet_config=xlnet_config,
192 | run_config=run_config,
193 | input_ids=inp,
194 | seg_ids=seg_id,
195 | input_mask=inp_mask)
196 |
197 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
198 |
199 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
200 | per_example_loss, logits = modeling.regression_loss(
201 | hidden=summary,
202 | labels=label,
203 | initializer=xlnet_model.get_initializer(),
204 | scope="regression_{}".format(FLAGS.task_name.lower()),
205 | return_logits=True)
206 |
207 | total_loss = tf.reduce_mean(per_example_loss)
208 |
209 | return total_loss, per_example_loss, logits
210 |
211 |
212 | def get_qa_outputs(FLAGS, features, is_training):
213 | """Loss for downstream span-extraction QA tasks such as SQuAD."""
214 |
215 | inp = tf.transpose(features["input_ids"], [1, 0])
216 | seg_id = tf.transpose(features["segment_ids"], [1, 0])
217 | inp_mask = tf.transpose(features["input_mask"], [1, 0])
218 | cls_index = tf.reshape(features["cls_index"], [-1])
219 |
220 | seq_len = tf.shape(inp)[0]
221 |
222 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
223 | run_config = xlnet.create_run_config(is_training, True, FLAGS)
224 |
225 | xlnet_model = xlnet.XLNetModel(
226 | xlnet_config=xlnet_config,
227 | run_config=run_config,
228 | input_ids=inp,
229 | seg_ids=seg_id,
230 | input_mask=inp_mask)
231 | output = xlnet_model.get_sequence_output()
232 | initializer = xlnet_model.get_initializer()
233 |
234 | return_dict = {}
235 |
236 | # invalid position mask such as query and special symbols (PAD, SEP, CLS)
237 | p_mask = features["p_mask"]
238 |
239 | # logit of the start position
240 | with tf.variable_scope("start_logits"):
241 | start_logits = tf.layers.dense(
242 | output,
243 | 1,
244 | kernel_initializer=initializer)
245 | start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])
246 | start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
247 | start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)
248 |
249 | # logit of the end position
250 | with tf.variable_scope("end_logits"):
251 | if is_training:
252 | # during training, compute the end logits based on the
253 | # ground truth of the start position
254 |
255 | start_positions = tf.reshape(features["start_positions"], [-1])
256 | start_index = tf.one_hot(start_positions, depth=seq_len, axis=-1,
257 | dtype=tf.float32)
258 | start_features = tf.einsum("lbh,bl->bh", output, start_index)
259 | start_features = tf.tile(start_features[None], [seq_len, 1, 1])
260 | end_logits = tf.layers.dense(
261 | tf.concat([output, start_features], axis=-1), xlnet_config.d_model,
262 | kernel_initializer=initializer, activation=tf.tanh, name="dense_0")
263 | end_logits = tf.contrib.layers.layer_norm(
264 | end_logits, begin_norm_axis=-1)
265 |
266 | end_logits = tf.layers.dense(
267 | end_logits, 1,
268 | kernel_initializer=initializer,
269 | name="dense_1")
270 | end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])
271 | end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
272 | end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
273 | else:
274 | # during inference, compute the end logits based on beam search
275 |
276 | start_top_log_probs, start_top_index = tf.nn.top_k(
277 | start_log_probs, k=FLAGS.start_n_top)
278 | start_index = tf.one_hot(start_top_index,
279 | depth=seq_len, axis=-1, dtype=tf.float32)
280 | start_features = tf.einsum("lbh,bkl->bkh", output, start_index)
281 | end_input = tf.tile(output[:, :, None],
282 | [1, 1, FLAGS.start_n_top, 1])
283 | start_features = tf.tile(start_features[None],
284 | [seq_len, 1, 1, 1])
285 | end_input = tf.concat([end_input, start_features], axis=-1)
286 | end_logits = tf.layers.dense(
287 | end_input,
288 | xlnet_config.d_model,
289 | kernel_initializer=initializer,
290 | activation=tf.tanh,
291 | name="dense_0")
292 | end_logits = tf.contrib.layers.layer_norm(end_logits,
293 | begin_norm_axis=-1)
294 | end_logits = tf.layers.dense(
295 | end_logits,
296 | 1,
297 | kernel_initializer=initializer,
298 | name="dense_1")
299 | end_logits = tf.reshape(end_logits, [seq_len, -1, FLAGS.start_n_top])
300 | end_logits = tf.transpose(end_logits, [1, 2, 0])
301 | end_logits_masked = end_logits * (
302 | 1 - p_mask[:, None]) - 1e30 * p_mask[:, None]
303 | end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
304 | end_top_log_probs, end_top_index = tf.nn.top_k(
305 | end_log_probs, k=FLAGS.end_n_top)
306 | end_top_log_probs = tf.reshape(
307 | end_top_log_probs,
308 | [-1, FLAGS.start_n_top * FLAGS.end_n_top])
309 | end_top_index = tf.reshape(
310 | end_top_index,
311 | [-1, FLAGS.start_n_top * FLAGS.end_n_top])
312 |
313 | if is_training:
314 | return_dict["start_log_probs"] = start_log_probs
315 | return_dict["end_log_probs"] = end_log_probs
316 | else:
317 | return_dict["start_top_log_probs"] = start_top_log_probs
318 | return_dict["start_top_index"] = start_top_index
319 | return_dict["end_top_log_probs"] = end_top_log_probs
320 | return_dict["end_top_index"] = end_top_index
321 |
322 | # an additional layer to predict answerability
323 | with tf.variable_scope("answer_class"):
324 | # get the representation of CLS
325 | cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32)
326 | cls_feature = tf.einsum("lbh,bl->bh", output, cls_index)
327 |
328 | # get the representation of START
329 | start_p = tf.nn.softmax(start_logits_masked, axis=-1,
330 | name="softmax_start")
331 | start_feature = tf.einsum("lbh,bl->bh", output, start_p)
332 |
333 | # note(zhiliny): no dependency on end_feature so that we can obtain
334 | # one single `cls_logits` for each sample
335 | ans_feature = tf.concat([start_feature, cls_feature], -1)
336 | ans_feature = tf.layers.dense(
337 | ans_feature,
338 | xlnet_config.d_model,
339 | activation=tf.tanh,
340 | kernel_initializer=initializer, name="dense_0")
341 | ans_feature = tf.layers.dropout(ans_feature, FLAGS.dropout,
342 | training=is_training)
343 | cls_logits = tf.layers.dense(
344 | ans_feature,
345 | 1,
346 | kernel_initializer=initializer,
347 | name="dense_1",
348 | use_bias=False)
349 | cls_logits = tf.squeeze(cls_logits, -1)
350 |
351 | return_dict["cls_logits"] = cls_logits
352 |
353 | return return_dict
354 |
355 |
356 | def get_race_loss(FLAGS, features, is_training):
357 | """Loss for downstream multi-choice QA tasks such as RACE."""
358 |
359 | bsz_per_core = tf.shape(features["input_ids"])[0]
360 |
361 | def _transform_features(feature):
362 | out = tf.reshape(feature, [bsz_per_core, 4, -1])
363 | out = tf.transpose(out, [2, 0, 1])
364 | out = tf.reshape(out, [-1, bsz_per_core * 4])
365 | return out
366 |
367 | inp = _transform_features(features["input_ids"])
368 | seg_id = _transform_features(features["segment_ids"])
369 | inp_mask = _transform_features(features["input_mask"])
370 | label = tf.reshape(features["label_ids"], [bsz_per_core])
371 |
372 | xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
373 | run_config = xlnet.create_run_config(is_training, True, FLAGS)
374 |
375 | xlnet_model = xlnet.XLNetModel(
376 | xlnet_config=xlnet_config,
377 | run_config=run_config,
378 | input_ids=inp,
379 | seg_ids=seg_id,
380 | input_mask=inp_mask)
381 | summary = xlnet_model.get_pooled_out(FLAGS.summary_type, FLAGS.use_summ_proj)
382 |
383 | with tf.variable_scope("logits"):
384 | logits = tf.layers.dense(summary, 1,
385 | kernel_initializer=xlnet_model.get_initializer())
386 | logits = tf.reshape(logits, [bsz_per_core, 4])
387 |
388 | one_hot_target = tf.one_hot(label, 4)
389 | per_example_loss = -tf.reduce_sum(
390 | tf.nn.log_softmax(logits) * one_hot_target, -1)
391 | total_loss = tf.reduce_mean(per_example_loss)
392 |
393 | return total_loss, per_example_loss, logits
394 |
--------------------------------------------------------------------------------
/gpu_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import tensorflow as tf
7 |
8 | def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"):
9 | def _assign(op):
10 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def
11 | if node_def.op == "Variable":
12 | return ps_dev
13 | else:
14 | return "/gpu:%d" % gpu
15 | return _assign
16 |
17 |
18 | def average_grads_and_vars(tower_grads_and_vars):
19 | def average_dense(grad_and_vars):
20 | if len(grad_and_vars) == 1:
21 | return grad_and_vars[0][0]
22 |
23 | grad = grad_and_vars[0][0]
24 | for g, _ in grad_and_vars[1:]:
25 | grad += g
26 | return grad / len(grad_and_vars)
27 |
28 | def average_sparse(grad_and_vars):
29 | if len(grad_and_vars) == 1:
30 | return grad_and_vars[0][0]
31 |
32 | indices = []
33 | values = []
34 | for g, _ in grad_and_vars:
35 | indices += [g.indices]
36 | values += [g.values]
37 | indices = tf.concat(indices, 0)
38 | values = tf.concat(values, 0) / len(grad_and_vars)
39 | return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape)
40 |
41 | average_grads_and_vars = []
42 | for grad_and_vars in zip(*tower_grads_and_vars):
43 | if grad_and_vars[0][0] is None:
44 | grad = None
45 | elif isinstance(grad_and_vars[0][0], tf.IndexedSlices):
46 | grad = average_sparse(grad_and_vars)
47 | else:
48 | grad = average_dense(grad_and_vars)
49 | # Keep in mind that the Variables are redundant because they are shared
50 | # across towers. So .. we will just return the first tower's pointer to
51 | # the Variable.
52 | v = grad_and_vars[0][1]
53 | grad_and_var = (grad, v)
54 | average_grads_and_vars.append(grad_and_var)
55 | return average_grads_and_vars
56 |
57 |
58 | def load_from_checkpoint(saver, logdir):
59 | sess = tf.get_default_session()
60 | ckpt = tf.train.get_checkpoint_state(logdir)
61 | if ckpt and ckpt.model_checkpoint_path:
62 | if os.path.isabs(ckpt.model_checkpoint_path):
63 | # Restores from checkpoint with absolute path.
64 | saver.restore(sess, ckpt.model_checkpoint_path)
65 | else:
66 | # Restores from checkpoint with relative path.
67 | saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path))
68 | return True
69 | return False
70 |
--------------------------------------------------------------------------------
/misc/race_example.md:
--------------------------------------------------------------------------------
1 | ## A RACE Exampple
2 |
3 | **Paragraph**
4 |
5 | It was a cold night. The taxi driver didn't take even one passenger all day.
6 | When he went by the railway station, he saw a young man coming out with two bags in his hands.
7 | So he drove to him and asked, " where are you going, sir?"
8 | "To the Red Hotel," the young man answered.
9 | When the taxi driver heard this, he didn't feel happy any more.
10 | The young man would give him only three dollars because the hotel was near the railway station.
11 | But suddenly, he had an idea. He took the young man through many streets of the big city.
12 | After a long time, they arrived at the hotel. "Here we are! You should pay me fifteen dollars, please." the taxi driver said to the young man."
13 | What? Fifteen dollars! Do you think I'm a fool? Only last week, I took a taxi from the railway station to this hotel and I only gave the driver thirteen dollars.
14 | I know how much I have to pay for the trip."
15 |
16 | **Question 1**
17 |
18 | Maybe the taxi driver got _ dollars at last.
19 |
20 | **Answer Candidates**
21 |
22 | - 3
23 | - 2
24 | - 13
25 | - 15
26 |
27 | **Question 2**
28 |
29 | Which of the following is TRUE?
30 |
31 | **Answer Candidates**
32 |
33 | - The two taxi drivers were both honest.
34 | - The two taxi drivers cheated the young man.
35 | - It is very far from the railway station to the Red Hotel.
36 | - The young man knew how far it was from the railway station to the hotel.
37 |
--------------------------------------------------------------------------------
/misc/slides.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zihangdai/xlnet/bbaa3a6fa0b3a2ee694e8cf66167434f9eca9660/misc/slides.pdf
--------------------------------------------------------------------------------
/model_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import collections
6 | import os
7 | import re
8 | import numpy as np
9 | import six
10 | from os.path import join
11 | from six.moves import zip
12 |
13 | from absl import flags
14 |
15 | import tensorflow as tf
16 |
17 |
18 | def configure_tpu(FLAGS):
19 | if FLAGS.use_tpu:
20 | tpu_cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
21 | FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
22 | master = tpu_cluster.get_master()
23 | else:
24 | tpu_cluster = None
25 | master = FLAGS.master
26 |
27 | session_config = tf.ConfigProto(allow_soft_placement=True)
28 | # Uncomment the following line if you hope to monitor GPU RAM growth
29 | # session_config.gpu_options.allow_growth = True
30 |
31 | if FLAGS.use_tpu:
32 | strategy = None
33 | tf.logging.info('Use TPU without distribute strategy.')
34 | elif FLAGS.num_core_per_host == 1:
35 | strategy = None
36 | tf.logging.info('Single device mode.')
37 | else:
38 | strategy = tf.contrib.distribute.MirroredStrategy(
39 | num_gpus=FLAGS.num_core_per_host)
40 | tf.logging.info('Use MirroredStrategy with %d devices.',
41 | strategy.num_replicas_in_sync)
42 |
43 | per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
44 | run_config = tf.contrib.tpu.RunConfig(
45 | master=master,
46 | model_dir=FLAGS.model_dir,
47 | session_config=session_config,
48 | tpu_config=tf.contrib.tpu.TPUConfig(
49 | iterations_per_loop=FLAGS.iterations,
50 | num_shards=FLAGS.num_hosts * FLAGS.num_core_per_host,
51 | per_host_input_for_training=per_host_input),
52 | keep_checkpoint_max=FLAGS.max_save,
53 | save_checkpoints_secs=None,
54 | save_checkpoints_steps=FLAGS.save_steps,
55 | train_distribute=strategy
56 | )
57 | return run_config
58 |
59 |
60 | def init_from_checkpoint(FLAGS, global_vars=False):
61 | tvars = tf.global_variables() if global_vars else tf.trainable_variables()
62 | initialized_variable_names = {}
63 | scaffold_fn = None
64 | if FLAGS.init_checkpoint is not None:
65 | if FLAGS.init_checkpoint.endswith("latest"):
66 | ckpt_dir = os.path.dirname(FLAGS.init_checkpoint)
67 | init_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
68 | else:
69 | init_checkpoint = FLAGS.init_checkpoint
70 |
71 | tf.logging.info("Initialize from the ckpt {}".format(init_checkpoint))
72 |
73 | (assignment_map, initialized_variable_names
74 | ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)
75 | if FLAGS.use_tpu:
76 | def tpu_scaffold():
77 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
78 | return tf.train.Scaffold()
79 |
80 | scaffold_fn = tpu_scaffold
81 | else:
82 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
83 |
84 | # Log customized initialization
85 | tf.logging.info("**** Global Variables ****")
86 | for var in tvars:
87 | init_string = ""
88 | if var.name in initialized_variable_names:
89 | init_string = ", *INIT_FROM_CKPT*"
90 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
91 | init_string)
92 | return scaffold_fn
93 |
94 |
95 | def get_train_op(FLAGS, total_loss, grads_and_vars=None):
96 | global_step = tf.train.get_or_create_global_step()
97 |
98 | # increase the learning rate linearly
99 | if FLAGS.warmup_steps > 0:
100 | warmup_lr = (tf.cast(global_step, tf.float32)
101 | / tf.cast(FLAGS.warmup_steps, tf.float32)
102 | * FLAGS.learning_rate)
103 | else:
104 | warmup_lr = 0.0
105 |
106 | # decay the learning rate
107 | if FLAGS.decay_method == "poly":
108 | decay_lr = tf.train.polynomial_decay(
109 | FLAGS.learning_rate,
110 | global_step=global_step - FLAGS.warmup_steps,
111 | decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
112 | end_learning_rate=FLAGS.learning_rate * FLAGS.min_lr_ratio)
113 | elif FLAGS.decay_method == "cos":
114 | decay_lr = tf.train.cosine_decay(
115 | FLAGS.learning_rate,
116 | global_step=global_step - FLAGS.warmup_steps,
117 | decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
118 | alpha=FLAGS.min_lr_ratio)
119 | else:
120 | raise ValueError(FLAGS.decay_method)
121 |
122 | learning_rate = tf.where(global_step < FLAGS.warmup_steps,
123 | warmup_lr, decay_lr)
124 |
125 | if (FLAGS.weight_decay > 0 and not FLAGS.use_tpu and
126 | FLAGS.num_core_per_host > 1):
127 | raise ValueError("Do not support `weight_decay > 0` with multi-gpu "
128 | "training so far.")
129 |
130 | if FLAGS.weight_decay == 0:
131 | optimizer = tf.train.AdamOptimizer(
132 | learning_rate=learning_rate,
133 | epsilon=FLAGS.adam_epsilon)
134 | else:
135 | optimizer = AdamWeightDecayOptimizer(
136 | learning_rate=learning_rate,
137 | epsilon=FLAGS.adam_epsilon,
138 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
139 | weight_decay_rate=FLAGS.weight_decay)
140 |
141 | if FLAGS.use_tpu:
142 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
143 |
144 | if grads_and_vars is None:
145 | grads_and_vars = optimizer.compute_gradients(total_loss)
146 | gradients, variables = zip(*grads_and_vars)
147 | clipped, gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip)
148 |
149 | if getattr(FLAGS, "lr_layer_decay_rate", 1.0) != 1.0:
150 | n_layer = 0
151 | for i in range(len(clipped)):
152 | m = re.search(r"model/transformer/layer_(\d+?)/", variables[i].name)
153 | if not m: continue
154 | n_layer = max(n_layer, int(m.group(1)) + 1)
155 |
156 | for i in range(len(clipped)):
157 | for l in range(n_layer):
158 | if "model/transformer/layer_{}/".format(l) in variables[i].name:
159 | abs_rate = FLAGS.lr_layer_decay_rate ** (n_layer - 1 - l)
160 | clipped[i] *= abs_rate
161 | tf.logging.info("Apply mult {:.4f} to layer-{} grad of {}".format(
162 | abs_rate, l, variables[i].name))
163 | break
164 |
165 | train_op = optimizer.apply_gradients(
166 | zip(clipped, variables), global_step=global_step)
167 |
168 | # Manually increment `global_step` for AdamWeightDecayOptimizer
169 | if FLAGS.weight_decay > 0:
170 | new_global_step = global_step + 1
171 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
172 |
173 | return train_op, learning_rate, gnorm
174 |
175 |
176 | def clean_ckpt(_):
177 | input_ckpt = FLAGS.clean_input_ckpt
178 | output_model_dir = FLAGS.clean_output_model_dir
179 |
180 | tf.reset_default_graph()
181 |
182 | var_list = tf.contrib.framework.list_variables(input_ckpt)
183 | var_values, var_dtypes = {}, {}
184 | for (name, shape) in var_list:
185 | if not name.startswith("global_step") and "adam" not in name.lower():
186 | var_values[name] = None
187 | tf.logging.info("Include {}".format(name))
188 | else:
189 | tf.logging.info("Exclude {}".format(name))
190 |
191 | tf.logging.info("Loading from {}".format(input_ckpt))
192 | reader = tf.contrib.framework.load_checkpoint(input_ckpt)
193 | for name in var_values:
194 | tensor = reader.get_tensor(name)
195 | var_dtypes[name] = tensor.dtype
196 | var_values[name] = tensor
197 |
198 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
199 | tf_vars = [
200 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
201 | for v in var_values
202 | ]
203 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
204 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
205 | global_step = tf.Variable(
206 | 0, name="global_step", trainable=False, dtype=tf.int64)
207 | saver = tf.train.Saver(tf.all_variables())
208 |
209 | if not tf.gfile.Exists(output_model_dir):
210 | tf.gfile.MakeDirs(output_model_dir)
211 |
212 | # Build a model consisting only of variables, set them to the average values.
213 | with tf.Session() as sess:
214 | sess.run(tf.initialize_all_variables())
215 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
216 | six.iteritems(var_values)):
217 | sess.run(assign_op, {p: value})
218 |
219 | # Use the built saver to save the averaged checkpoint.
220 | saver.save(sess, join(output_model_dir, "model.ckpt"),
221 | global_step=global_step)
222 |
223 |
224 | def avg_checkpoints(model_dir, output_model_dir, last_k):
225 | tf.reset_default_graph()
226 |
227 | checkpoint_state = tf.train.get_checkpoint_state(model_dir)
228 | checkpoints = checkpoint_state.all_model_checkpoint_paths[- last_k:]
229 | var_list = tf.contrib.framework.list_variables(checkpoints[0])
230 | var_values, var_dtypes = {}, {}
231 | for (name, shape) in var_list:
232 | if not name.startswith("global_step"):
233 | var_values[name] = np.zeros(shape)
234 | for checkpoint in checkpoints:
235 | reader = tf.contrib.framework.load_checkpoint(checkpoint)
236 | for name in var_values:
237 | tensor = reader.get_tensor(name)
238 | var_dtypes[name] = tensor.dtype
239 | var_values[name] += tensor
240 | tf.logging.info("Read from checkpoint %s", checkpoint)
241 | for name in var_values: # Average.
242 | var_values[name] /= len(checkpoints)
243 |
244 | with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
245 | tf_vars = [
246 | tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
247 | for v in var_values
248 | ]
249 | placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
250 | assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
251 | global_step = tf.Variable(
252 | 0, name="global_step", trainable=False, dtype=tf.int64)
253 | saver = tf.train.Saver(tf.all_variables())
254 |
255 | # Build a model consisting only of variables, set them to the average values.
256 | with tf.Session() as sess:
257 | sess.run(tf.initialize_all_variables())
258 | for p, assign_op, (name, value) in zip(placeholders, assign_ops,
259 | six.iteritems(var_values)):
260 | sess.run(assign_op, {p: value})
261 | # Use the built saver to save the averaged checkpoint.
262 | saver.save(sess, join(output_model_dir, "model.ckpt"),
263 | global_step=global_step)
264 |
265 |
266 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
267 | """Compute the union of the current variables and checkpoint variables."""
268 | assignment_map = {}
269 | initialized_variable_names = {}
270 |
271 | name_to_variable = collections.OrderedDict()
272 | for var in tvars:
273 | name = var.name
274 | m = re.match("^(.*):\\d+$", name)
275 | if m is not None:
276 | name = m.group(1)
277 | name_to_variable[name] = var
278 |
279 | init_vars = tf.train.list_variables(init_checkpoint)
280 |
281 | assignment_map = collections.OrderedDict()
282 | for x in init_vars:
283 | (name, var) = (x[0], x[1])
284 | # tf.logging.info('original name: %s', name)
285 | if name not in name_to_variable:
286 | continue
287 | # assignment_map[name] = name
288 | assignment_map[name] = name_to_variable[name]
289 | initialized_variable_names[name] = 1
290 | initialized_variable_names[name + ":0"] = 1
291 |
292 | return (assignment_map, initialized_variable_names)
293 |
294 |
295 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
296 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
297 |
298 | def __init__(self,
299 | learning_rate,
300 | weight_decay_rate=0.0,
301 | beta_1=0.9,
302 | beta_2=0.999,
303 | epsilon=1e-6,
304 | exclude_from_weight_decay=None,
305 | include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"],
306 | name="AdamWeightDecayOptimizer"):
307 | """Constructs a AdamWeightDecayOptimizer."""
308 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
309 |
310 | self.learning_rate = learning_rate
311 | self.weight_decay_rate = weight_decay_rate
312 | self.beta_1 = beta_1
313 | self.beta_2 = beta_2
314 | self.epsilon = epsilon
315 | self.exclude_from_weight_decay = exclude_from_weight_decay
316 | self.include_in_weight_decay = include_in_weight_decay
317 |
318 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
319 | """See base class."""
320 | assignments = []
321 | for (grad, param) in grads_and_vars:
322 | if grad is None or param is None:
323 | continue
324 |
325 | param_name = self._get_variable_name(param.name)
326 |
327 | m = tf.get_variable(
328 | name=param_name + "/adam_m",
329 | shape=param.shape.as_list(),
330 | dtype=tf.float32,
331 | trainable=False,
332 | initializer=tf.zeros_initializer())
333 | v = tf.get_variable(
334 | name=param_name + "/adam_v",
335 | shape=param.shape.as_list(),
336 | dtype=tf.float32,
337 | trainable=False,
338 | initializer=tf.zeros_initializer())
339 |
340 | # Standard Adam update.
341 | next_m = (
342 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
343 | next_v = (
344 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
345 | tf.square(grad)))
346 |
347 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
348 |
349 | # Just adding the square of the weights to the loss function is *not*
350 | # the correct way of using L2 regularization/weight decay with Adam,
351 | # since that will interact with the m and v parameters in strange ways.
352 | #
353 | # Instead we want ot decay the weights in a manner that doesn't interact
354 | # with the m/v parameters. This is equivalent to adding the square
355 | # of the weights to the loss with plain (non-momentum) SGD.
356 | if self._do_use_weight_decay(param_name):
357 | update += self.weight_decay_rate * param
358 |
359 | update_with_lr = self.learning_rate * update
360 |
361 | next_param = param - update_with_lr
362 |
363 | assignments.extend(
364 | [param.assign(next_param),
365 | m.assign(next_m),
366 | v.assign(next_v)])
367 |
368 | return tf.group(*assignments, name=name)
369 |
370 | def _do_use_weight_decay(self, param_name):
371 | """Whether to use L2 weight decay for `param_name`."""
372 | if not self.weight_decay_rate:
373 | return False
374 | for r in self.include_in_weight_decay:
375 | if re.search(r, param_name) is not None:
376 | return True
377 |
378 | if self.exclude_from_weight_decay:
379 | for r in self.exclude_from_weight_decay:
380 | if re.search(r, param_name) is not None:
381 | tf.logging.info('Adam WD excludes {}'.format(param_name))
382 | return False
383 | return True
384 |
385 | def _get_variable_name(self, param_name):
386 | """Get the variable name from the tensor name."""
387 | m = re.match("^(.*):\\d+$", param_name)
388 | if m is not None:
389 | param_name = m.group(1)
390 | return param_name
391 |
392 |
393 | if __name__ == "__main__":
394 | flags.DEFINE_string("clean_input_ckpt", "", "input ckpt for cleaning")
395 | flags.DEFINE_string("clean_output_model_dir", "", "output dir for cleaned ckpt")
396 |
397 | FLAGS = flags.FLAGS
398 |
399 | tf.app.run(clean_ckpt)
400 |
--------------------------------------------------------------------------------
/modeling.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 |
9 | def gelu(x):
10 | """Gaussian Error Linear Unit.
11 |
12 | This is a smoother version of the RELU.
13 | Original paper: https://arxiv.org/abs/1606.08415
14 | Args:
15 | x: float Tensor to perform activation.
16 |
17 | Returns:
18 | `x` with the GELU activation applied.
19 | """
20 | cdf = 0.5 * (1.0 + tf.tanh(
21 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
22 | return x * cdf
23 |
24 |
25 | def embedding_lookup(x, n_token, d_embed, initializer, use_tpu=True,
26 | scope='embedding', reuse=None, dtype=tf.float32):
27 | """TPU and GPU embedding_lookup function."""
28 | with tf.variable_scope(scope, reuse=reuse):
29 | lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
30 | dtype=dtype, initializer=initializer)
31 | if use_tpu:
32 | one_hot_idx = tf.one_hot(x, n_token, dtype=dtype)
33 | if one_hot_idx.shape.ndims == 2:
34 | return tf.einsum('in,nd->id', one_hot_idx, lookup_table), lookup_table
35 | else:
36 | return tf.einsum('ibn,nd->ibd', one_hot_idx, lookup_table), lookup_table
37 | else:
38 | return tf.nn.embedding_lookup(lookup_table, x), lookup_table
39 |
40 |
41 | def positional_embedding(pos_seq, inv_freq, bsz=None):
42 | sinusoid_inp = tf.einsum('i,d->id', pos_seq, inv_freq)
43 | pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
44 | pos_emb = pos_emb[:, None, :]
45 |
46 | if bsz is not None:
47 | pos_emb = tf.tile(pos_emb, [1, bsz, 1])
48 |
49 | return pos_emb
50 |
51 |
52 | def positionwise_ffn(inp, d_model, d_inner, dropout, kernel_initializer,
53 | activation_type='relu', scope='ff', is_training=True,
54 | reuse=None):
55 | """Position-wise Feed-forward Network."""
56 | if activation_type == 'relu':
57 | activation = tf.nn.relu
58 | elif activation_type == 'gelu':
59 | activation = gelu
60 | else:
61 | raise ValueError('Unsupported activation type {}'.format(activation_type))
62 |
63 | output = inp
64 | with tf.variable_scope(scope, reuse=reuse):
65 | output = tf.layers.dense(output, d_inner, activation=activation,
66 | kernel_initializer=kernel_initializer,
67 | name='layer_1')
68 | output = tf.layers.dropout(output, dropout, training=is_training,
69 | name='drop_1')
70 | output = tf.layers.dense(output, d_model,
71 | kernel_initializer=kernel_initializer,
72 | name='layer_2')
73 | output = tf.layers.dropout(output, dropout, training=is_training,
74 | name='drop_2')
75 | output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1,
76 | scope='LayerNorm')
77 | return output
78 |
79 |
80 | def head_projection(h, d_model, n_head, d_head, kernel_initializer, name):
81 | """Project hidden states to a specific head with a 4D-shape."""
82 | proj_weight = tf.get_variable('{}/kernel'.format(name),
83 | [d_model, n_head, d_head], dtype=h.dtype,
84 | initializer=kernel_initializer)
85 | head = tf.einsum('ibh,hnd->ibnd', h, proj_weight)
86 |
87 | return head
88 |
89 |
90 | def post_attention(h, attn_vec, d_model, n_head, d_head, dropout, is_training,
91 | kernel_initializer, residual=True):
92 | """Post-attention processing."""
93 | # post-attention projection (back to `d_model`)
94 | proj_o = tf.get_variable('o/kernel', [d_model, n_head, d_head],
95 | dtype=h.dtype, initializer=kernel_initializer)
96 | attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o)
97 |
98 | attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
99 | if residual:
100 | output = tf.contrib.layers.layer_norm(attn_out + h, begin_norm_axis=-1,
101 | scope='LayerNorm')
102 | else:
103 | output = tf.contrib.layers.layer_norm(attn_out, begin_norm_axis=-1,
104 | scope='LayerNorm')
105 |
106 | return output
107 |
108 |
109 | def abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt, is_training,
110 | scale):
111 | """Core absolute positional attention operations."""
112 |
113 | attn_score = tf.einsum('ibnd,jbnd->ijbn', q_head, k_head)
114 | attn_score *= scale
115 | if attn_mask is not None:
116 | attn_score = attn_score - 1e30 * attn_mask
117 |
118 | # attention probability
119 | attn_prob = tf.nn.softmax(attn_score, 1)
120 | attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
121 |
122 | # attention output
123 | attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head)
124 |
125 | return attn_vec
126 |
127 |
128 | def rel_attn_core(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
129 | r_w_bias, r_r_bias, r_s_bias, attn_mask, dropatt, is_training,
130 | scale):
131 | """Core relative positional attention operations."""
132 |
133 | # content based attention score
134 | ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)
135 |
136 | # position based attention score
137 | bd = tf.einsum('ibnd,jbnd->ijbn', q_head + r_r_bias, k_head_r)
138 | bd = rel_shift(bd, klen=tf.shape(ac)[1])
139 |
140 | # segment based attention score
141 | if seg_mat is None:
142 | ef = 0
143 | else:
144 | ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
145 | ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)
146 |
147 | # merge attention scores and perform masking
148 | attn_score = (ac + bd + ef) * scale
149 | if attn_mask is not None:
150 | # attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
151 | attn_score = attn_score - 1e30 * attn_mask
152 |
153 | # attention probability
154 | attn_prob = tf.nn.softmax(attn_score, 1)
155 | attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
156 |
157 | # attention output
158 | attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, v_head_h)
159 |
160 | return attn_vec
161 |
162 |
163 | def rel_shift(x, klen=-1):
164 | """perform relative shift to form the relative attention score."""
165 | x_size = tf.shape(x)
166 |
167 | x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
168 | x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
169 | x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
170 | x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
171 |
172 | return x
173 |
174 |
175 | def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
176 | """create causal attention mask."""
177 | attn_mask = tf.ones([qlen, qlen], dtype=dtype)
178 | mask_u = tf.matrix_band_part(attn_mask, 0, -1)
179 | mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
180 | attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
181 | ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
182 | if same_length:
183 | mask_l = tf.matrix_band_part(attn_mask, -1, 0)
184 | ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
185 |
186 | return ret
187 |
188 |
189 | def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
190 | """cache hidden states into memory."""
191 | if mem_len is None or mem_len == 0:
192 | return None
193 | else:
194 | if reuse_len is not None and reuse_len > 0:
195 | curr_out = curr_out[:reuse_len]
196 |
197 | if prev_mem is None:
198 | new_mem = curr_out[-mem_len:]
199 | else:
200 | new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]
201 |
202 | return tf.stop_gradient(new_mem)
203 |
204 |
205 | def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type,
206 | bi_data, bsz=None, dtype=None):
207 | """create relative positional encoding."""
208 | freq_seq = tf.range(0, d_model, 2.0)
209 | if dtype is not None and dtype != tf.float32:
210 | freq_seq = tf.cast(freq_seq, dtype=dtype)
211 | inv_freq = 1 / (10000 ** (freq_seq / d_model))
212 |
213 | if attn_type == 'bi':
214 | # beg, end = klen - 1, -qlen
215 | beg, end = klen, -qlen
216 | elif attn_type == 'uni':
217 | # beg, end = klen - 1, -1
218 | beg, end = klen, -1
219 | else:
220 | raise ValueError('Unknown `attn_type` {}.'.format(attn_type))
221 |
222 | if bi_data:
223 | fwd_pos_seq = tf.range(beg, end, -1.0)
224 | bwd_pos_seq = tf.range(-beg, -end, 1.0)
225 |
226 | if dtype is not None and dtype != tf.float32:
227 | fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
228 | bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
229 |
230 | if clamp_len > 0:
231 | fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
232 | bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len)
233 |
234 | if bsz is not None:
235 | # With bi_data, the batch size should be divisible by 2.
236 | assert bsz%2 == 0
237 | fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
238 | bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
239 | else:
240 | fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq)
241 | bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq)
242 |
243 | pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
244 | else:
245 | fwd_pos_seq = tf.range(beg, end, -1.0)
246 | if dtype is not None and dtype != tf.float32:
247 | fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
248 | if clamp_len > 0:
249 | fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
250 | pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz)
251 |
252 | return pos_emb
253 |
254 |
255 | def multihead_attn(q, k, v, attn_mask, d_model, n_head, d_head, dropout,
256 | dropatt, is_training, kernel_initializer, residual=True,
257 | scope='abs_attn', reuse=None):
258 | """Standard multi-head attention with absolute positional embedding."""
259 |
260 | scale = 1 / (d_head ** 0.5)
261 | with tf.variable_scope(scope, reuse=reuse):
262 | # attention heads
263 | q_head = head_projection(
264 | q, d_model, n_head, d_head, kernel_initializer, 'q')
265 | k_head = head_projection(
266 | k, d_model, n_head, d_head, kernel_initializer, 'k')
267 | v_head = head_projection(
268 | v, d_model, n_head, d_head, kernel_initializer, 'v')
269 |
270 | # attention vector
271 | attn_vec = abs_attn_core(q_head, k_head, v_head, attn_mask, dropatt,
272 | is_training, scale)
273 |
274 | # post processing
275 | output = post_attention(v, attn_vec, d_model, n_head, d_head, dropout,
276 | is_training, kernel_initializer, residual)
277 |
278 | return output
279 |
280 |
281 |
282 | def rel_multihead_attn(h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
283 | attn_mask, mems, d_model, n_head, d_head, dropout,
284 | dropatt, is_training, kernel_initializer,
285 | scope='rel_attn', reuse=None):
286 | """Multi-head attention with relative positional encoding."""
287 |
288 | scale = 1 / (d_head ** 0.5)
289 | with tf.variable_scope(scope, reuse=reuse):
290 | if mems is not None and mems.shape.ndims > 1:
291 | cat = tf.concat([mems, h], 0)
292 | else:
293 | cat = h
294 |
295 | # content heads
296 | q_head_h = head_projection(
297 | h, d_model, n_head, d_head, kernel_initializer, 'q')
298 | k_head_h = head_projection(
299 | cat, d_model, n_head, d_head, kernel_initializer, 'k')
300 | v_head_h = head_projection(
301 | cat, d_model, n_head, d_head, kernel_initializer, 'v')
302 |
303 | # positional heads
304 | k_head_r = head_projection(
305 | r, d_model, n_head, d_head, kernel_initializer, 'r')
306 |
307 | # core attention ops
308 | attn_vec = rel_attn_core(
309 | q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
310 | r_r_bias, r_s_bias, attn_mask, dropatt, is_training, scale)
311 |
312 | # post processing
313 | output = post_attention(h, attn_vec, d_model, n_head, d_head, dropout,
314 | is_training, kernel_initializer)
315 |
316 | return output
317 |
318 |
319 | def two_stream_rel_attn(h, g, r, mems, r_w_bias, r_r_bias, seg_mat, r_s_bias,
320 | seg_embed, attn_mask_h, attn_mask_g, target_mapping,
321 | d_model, n_head, d_head, dropout, dropatt, is_training,
322 | kernel_initializer, scope='rel_attn'):
323 | """Two-stream attention with relative positional encoding."""
324 |
325 | scale = 1 / (d_head ** 0.5)
326 | with tf.variable_scope(scope, reuse=False):
327 |
328 | # content based attention score
329 | if mems is not None and mems.shape.ndims > 1:
330 | cat = tf.concat([mems, h], 0)
331 | else:
332 | cat = h
333 |
334 | # content-based key head
335 | k_head_h = head_projection(
336 | cat, d_model, n_head, d_head, kernel_initializer, 'k')
337 |
338 | # content-based value head
339 | v_head_h = head_projection(
340 | cat, d_model, n_head, d_head, kernel_initializer, 'v')
341 |
342 | # position-based key head
343 | k_head_r = head_projection(
344 | r, d_model, n_head, d_head, kernel_initializer, 'r')
345 |
346 | ##### h-stream
347 | # content-stream query head
348 | q_head_h = head_projection(
349 | h, d_model, n_head, d_head, kernel_initializer, 'q')
350 |
351 | # core attention ops
352 | attn_vec_h = rel_attn_core(
353 | q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
354 | r_r_bias, r_s_bias, attn_mask_h, dropatt, is_training, scale)
355 |
356 | # post processing
357 | output_h = post_attention(h, attn_vec_h, d_model, n_head, d_head, dropout,
358 | is_training, kernel_initializer)
359 |
360 | with tf.variable_scope(scope, reuse=True):
361 | ##### g-stream
362 | # query-stream query head
363 | q_head_g = head_projection(
364 | g, d_model, n_head, d_head, kernel_initializer, 'q')
365 |
366 | # core attention ops
367 | if target_mapping is not None:
368 | q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
369 | attn_vec_g = rel_attn_core(
370 | q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
371 | r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale)
372 | attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
373 | else:
374 | attn_vec_g = rel_attn_core(
375 | q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
376 | r_r_bias, r_s_bias, attn_mask_g, dropatt, is_training, scale)
377 |
378 | # post processing
379 | output_g = post_attention(g, attn_vec_g, d_model, n_head, d_head, dropout,
380 | is_training, kernel_initializer)
381 |
382 | return output_h, output_g
383 |
384 |
385 | def transformer_xl(inp_k, n_token, n_layer, d_model, n_head,
386 | d_head, d_inner, dropout, dropatt, attn_type,
387 | bi_data, initializer, is_training, mem_len=None,
388 | inp_q=None, mems=None,
389 | same_length=False, clamp_len=-1, untie_r=False,
390 | use_tpu=True, input_mask=None,
391 | perm_mask=None, seg_id=None, reuse_len=None,
392 | ff_activation='relu', target_mapping=None,
393 | use_bfloat16=False, scope='transformer', **kwargs):
394 | """
395 | Defines a Transformer-XL computation graph with additional
396 | support for XLNet.
397 |
398 | Args:
399 |
400 | inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
401 | seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
402 | input_mask: float32 Tensor in shape [len, bsz], the input mask.
403 | 0 for real tokens and 1 for padding.
404 | mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
405 | from previous batches. The length of the list equals n_layer.
406 | If None, no memory is used.
407 | perm_mask: float32 Tensor in shape [len, len, bsz].
408 | If perm_mask[i, j, k] = 0, i attend to j in batch k;
409 | if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
410 | If None, each position attends to all the others.
411 | target_mapping: float32 Tensor in shape [num_predict, len, bsz].
412 | If target_mapping[i, j, k] = 1, the i-th predict in batch k is
413 | on the j-th token.
414 | Only used during pretraining for partial prediction.
415 | Set to None during finetuning.
416 | inp_q: float32 Tensor in shape [len, bsz].
417 | 1 for tokens with losses and 0 for tokens without losses.
418 | Only used during pretraining for two-stream attention.
419 | Set to None during finetuning.
420 |
421 | n_layer: int, the number of layers.
422 | d_model: int, the hidden size.
423 | n_head: int, the number of attention heads.
424 | d_head: int, the dimension size of each attention head.
425 | d_inner: int, the hidden size in feed-forward layers.
426 | ff_activation: str, "relu" or "gelu".
427 | untie_r: bool, whether to untie the biases in attention.
428 | n_token: int, the vocab size.
429 |
430 | is_training: bool, whether in training mode.
431 | use_tpu: bool, whether TPUs are used.
432 | use_bfloat16: bool, use bfloat16 instead of float32.
433 | dropout: float, dropout rate.
434 | dropatt: float, dropout rate on attention probabilities.
435 | init: str, the initialization scheme, either "normal" or "uniform".
436 | init_range: float, initialize the parameters with a uniform distribution
437 | in [-init_range, init_range]. Only effective when init="uniform".
438 | init_std: float, initialize the parameters with a normal distribution
439 | with mean 0 and stddev init_std. Only effective when init="normal".
440 | mem_len: int, the number of tokens to cache.
441 | reuse_len: int, the number of tokens in the currect batch to be cached
442 | and reused in the future.
443 | bi_data: bool, whether to use bidirectional input pipeline.
444 | Usually set to True during pretraining and False during finetuning.
445 | clamp_len: int, clamp all relative distances larger than clamp_len.
446 | -1 means no clamping.
447 | same_length: bool, whether to use the same attention length for each token.
448 | summary_type: str, "last", "first", "mean", or "attn". The method
449 | to pool the input to get a vector representation.
450 | initializer: A tf initializer.
451 | scope: scope name for the computation graph.
452 | """
453 | tf.logging.info('memory input {}'.format(mems))
454 | tf_float = tf.bfloat16 if use_bfloat16 else tf.float32
455 | tf.logging.info('Use float type {}'.format(tf_float))
456 |
457 | new_mems = []
458 | with tf.variable_scope(scope):
459 | if untie_r:
460 | r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
461 | dtype=tf_float, initializer=initializer)
462 | r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
463 | dtype=tf_float, initializer=initializer)
464 | else:
465 | r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
466 | dtype=tf_float, initializer=initializer)
467 | r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
468 | dtype=tf_float, initializer=initializer)
469 |
470 | bsz = tf.shape(inp_k)[1]
471 | qlen = tf.shape(inp_k)[0]
472 | mlen = tf.shape(mems[0])[0] if mems is not None else 0
473 | klen = mlen + qlen
474 |
475 | ##### Attention mask
476 | # causal attention mask
477 | if attn_type == 'uni':
478 | attn_mask = _create_mask(qlen, mlen, tf_float, same_length)
479 | attn_mask = attn_mask[:, :, None, None]
480 | elif attn_type == 'bi':
481 | attn_mask = None
482 | else:
483 | raise ValueError('Unsupported attention type: {}'.format(attn_type))
484 |
485 | # data mask: input mask & perm mask
486 | if input_mask is not None and perm_mask is not None:
487 | data_mask = input_mask[None] + perm_mask
488 | elif input_mask is not None and perm_mask is None:
489 | data_mask = input_mask[None]
490 | elif input_mask is None and perm_mask is not None:
491 | data_mask = perm_mask
492 | else:
493 | data_mask = None
494 |
495 | if data_mask is not None:
496 | # all mems can be attended to
497 | mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
498 | dtype=tf_float)
499 | data_mask = tf.concat([mems_mask, data_mask], 1)
500 | if attn_mask is None:
501 | attn_mask = data_mask[:, :, :, None]
502 | else:
503 | attn_mask += data_mask[:, :, :, None]
504 |
505 | if attn_mask is not None:
506 | attn_mask = tf.cast(attn_mask > 0, dtype=tf_float)
507 |
508 | if attn_mask is not None:
509 | non_tgt_mask = -tf.eye(qlen, dtype=tf_float)
510 | non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float),
511 | non_tgt_mask], axis=-1)
512 | non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
513 | dtype=tf_float)
514 | else:
515 | non_tgt_mask = None
516 |
517 | ##### Word embedding
518 | word_emb_k, lookup_table = embedding_lookup(
519 | x=inp_k,
520 | n_token=n_token,
521 | d_embed=d_model,
522 | initializer=initializer,
523 | use_tpu=use_tpu,
524 | dtype=tf_float,
525 | scope='word_embedding')
526 |
527 | if inp_q is not None:
528 | with tf.variable_scope('mask_emb'):
529 | mask_emb = tf.get_variable('mask_emb', [1, 1, d_model], dtype=tf_float)
530 | if target_mapping is not None:
531 | word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0], bsz, 1])
532 | else:
533 | inp_q_ext = inp_q[:, :, None]
534 | word_emb_q = inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
535 | output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training)
536 | if inp_q is not None:
537 | output_g = tf.layers.dropout(word_emb_q, dropout, training=is_training)
538 |
539 | ##### Segment embedding
540 | if seg_id is not None:
541 | if untie_r:
542 | r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head],
543 | dtype=tf_float, initializer=initializer)
544 | else:
545 | # default case (tie)
546 | r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
547 | dtype=tf_float, initializer=initializer)
548 |
549 | seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head],
550 | dtype=tf_float, initializer=initializer)
551 |
552 | # Convert `seg_id` to one-hot `seg_mat`
553 | mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
554 | cat_ids = tf.concat([mem_pad, seg_id], 0)
555 |
556 | # `1` indicates not in the same segment [qlen x klen x bsz]
557 | seg_mat = tf.cast(
558 | tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
559 | tf.int32)
560 | seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
561 | else:
562 | seg_mat = None
563 |
564 | ##### Positional encoding
565 | pos_emb = relative_positional_encoding(
566 | qlen, klen, d_model, clamp_len, attn_type, bi_data,
567 | bsz=bsz, dtype=tf_float)
568 | pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)
569 |
570 | ##### Attention layers
571 | if mems is None:
572 | mems = [None] * n_layer
573 |
574 | for i in range(n_layer):
575 | # cache new mems
576 | new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len))
577 |
578 | # segment bias
579 | if seg_id is None:
580 | r_s_bias_i = None
581 | seg_embed_i = None
582 | else:
583 | r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
584 | seg_embed_i = seg_embed[i]
585 |
586 | with tf.variable_scope('layer_{}'.format(i)):
587 | if inp_q is not None:
588 | output_h, output_g = two_stream_rel_attn(
589 | h=output_h,
590 | g=output_g,
591 | r=pos_emb,
592 | r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
593 | r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
594 | seg_mat=seg_mat,
595 | r_s_bias=r_s_bias_i,
596 | seg_embed=seg_embed_i,
597 | attn_mask_h=non_tgt_mask,
598 | attn_mask_g=attn_mask,
599 | mems=mems[i],
600 | target_mapping=target_mapping,
601 | d_model=d_model,
602 | n_head=n_head,
603 | d_head=d_head,
604 | dropout=dropout,
605 | dropatt=dropatt,
606 | is_training=is_training,
607 | kernel_initializer=initializer)
608 | reuse = True
609 | else:
610 | reuse = False
611 |
612 | output_h = rel_multihead_attn(
613 | h=output_h,
614 | r=pos_emb,
615 | r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
616 | r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
617 | seg_mat=seg_mat,
618 | r_s_bias=r_s_bias_i,
619 | seg_embed=seg_embed_i,
620 | attn_mask=non_tgt_mask,
621 | mems=mems[i],
622 | d_model=d_model,
623 | n_head=n_head,
624 | d_head=d_head,
625 | dropout=dropout,
626 | dropatt=dropatt,
627 | is_training=is_training,
628 | kernel_initializer=initializer,
629 | reuse=reuse)
630 |
631 | if inp_q is not None:
632 | output_g = positionwise_ffn(
633 | inp=output_g,
634 | d_model=d_model,
635 | d_inner=d_inner,
636 | dropout=dropout,
637 | kernel_initializer=initializer,
638 | activation_type=ff_activation,
639 | is_training=is_training)
640 |
641 | output_h = positionwise_ffn(
642 | inp=output_h,
643 | d_model=d_model,
644 | d_inner=d_inner,
645 | dropout=dropout,
646 | kernel_initializer=initializer,
647 | activation_type=ff_activation,
648 | is_training=is_training,
649 | reuse=reuse)
650 |
651 | if inp_q is not None:
652 | output = tf.layers.dropout(output_g, dropout, training=is_training)
653 | else:
654 | output = tf.layers.dropout(output_h, dropout, training=is_training)
655 |
656 | return output, new_mems, lookup_table
657 |
658 |
659 | def lm_loss(hidden, target, n_token, d_model, initializer, lookup_table=None,
660 | tie_weight=False, bi_data=True, use_tpu=False):
661 | """doc."""
662 |
663 | with tf.variable_scope('lm_loss'):
664 | if tie_weight:
665 | assert lookup_table is not None, \
666 | 'lookup_table cannot be None for tie_weight'
667 | softmax_w = lookup_table
668 | else:
669 | softmax_w = tf.get_variable('weight', [n_token, d_model],
670 | dtype=hidden.dtype, initializer=initializer)
671 |
672 | softmax_b = tf.get_variable('bias', [n_token], dtype=hidden.dtype,
673 | initializer=tf.zeros_initializer())
674 |
675 | logits = tf.einsum('ibd,nd->ibn', hidden, softmax_w) + softmax_b
676 |
677 | if use_tpu:
678 | one_hot_target = tf.one_hot(target, n_token, dtype=logits.dtype)
679 | loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
680 | else:
681 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
682 | logits=logits)
683 |
684 | return loss
685 |
686 |
687 | def summarize_sequence(summary_type, hidden, d_model, n_head, d_head, dropout,
688 | dropatt, input_mask, is_training, initializer,
689 | scope=None, reuse=None, use_proj=True):
690 |
691 | """
692 | Different classification tasks may not may not share the same parameters
693 | to summarize the sequence features.
694 |
695 | If shared, one can keep the `scope` to the default value `None`.
696 | Otherwise, one should specify a different `scope` for each task.
697 | """
698 |
699 | with tf.variable_scope(scope, 'sequnece_summary', reuse=reuse):
700 | if summary_type == 'last':
701 | summary = hidden[-1]
702 | elif summary_type == 'first':
703 | summary = hidden[0]
704 | elif summary_type == 'mean':
705 | summary = tf.reduce_mean(hidden, axis=0)
706 | elif summary_type == 'attn':
707 | bsz = tf.shape(hidden)[1]
708 |
709 | summary_bias = tf.get_variable('summary_bias', [d_model],
710 | dtype=hidden.dtype,
711 | initializer=initializer)
712 | summary_bias = tf.tile(summary_bias[None, None], [1, bsz, 1])
713 |
714 | if input_mask is not None:
715 | input_mask = input_mask[None, :, :, None]
716 |
717 | summary = multihead_attn(summary_bias, hidden, hidden, input_mask,
718 | d_model, n_head, d_head, dropout, dropatt,
719 | is_training, initializer, residual=False)
720 | summary = summary[0]
721 | else:
722 | raise ValueError('Unsupported summary type {}'.format(summary_type))
723 |
724 | # use another projection as in BERT
725 | if use_proj:
726 | summary = tf.layers.dense(
727 | summary,
728 | d_model,
729 | activation=tf.tanh,
730 | kernel_initializer=initializer,
731 | name='summary')
732 |
733 | # dropout
734 | summary = tf.layers.dropout(
735 | summary, dropout, training=is_training,
736 | name='dropout')
737 |
738 | return summary
739 |
740 |
741 | def classification_loss(hidden, labels, n_class, initializer, scope, reuse=None,
742 | return_logits=False):
743 | """
744 | Different classification tasks should use different scope names to ensure
745 | different dense layers (parameters) are used to produce the logits.
746 |
747 | An exception will be in transfer learning, where one hopes to transfer
748 | the classification weights.
749 | """
750 |
751 | with tf.variable_scope(scope, reuse=reuse):
752 | logits = tf.layers.dense(
753 | hidden,
754 | n_class,
755 | kernel_initializer=initializer,
756 | name='logit')
757 |
758 | one_hot_target = tf.one_hot(labels, n_class, dtype=hidden.dtype)
759 | loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
760 |
761 | if return_logits:
762 | return loss, logits
763 |
764 | return loss
765 |
766 |
767 | def regression_loss(hidden, labels, initializer, scope, reuse=None,
768 | return_logits=False):
769 | with tf.variable_scope(scope, reuse=reuse):
770 | logits = tf.layers.dense(
771 | hidden,
772 | 1,
773 | kernel_initializer=initializer,
774 | name='logit')
775 |
776 | logits = tf.squeeze(logits, axis=-1)
777 | loss = tf.square(logits - labels)
778 |
779 | if return_logits:
780 | return loss, logits
781 |
782 | return loss
783 |
784 |
--------------------------------------------------------------------------------
/notebooks/colab_imdb_gpu.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "colab_type": "text",
17 | "id": "fnOHnctkG6kW"
18 | },
19 | "source": [
20 | "# XLNet IMDB movie review classification project\n",
21 | "\n",
22 | "This notebook is for classifying the [imdb sentiment dataset](https://ai.stanford.edu/~amaas/data/sentiment/). It will be easy to edit this notebook in order to run all of the classification tasks referenced in the [XLNet paper](https://arxiv.org/abs/1906.08237). Whilst you cannot expect to obtain the state-of-the-art results in the paper on a GPU, this model will still score very highly. "
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {
28 | "colab_type": "text",
29 | "id": "2mBzLdrdzodb"
30 | },
31 | "source": [
32 | "## Setup\n",
33 | "Install dependencies"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 0,
39 | "metadata": {
40 | "colab": {},
41 | "colab_type": "code",
42 | "id": "hRHRPImGUth7"
43 | },
44 | "outputs": [],
45 | "source": [
46 | "! pip install sentencepiece"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {
52 | "colab_type": "text",
53 | "id": "jy8gUsPuJNyw"
54 | },
55 | "source": [
56 | "Download the pretrained XLNet model and unzip"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 0,
62 | "metadata": {
63 | "colab": {},
64 | "colab_type": "code",
65 | "id": "HfPDGsUtHKG0"
66 | },
67 | "outputs": [],
68 | "source": [
69 | "# only needs to be done once\n",
70 | "! wget https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip\n",
71 | "! unzip cased_L-24_H-1024_A-16.zip "
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {
77 | "colab_type": "text",
78 | "id": "4uUwjq3BJRbu"
79 | },
80 | "source": [
81 | "Download extract the imdb dataset - surpessing output"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 0,
87 | "metadata": {
88 | "colab": {},
89 | "colab_type": "code",
90 | "id": "QOGRICbOIsU8"
91 | },
92 | "outputs": [],
93 | "source": [
94 | "! wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n",
95 | "! tar zxf aclImdb_v1.tar.gz"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {
101 | "colab_type": "text",
102 | "id": "yGY_ggUUMrwU"
103 | },
104 | "source": [
105 | "Git clone XLNet repo for access to run_classifier and the rest of the xlnet module"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 0,
111 | "metadata": {
112 | "colab": {},
113 | "colab_type": "code",
114 | "id": "-r190eYVMpiG"
115 | },
116 | "outputs": [],
117 | "source": [
118 | "! git clone https://github.com/zihangdai/xlnet.git"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "metadata": {
124 | "colab_type": "text",
125 | "id": "jDP-IaVuPC-z"
126 | },
127 | "source": [
128 | "## Define Variables\n",
129 | "Define all the dirs: data, xlnet scripts & pretrained model. \n",
130 | "If you would like to save models then you can authenticate a GCP account and use that for the OUTPUT_DIR & CHECKPOINT_DIR - you will need a large amount storage to fix these models. \n",
131 | "\n",
132 | "Alternatively it is easy to integrate a google drive account, checkout this guide for [I/O in colab](https://colab.research.google.com/notebooks/io.ipynb) but rememeber these will take up a large amount of storage. \n"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 0,
138 | "metadata": {
139 | "colab": {},
140 | "colab_type": "code",
141 | "id": "y7N_xVwavQlV"
142 | },
143 | "outputs": [],
144 | "source": [
145 | "SCRIPTS_DIR = 'xlnet' #@param {type:\"string\"}\n",
146 | "DATA_DIR = 'aclImdb' #@param {type:\"string\"}\n",
147 | "OUTPUT_DIR = 'proc_data/imdb' #@param {type:\"string\"}\n",
148 | "PRETRAINED_MODEL_DIR = 'xlnet_cased_L-24_H-1024_A-16' #@param {type:\"string\"}\n",
149 | "CHECKPOINT_DIR = 'exp/imdb' #@param {type:\"string\"}"
150 | ]
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {
155 | "colab_type": "text",
156 | "id": "jR6euqwL1KBV"
157 | },
158 | "source": [
159 | "## Run Model\n",
160 | "This will set off the fine tuning of XLNet. There are a few things to note here:\n",
161 | "\n",
162 | "\n",
163 | "1. This script will train and evaluate the model\n",
164 | "2. This will store the results locally on colab and will be lost when you are disconnected from the runtime\n",
165 | "3. This uses the large version of the model (base not released presently)\n",
166 | "4. We are using a max seq length of 128 with a batch size of 8 please refer to the [README](https://github.com/zihangdai/xlnet#memory-issue-during-finetuning) for why this is.\n",
167 | "5. This will take approx 4hrs to run on GPU.\n",
168 | "\n"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 0,
174 | "metadata": {
175 | "colab": {},
176 | "colab_type": "code",
177 | "id": "CEMuT6LU0avg"
178 | },
179 | "outputs": [],
180 | "source": [
181 | "train_command = \"python xlnet/run_classifier.py \\\n",
182 | " --do_train=True \\\n",
183 | " --do_eval=True \\\n",
184 | " --eval_all_ckpt=True \\\n",
185 | " --task_name=imdb \\\n",
186 | " --data_dir=\"+DATA_DIR+\" \\\n",
187 | " --output_dir=\"+OUTPUT_DIR+\" \\\n",
188 | " --model_dir=\"+CHECKPOINT_DIR+\" \\\n",
189 | " --uncased=False \\\n",
190 | " --spiece_model_file=\"+PRETRAINED_MODEL_DIR+\"/spiece.model \\\n",
191 | " --model_config_path=\"+PRETRAINED_MODEL_DIR+\"/xlnet_config.json \\\n",
192 | " --init_checkpoint=\"+PRETRAINED_MODEL_DIR+\"/xlnet_model.ckpt \\\n",
193 | " --max_seq_length=128 \\\n",
194 | " --train_batch_size=8 \\\n",
195 | " --eval_batch_size=8 \\\n",
196 | " --num_hosts=1 \\\n",
197 | " --num_core_per_host=1 \\\n",
198 | " --learning_rate=2e-5 \\\n",
199 | " --train_steps=4000 \\\n",
200 | " --warmup_steps=500 \\\n",
201 | " --save_steps=500 \\\n",
202 | " --iterations=500\"\n",
203 | "\n",
204 | "! {train_command}\n"
205 | ]
206 | },
207 | {
208 | "cell_type": "markdown",
209 | "metadata": {
210 | "colab_type": "text",
211 | "id": "VvhqD-sO0Kyh"
212 | },
213 | "source": [
214 | "## Running & Results\n",
215 | "These are the results that I got from running this experiment\n",
216 | "### Params\n",
217 | "* --max_seq_length=128 \\\n",
218 | "* --train_batch_size= 8 \n",
219 | "\n",
220 | "### Times\n",
221 | "* Training: 1hr 11mins\n",
222 | "* Evaluation: 2.5hr\n",
223 | "\n",
224 | "### Results\n",
225 | "* Most accurate model on final step\n",
226 | "* Accuracy: 0.92416, eval_loss: 0.31708\n"
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "metadata": {
232 | "colab_type": "text",
233 | "id": "XUW2avFM_fi_"
234 | },
235 | "source": [
236 | "### Model\n",
237 | "\n",
238 | "* The trained model checkpoints can be found in 'exp/imdb'\n",
239 | "\n"
240 | ]
241 | }
242 | ],
243 | "metadata": {
244 | "accelerator": "GPU",
245 | "colab": {
246 | "collapsed_sections": [],
247 | "include_colab_link": true,
248 | "name": "XLNet-imdb-GPU.ipynb",
249 | "provenance": [],
250 | "toc_visible": true,
251 | "version": "0.3.2"
252 | },
253 | "kernelspec": {
254 | "display_name": "Python 3",
255 | "language": "python",
256 | "name": "python3"
257 | }
258 | },
259 | "nbformat": 4,
260 | "nbformat_minor": 1
261 | }
262 |
--------------------------------------------------------------------------------
/prepro_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import unicodedata
7 | import six
8 | from functools import partial
9 |
10 |
11 | SPIECE_UNDERLINE = '▁'
12 |
13 |
14 | def printable_text(text):
15 | """Returns text encoded in a way suitable for print or `tf.logging`."""
16 |
17 | # These functions want `str` for both Python2 and Python3, but in one case
18 | # it's a Unicode string and in the other it's a byte string.
19 | if six.PY3:
20 | if isinstance(text, str):
21 | return text
22 | elif isinstance(text, bytes):
23 | return text.decode("utf-8", "ignore")
24 | else:
25 | raise ValueError("Unsupported string type: %s" % (type(text)))
26 | elif six.PY2:
27 | if isinstance(text, str):
28 | return text
29 | elif isinstance(text, unicode):
30 | return text.encode("utf-8")
31 | else:
32 | raise ValueError("Unsupported string type: %s" % (type(text)))
33 | else:
34 | raise ValueError("Not running on Python2 or Python 3?")
35 |
36 |
37 | def print_(*args):
38 | new_args = []
39 | for arg in args:
40 | if isinstance(arg, list):
41 | s = [printable_text(i) for i in arg]
42 | s = ' '.join(s)
43 | new_args.append(s)
44 | else:
45 | new_args.append(printable_text(arg))
46 | print(*new_args)
47 |
48 |
49 | def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
50 | if remove_space:
51 | outputs = ' '.join(inputs.strip().split())
52 | else:
53 | outputs = inputs
54 | outputs = outputs.replace("``", '"').replace("''", '"')
55 |
56 | if six.PY2 and isinstance(outputs, str):
57 | outputs = outputs.decode('utf-8')
58 |
59 | if not keep_accents:
60 | outputs = unicodedata.normalize('NFKD', outputs)
61 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
62 | if lower:
63 | outputs = outputs.lower()
64 |
65 | return outputs
66 |
67 |
68 | def encode_pieces(sp_model, text, return_unicode=True, sample=False):
69 | # return_unicode is used only for py2
70 |
71 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2
72 | if six.PY2 and isinstance(text, unicode):
73 | text = text.encode('utf-8')
74 |
75 | if not sample:
76 | pieces = sp_model.EncodeAsPieces(text)
77 | else:
78 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
79 | new_pieces = []
80 | for piece in pieces:
81 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
82 | cur_pieces = sp_model.EncodeAsPieces(
83 | piece[:-1].replace(SPIECE_UNDERLINE, ''))
84 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
85 | if len(cur_pieces[0]) == 1:
86 | cur_pieces = cur_pieces[1:]
87 | else:
88 | cur_pieces[0] = cur_pieces[0][1:]
89 | cur_pieces.append(piece[-1])
90 | new_pieces.extend(cur_pieces)
91 | else:
92 | new_pieces.append(piece)
93 |
94 | # note(zhiliny): convert back to unicode for py2
95 | if six.PY2 and return_unicode:
96 | ret_pieces = []
97 | for piece in new_pieces:
98 | if isinstance(piece, str):
99 | piece = piece.decode('utf-8')
100 | ret_pieces.append(piece)
101 | new_pieces = ret_pieces
102 |
103 | return new_pieces
104 |
105 |
106 | def encode_ids(sp_model, text, sample=False):
107 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
108 | ids = [sp_model.PieceToId(piece) for piece in pieces]
109 | return ids
110 |
111 |
112 | if __name__ == '__main__':
113 | import sentencepiece as spm
114 |
115 | sp = spm.SentencePieceProcessor()
116 | sp.load('sp10m.uncased.v3.model')
117 |
118 | print_(u'I was born in 2000, and this is falsé.')
119 | print_(u'ORIGINAL', sp.EncodeAsPieces(u'I was born in 2000, and this is falsé.'))
120 | print_(u'OURS', encode_pieces(sp, u'I was born in 2000, and this is falsé.'))
121 | print(encode_ids(sp, u'I was born in 2000, and this is falsé.'))
122 | print_('')
123 | prepro_func = partial(preprocess_text, lower=True)
124 | print_(prepro_func('I was born in 2000, and this is falsé.'))
125 | print_('ORIGINAL', sp.EncodeAsPieces(prepro_func('I was born in 2000, and this is falsé.')))
126 | print_('OURS', encode_pieces(sp, prepro_func('I was born in 2000, and this is falsé.')))
127 | print(encode_ids(sp, prepro_func('I was born in 2000, and this is falsé.')))
128 | print_('')
129 | print_('I was born in 2000, and this is falsé.')
130 | print_('ORIGINAL', sp.EncodeAsPieces('I was born in 2000, and this is falsé.'))
131 | print_('OURS', encode_pieces(sp, 'I was born in 2000, and this is falsé.'))
132 | print(encode_ids(sp, 'I was born in 2000, and this is falsé.'))
133 | print_('')
134 | print_('I was born in 92000, and this is falsé.')
135 | print_('ORIGINAL', sp.EncodeAsPieces('I was born in 92000, and this is falsé.'))
136 | print_('OURS', encode_pieces(sp, 'I was born in 92000, and this is falsé.'))
137 | print(encode_ids(sp, 'I was born in 92000, and this is falsé.'))
138 |
139 |
--------------------------------------------------------------------------------
/run_classifier.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from os.path import join
6 | from absl import flags
7 | import os
8 | import sys
9 | import csv
10 | import collections
11 | import numpy as np
12 | import time
13 | import math
14 | import json
15 | import random
16 | from copy import copy
17 | from collections import defaultdict as dd
18 |
19 | import absl.logging as _logging # pylint: disable=unused-import
20 | import tensorflow as tf
21 |
22 | import sentencepiece as spm
23 |
24 | from data_utils import SEP_ID, VOCAB_SIZE, CLS_ID
25 | import model_utils
26 | import function_builder
27 | from classifier_utils import PaddingInputExample
28 | from classifier_utils import convert_single_example
29 | from prepro_utils import preprocess_text, encode_ids
30 |
31 |
32 | # Model
33 | flags.DEFINE_string("model_config_path", default=None,
34 | help="Model config path.")
35 | flags.DEFINE_float("dropout", default=0.1,
36 | help="Dropout rate.")
37 | flags.DEFINE_float("dropatt", default=0.1,
38 | help="Attention dropout rate.")
39 | flags.DEFINE_integer("clamp_len", default=-1,
40 | help="Clamp length")
41 | flags.DEFINE_string("summary_type", default="last",
42 | help="Method used to summarize a sequence into a compact vector.")
43 | flags.DEFINE_bool("use_summ_proj", default=True,
44 | help="Whether to use projection for summarizing sequences.")
45 | flags.DEFINE_bool("use_bfloat16", False,
46 | help="Whether to use bfloat16.")
47 |
48 | # Parameter initialization
49 | flags.DEFINE_enum("init", default="normal",
50 | enum_values=["normal", "uniform"],
51 | help="Initialization method.")
52 | flags.DEFINE_float("init_std", default=0.02,
53 | help="Initialization std when init is normal.")
54 | flags.DEFINE_float("init_range", default=0.1,
55 | help="Initialization std when init is uniform.")
56 |
57 | # I/O paths
58 | flags.DEFINE_bool("overwrite_data", default=False,
59 | help="If False, will use cached data if available.")
60 | flags.DEFINE_string("init_checkpoint", default=None,
61 | help="checkpoint path for initializing the model. "
62 | "Could be a pretrained model or a finetuned model.")
63 | flags.DEFINE_string("output_dir", default="",
64 | help="Output dir for TF records.")
65 | flags.DEFINE_string("spiece_model_file", default="",
66 | help="Sentence Piece model path.")
67 | flags.DEFINE_string("model_dir", default="",
68 | help="Directory for saving the finetuned model.")
69 | flags.DEFINE_string("data_dir", default="",
70 | help="Directory for input data.")
71 |
72 | # TPUs and machines
73 | flags.DEFINE_bool("use_tpu", default=False, help="whether to use TPU.")
74 | flags.DEFINE_integer("num_hosts", default=1, help="How many TPU hosts.")
75 | flags.DEFINE_integer("num_core_per_host", default=8,
76 | help="8 for TPU v2 and v3-8, 16 for larger TPU v3 pod. In the context "
77 | "of GPU training, it refers to the number of GPUs used.")
78 | flags.DEFINE_string("tpu_job_name", default=None, help="TPU worker job name.")
79 | flags.DEFINE_string("tpu", default=None, help="TPU name.")
80 | flags.DEFINE_string("tpu_zone", default=None, help="TPU zone.")
81 | flags.DEFINE_string("gcp_project", default=None, help="gcp project.")
82 | flags.DEFINE_string("master", default=None, help="master")
83 | flags.DEFINE_integer("iterations", default=1000,
84 | help="number of iterations per TPU training loop.")
85 |
86 | # training
87 | flags.DEFINE_bool("do_train", default=False, help="whether to do training")
88 | flags.DEFINE_integer("train_steps", default=1000,
89 | help="Number of training steps")
90 | flags.DEFINE_integer("warmup_steps", default=0, help="number of warmup steps")
91 | flags.DEFINE_float("learning_rate", default=1e-5, help="initial learning rate")
92 | flags.DEFINE_float("lr_layer_decay_rate", 1.0,
93 | "Top layer: lr[L] = FLAGS.learning_rate."
94 | "Low layer: lr[l-1] = lr[l] * lr_layer_decay_rate.")
95 | flags.DEFINE_float("min_lr_ratio", default=0.0,
96 | help="min lr ratio for cos decay.")
97 | flags.DEFINE_float("clip", default=1.0, help="Gradient clipping")
98 | flags.DEFINE_integer("max_save", default=0,
99 | help="Max number of checkpoints to save. Use 0 to save all.")
100 | flags.DEFINE_integer("save_steps", default=None,
101 | help="Save the model for every save_steps. "
102 | "If None, not to save any model.")
103 | flags.DEFINE_integer("train_batch_size", default=8,
104 | help="Batch size for training")
105 | flags.DEFINE_float("weight_decay", default=0.00, help="Weight decay rate")
106 | flags.DEFINE_float("adam_epsilon", default=1e-8, help="Adam epsilon")
107 | flags.DEFINE_string("decay_method", default="poly", help="poly or cos")
108 |
109 | # evaluation
110 | flags.DEFINE_bool("do_eval", default=False, help="whether to do eval")
111 | flags.DEFINE_bool("do_predict", default=False, help="whether to do prediction")
112 | flags.DEFINE_float("predict_threshold", default=0,
113 | help="Threshold for binary prediction.")
114 | flags.DEFINE_string("eval_split", default="dev", help="could be dev or test")
115 | flags.DEFINE_integer("eval_batch_size", default=128,
116 | help="batch size for evaluation")
117 | flags.DEFINE_integer("predict_batch_size", default=128,
118 | help="batch size for prediction.")
119 | flags.DEFINE_string("predict_dir", default=None,
120 | help="Dir for saving prediction files.")
121 | flags.DEFINE_bool("eval_all_ckpt", default=False,
122 | help="Eval all ckpts. If False, only evaluate the last one.")
123 | flags.DEFINE_string("predict_ckpt", default=None,
124 | help="Ckpt path for do_predict. If None, use the last one.")
125 |
126 | # task specific
127 | flags.DEFINE_string("task_name", default=None, help="Task name")
128 | flags.DEFINE_integer("max_seq_length", default=128, help="Max sequence length")
129 | flags.DEFINE_integer("shuffle_buffer", default=2048,
130 | help="Buffer size used for shuffle.")
131 | flags.DEFINE_integer("num_passes", default=1,
132 | help="Num passes for processing training data. "
133 | "This is use to batch data without loss for TPUs.")
134 | flags.DEFINE_bool("uncased", default=False,
135 | help="Use uncased.")
136 | flags.DEFINE_string("cls_scope", default=None,
137 | help="Classifier layer scope.")
138 | flags.DEFINE_bool("is_regression", default=False,
139 | help="Whether it's a regression task.")
140 |
141 | FLAGS = flags.FLAGS
142 |
143 |
144 | class InputExample(object):
145 | """A single training/test example for simple sequence classification."""
146 |
147 | def __init__(self, guid, text_a, text_b=None, label=None):
148 | """Constructs a InputExample.
149 | Args:
150 | guid: Unique id for the example.
151 | text_a: string. The untokenized text of the first sequence. For single
152 | sequence tasks, only this sequence must be specified.
153 | text_b: (Optional) string. The untokenized text of the second sequence.
154 | Only must be specified for sequence pair tasks.
155 | label: (Optional) string. The label of the example. This should be
156 | specified for train and dev examples, but not for test examples.
157 | """
158 | self.guid = guid
159 | self.text_a = text_a
160 | self.text_b = text_b
161 | self.label = label
162 |
163 |
164 | class DataProcessor(object):
165 | """Base class for data converters for sequence classification data sets."""
166 |
167 | def get_train_examples(self, data_dir):
168 | """Gets a collection of `InputExample`s for the train set."""
169 | raise NotImplementedError()
170 |
171 | def get_dev_examples(self, data_dir):
172 | """Gets a collection of `InputExample`s for the dev set."""
173 | raise NotImplementedError()
174 |
175 | def get_test_examples(self, data_dir):
176 | """Gets a collection of `InputExample`s for prediction."""
177 | raise NotImplementedError()
178 |
179 | def get_labels(self):
180 | """Gets the list of labels for this data set."""
181 | raise NotImplementedError()
182 |
183 | @classmethod
184 | def _read_tsv(cls, input_file, quotechar=None):
185 | """Reads a tab separated value file."""
186 | with tf.gfile.Open(input_file, "r") as f:
187 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
188 | lines = []
189 | for line in reader:
190 | if len(line) == 0: continue
191 | lines.append(line)
192 | return lines
193 |
194 |
195 | class GLUEProcessor(DataProcessor):
196 | def __init__(self):
197 | self.train_file = "train.tsv"
198 | self.dev_file = "dev.tsv"
199 | self.test_file = "test.tsv"
200 | self.label_column = None
201 | self.text_a_column = None
202 | self.text_b_column = None
203 | self.contains_header = True
204 | self.test_text_a_column = None
205 | self.test_text_b_column = None
206 | self.test_contains_header = True
207 |
208 | def get_train_examples(self, data_dir):
209 | """See base class."""
210 | return self._create_examples(
211 | self._read_tsv(os.path.join(data_dir, self.train_file)), "train")
212 |
213 | def get_dev_examples(self, data_dir):
214 | """See base class."""
215 | return self._create_examples(
216 | self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev")
217 |
218 | def get_test_examples(self, data_dir):
219 | """See base class."""
220 | if self.test_text_a_column is None:
221 | self.test_text_a_column = self.text_a_column
222 | if self.test_text_b_column is None:
223 | self.test_text_b_column = self.text_b_column
224 |
225 | return self._create_examples(
226 | self._read_tsv(os.path.join(data_dir, self.test_file)), "test")
227 |
228 | def get_labels(self):
229 | """See base class."""
230 | return ["0", "1"]
231 |
232 | def _create_examples(self, lines, set_type):
233 | """Creates examples for the training and dev sets."""
234 | examples = []
235 | for (i, line) in enumerate(lines):
236 | if i == 0 and self.contains_header and set_type != "test":
237 | continue
238 | if i == 0 and self.test_contains_header and set_type == "test":
239 | continue
240 | guid = "%s-%s" % (set_type, i)
241 |
242 | a_column = (self.text_a_column if set_type != "test" else
243 | self.test_text_a_column)
244 | b_column = (self.text_b_column if set_type != "test" else
245 | self.test_text_b_column)
246 |
247 | # there are some incomplete lines in QNLI
248 | if len(line) <= a_column:
249 | tf.logging.warning('Incomplete line, ignored.')
250 | continue
251 | text_a = line[a_column]
252 |
253 | if b_column is not None:
254 | if len(line) <= b_column:
255 | tf.logging.warning('Incomplete line, ignored.')
256 | continue
257 | text_b = line[b_column]
258 | else:
259 | text_b = None
260 |
261 | if set_type == "test":
262 | label = self.get_labels()[0]
263 | else:
264 | if len(line) <= self.label_column:
265 | tf.logging.warning('Incomplete line, ignored.')
266 | continue
267 | label = line[self.label_column]
268 | examples.append(
269 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
270 | return examples
271 |
272 |
273 | class Yelp5Processor(DataProcessor):
274 | def get_train_examples(self, data_dir):
275 | return self._create_examples(os.path.join(data_dir, "train.csv"))
276 |
277 | def get_dev_examples(self, data_dir):
278 | return self._create_examples(os.path.join(data_dir, "test.csv"))
279 |
280 | def get_labels(self):
281 | """See base class."""
282 | return ["1", "2", "3", "4", "5"]
283 |
284 | def _create_examples(self, input_file):
285 | """Creates examples for the training and dev sets."""
286 | examples = []
287 | with tf.gfile.Open(input_file) as f:
288 | reader = csv.reader(f)
289 | for i, line in enumerate(reader):
290 |
291 | label = line[0]
292 | text_a = line[1].replace('""', '"').replace('\\"', '"')
293 | examples.append(
294 | InputExample(guid=str(i), text_a=text_a, text_b=None, label=label))
295 | return examples
296 |
297 |
298 | class ImdbProcessor(DataProcessor):
299 | def get_labels(self):
300 | return ["neg", "pos"]
301 |
302 | def get_train_examples(self, data_dir):
303 | return self._create_examples(os.path.join(data_dir, "train"))
304 |
305 | def get_dev_examples(self, data_dir):
306 | return self._create_examples(os.path.join(data_dir, "test"))
307 |
308 | def _create_examples(self, data_dir):
309 | examples = []
310 | for label in ["neg", "pos"]:
311 | cur_dir = os.path.join(data_dir, label)
312 | for filename in tf.gfile.ListDirectory(cur_dir):
313 | if not filename.endswith("txt"): continue
314 |
315 | path = os.path.join(cur_dir, filename)
316 | with tf.gfile.Open(path) as f:
317 | text = f.read().strip().replace("
", " ")
318 | examples.append(InputExample(
319 | guid="unused_id", text_a=text, text_b=None, label=label))
320 | return examples
321 |
322 |
323 | class MnliMatchedProcessor(GLUEProcessor):
324 | def __init__(self):
325 | super(MnliMatchedProcessor, self).__init__()
326 | self.dev_file = "dev_matched.tsv"
327 | self.test_file = "test_matched.tsv"
328 | self.label_column = -1
329 | self.text_a_column = 8
330 | self.text_b_column = 9
331 |
332 | def get_labels(self):
333 | return ["contradiction", "entailment", "neutral"]
334 |
335 |
336 | class MnliMismatchedProcessor(MnliMatchedProcessor):
337 | def __init__(self):
338 | super(MnliMismatchedProcessor, self).__init__()
339 | self.dev_file = "dev_mismatched.tsv"
340 | self.test_file = "test_mismatched.tsv"
341 |
342 |
343 | class StsbProcessor(GLUEProcessor):
344 | def __init__(self):
345 | super(StsbProcessor, self).__init__()
346 | self.label_column = 9
347 | self.text_a_column = 7
348 | self.text_b_column = 8
349 |
350 | def get_labels(self):
351 | return [0.0]
352 |
353 | def _create_examples(self, lines, set_type):
354 | """Creates examples for the training and dev sets."""
355 | examples = []
356 | for (i, line) in enumerate(lines):
357 | if i == 0 and self.contains_header and set_type != "test":
358 | continue
359 | if i == 0 and self.test_contains_header and set_type == "test":
360 | continue
361 | guid = "%s-%s" % (set_type, i)
362 |
363 | a_column = (self.text_a_column if set_type != "test" else
364 | self.test_text_a_column)
365 | b_column = (self.text_b_column if set_type != "test" else
366 | self.test_text_b_column)
367 |
368 | # there are some incomplete lines in QNLI
369 | if len(line) <= a_column:
370 | tf.logging.warning('Incomplete line, ignored.')
371 | continue
372 | text_a = line[a_column]
373 |
374 | if b_column is not None:
375 | if len(line) <= b_column:
376 | tf.logging.warning('Incomplete line, ignored.')
377 | continue
378 | text_b = line[b_column]
379 | else:
380 | text_b = None
381 |
382 | if set_type == "test":
383 | label = self.get_labels()[0]
384 | else:
385 | if len(line) <= self.label_column:
386 | tf.logging.warning('Incomplete line, ignored.')
387 | continue
388 | label = float(line[self.label_column])
389 | examples.append(
390 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
391 |
392 | return examples
393 |
394 |
395 | def file_based_convert_examples_to_features(
396 | examples, label_list, max_seq_length, tokenize_fn, output_file,
397 | num_passes=1):
398 | """Convert a set of `InputExample`s to a TFRecord file."""
399 |
400 | # do not create duplicated records
401 | if tf.gfile.Exists(output_file) and not FLAGS.overwrite_data:
402 | tf.logging.info("Do not overwrite tfrecord {} exists.".format(output_file))
403 | return
404 |
405 | tf.logging.info("Create new tfrecord {}.".format(output_file))
406 |
407 | writer = tf.python_io.TFRecordWriter(output_file)
408 |
409 | if num_passes > 1:
410 | examples *= num_passes
411 |
412 | for (ex_index, example) in enumerate(examples):
413 | if ex_index % 10000 == 0:
414 | tf.logging.info("Writing example {} of {}".format(ex_index,
415 | len(examples)))
416 |
417 | feature = convert_single_example(ex_index, example, label_list,
418 | max_seq_length, tokenize_fn)
419 |
420 | def create_int_feature(values):
421 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
422 | return f
423 |
424 | def create_float_feature(values):
425 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
426 | return f
427 |
428 | features = collections.OrderedDict()
429 | features["input_ids"] = create_int_feature(feature.input_ids)
430 | features["input_mask"] = create_float_feature(feature.input_mask)
431 | features["segment_ids"] = create_int_feature(feature.segment_ids)
432 | if label_list is not None:
433 | features["label_ids"] = create_int_feature([feature.label_id])
434 | else:
435 | features["label_ids"] = create_float_feature([float(feature.label_id)])
436 | features["is_real_example"] = create_int_feature(
437 | [int(feature.is_real_example)])
438 |
439 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
440 | writer.write(tf_example.SerializeToString())
441 | writer.close()
442 |
443 |
444 | def file_based_input_fn_builder(input_file, seq_length, is_training,
445 | drop_remainder):
446 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
447 |
448 |
449 | name_to_features = {
450 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
451 | "input_mask": tf.FixedLenFeature([seq_length], tf.float32),
452 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
453 | "label_ids": tf.FixedLenFeature([], tf.int64),
454 | "is_real_example": tf.FixedLenFeature([], tf.int64),
455 | }
456 | if FLAGS.is_regression:
457 | name_to_features["label_ids"] = tf.FixedLenFeature([], tf.float32)
458 |
459 | tf.logging.info("Input tfrecord file {}".format(input_file))
460 |
461 | def _decode_record(record, name_to_features):
462 | """Decodes a record to a TensorFlow example."""
463 | example = tf.parse_single_example(record, name_to_features)
464 |
465 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
466 | # So cast all int64 to int32.
467 | for name in list(example.keys()):
468 | t = example[name]
469 | if t.dtype == tf.int64:
470 | t = tf.cast(t, tf.int32)
471 | example[name] = t
472 |
473 | return example
474 |
475 | def input_fn(params, input_context=None):
476 | """The actual input function."""
477 | if FLAGS.use_tpu:
478 | batch_size = params["batch_size"]
479 | elif is_training:
480 | batch_size = FLAGS.train_batch_size
481 | elif FLAGS.do_eval:
482 | batch_size = FLAGS.eval_batch_size
483 | else:
484 | batch_size = FLAGS.predict_batch_size
485 |
486 | d = tf.data.TFRecordDataset(input_file)
487 | # Shard the dataset to difference devices
488 | if input_context is not None:
489 | tf.logging.info("Input pipeline id %d out of %d",
490 | input_context.input_pipeline_id, input_context.num_replicas_in_sync)
491 | d = d.shard(input_context.num_input_pipelines,
492 | input_context.input_pipeline_id)
493 |
494 | # For training, we want a lot of parallel reading and shuffling.
495 | # For eval, we want no shuffling and parallel reading doesn't matter.
496 | if is_training:
497 | d = d.shuffle(buffer_size=FLAGS.shuffle_buffer)
498 | d = d.repeat()
499 |
500 | d = d.apply(
501 | tf.contrib.data.map_and_batch(
502 | lambda record: _decode_record(record, name_to_features),
503 | batch_size=batch_size,
504 | drop_remainder=drop_remainder))
505 |
506 | return d
507 |
508 | return input_fn
509 |
510 |
511 | def get_model_fn(n_class):
512 | def model_fn(features, labels, mode, params):
513 | #### Training or Evaluation
514 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
515 |
516 | #### Get loss from inputs
517 | if FLAGS.is_regression:
518 | (total_loss, per_example_loss, logits
519 | ) = function_builder.get_regression_loss(FLAGS, features, is_training)
520 | else:
521 | (total_loss, per_example_loss, logits
522 | ) = function_builder.get_classification_loss(
523 | FLAGS, features, n_class, is_training)
524 |
525 | #### Check model parameters
526 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
527 | tf.logging.info('#params: {}'.format(num_params))
528 |
529 | #### load pretrained models
530 | scaffold_fn = model_utils.init_from_checkpoint(FLAGS)
531 |
532 | #### Evaluation mode
533 | if mode == tf.estimator.ModeKeys.EVAL:
534 | assert FLAGS.num_hosts == 1
535 |
536 | def metric_fn(per_example_loss, label_ids, logits, is_real_example):
537 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
538 | eval_input_dict = {
539 | 'labels': label_ids,
540 | 'predictions': predictions,
541 | 'weights': is_real_example
542 | }
543 | accuracy = tf.metrics.accuracy(**eval_input_dict)
544 |
545 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
546 | return {
547 | 'eval_accuracy': accuracy,
548 | 'eval_loss': loss}
549 |
550 | def regression_metric_fn(
551 | per_example_loss, label_ids, logits, is_real_example):
552 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
553 | pearsonr = tf.contrib.metrics.streaming_pearson_correlation(
554 | logits, label_ids, weights=is_real_example)
555 | return {'eval_loss': loss, 'eval_pearsonr': pearsonr}
556 |
557 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
558 |
559 | #### Constucting evaluation TPUEstimatorSpec with new cache.
560 | label_ids = tf.reshape(features['label_ids'], [-1])
561 |
562 | if FLAGS.is_regression:
563 | metric_fn = regression_metric_fn
564 | else:
565 | metric_fn = metric_fn
566 | metric_args = [per_example_loss, label_ids, logits, is_real_example]
567 |
568 | if FLAGS.use_tpu:
569 | eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
570 | mode=mode,
571 | loss=total_loss,
572 | eval_metrics=(metric_fn, metric_args),
573 | scaffold_fn=scaffold_fn)
574 | else:
575 | eval_spec = tf.estimator.EstimatorSpec(
576 | mode=mode,
577 | loss=total_loss,
578 | eval_metric_ops=metric_fn(*metric_args))
579 |
580 | return eval_spec
581 |
582 | elif mode == tf.estimator.ModeKeys.PREDICT:
583 | label_ids = tf.reshape(features["label_ids"], [-1])
584 |
585 | predictions = {
586 | "logits": logits,
587 | "labels": label_ids,
588 | "is_real": features["is_real_example"]
589 | }
590 |
591 | if FLAGS.use_tpu:
592 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
593 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
594 | else:
595 | output_spec = tf.estimator.EstimatorSpec(
596 | mode=mode, predictions=predictions)
597 | return output_spec
598 |
599 | #### Configuring the optimizer
600 | train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss)
601 |
602 | monitor_dict = {}
603 | monitor_dict["lr"] = learning_rate
604 |
605 | #### Constucting training TPUEstimatorSpec with new cache.
606 | if FLAGS.use_tpu:
607 | #### Creating host calls
608 | if not FLAGS.is_regression:
609 | label_ids = tf.reshape(features['label_ids'], [-1])
610 | predictions = tf.argmax(logits, axis=-1, output_type=label_ids.dtype)
611 | is_correct = tf.equal(predictions, label_ids)
612 | accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
613 |
614 | monitor_dict["accuracy"] = accuracy
615 |
616 | host_call = function_builder.construct_scalar_host_call(
617 | monitor_dict=monitor_dict,
618 | model_dir=FLAGS.model_dir,
619 | prefix="train/",
620 | reduce_fn=tf.reduce_mean)
621 | else:
622 | host_call = None
623 |
624 | train_spec = tf.contrib.tpu.TPUEstimatorSpec(
625 | mode=mode, loss=total_loss, train_op=train_op, host_call=host_call,
626 | scaffold_fn=scaffold_fn)
627 | else:
628 | train_spec = tf.estimator.EstimatorSpec(
629 | mode=mode, loss=total_loss, train_op=train_op)
630 |
631 | return train_spec
632 |
633 | return model_fn
634 |
635 |
636 | def main(_):
637 | tf.logging.set_verbosity(tf.logging.INFO)
638 |
639 | #### Validate flags
640 | if FLAGS.save_steps is not None:
641 | FLAGS.iterations = min(FLAGS.iterations, FLAGS.save_steps)
642 |
643 | if FLAGS.do_predict:
644 | predict_dir = FLAGS.predict_dir
645 | if not tf.gfile.Exists(predict_dir):
646 | tf.gfile.MakeDirs(predict_dir)
647 |
648 | processors = {
649 | "mnli_matched": MnliMatchedProcessor,
650 | "mnli_mismatched": MnliMismatchedProcessor,
651 | 'sts-b': StsbProcessor,
652 | 'imdb': ImdbProcessor,
653 | "yelp5": Yelp5Processor
654 | }
655 |
656 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
657 | raise ValueError(
658 | "At least one of `do_train`, `do_eval, `do_predict` or "
659 | "`do_submit` must be True.")
660 |
661 | if not tf.gfile.Exists(FLAGS.output_dir):
662 | tf.gfile.MakeDirs(FLAGS.output_dir)
663 |
664 | task_name = FLAGS.task_name.lower()
665 |
666 | if task_name not in processors:
667 | raise ValueError("Task not found: %s" % (task_name))
668 |
669 | processor = processors[task_name]()
670 | label_list = processor.get_labels() if not FLAGS.is_regression else None
671 |
672 | sp = spm.SentencePieceProcessor()
673 | sp.Load(FLAGS.spiece_model_file)
674 | def tokenize_fn(text):
675 | text = preprocess_text(text, lower=FLAGS.uncased)
676 | return encode_ids(sp, text)
677 |
678 | run_config = model_utils.configure_tpu(FLAGS)
679 |
680 | model_fn = get_model_fn(len(label_list) if label_list is not None else None)
681 |
682 | spm_basename = os.path.basename(FLAGS.spiece_model_file)
683 |
684 | # If TPU is not available, this will fall back to normal Estimator on CPU
685 | # or GPU.
686 | if FLAGS.use_tpu:
687 | estimator = tf.contrib.tpu.TPUEstimator(
688 | use_tpu=FLAGS.use_tpu,
689 | model_fn=model_fn,
690 | config=run_config,
691 | train_batch_size=FLAGS.train_batch_size,
692 | predict_batch_size=FLAGS.predict_batch_size,
693 | eval_batch_size=FLAGS.eval_batch_size)
694 | else:
695 | estimator = tf.estimator.Estimator(
696 | model_fn=model_fn,
697 | config=run_config)
698 |
699 | if FLAGS.do_train:
700 | train_file_base = "{}.len-{}.train.tf_record".format(
701 | spm_basename, FLAGS.max_seq_length)
702 | train_file = os.path.join(FLAGS.output_dir, train_file_base)
703 | tf.logging.info("Use tfrecord file {}".format(train_file))
704 |
705 | train_examples = processor.get_train_examples(FLAGS.data_dir)
706 | np.random.shuffle(train_examples)
707 | tf.logging.info("Num of train samples: {}".format(len(train_examples)))
708 |
709 | file_based_convert_examples_to_features(
710 | train_examples, label_list, FLAGS.max_seq_length, tokenize_fn,
711 | train_file, FLAGS.num_passes)
712 |
713 | train_input_fn = file_based_input_fn_builder(
714 | input_file=train_file,
715 | seq_length=FLAGS.max_seq_length,
716 | is_training=True,
717 | drop_remainder=True)
718 |
719 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
720 |
721 | if FLAGS.do_eval or FLAGS.do_predict:
722 | if FLAGS.eval_split == "dev":
723 | eval_examples = processor.get_dev_examples(FLAGS.data_dir)
724 | else:
725 | eval_examples = processor.get_test_examples(FLAGS.data_dir)
726 |
727 | tf.logging.info("Num of eval samples: {}".format(len(eval_examples)))
728 |
729 | if FLAGS.do_eval:
730 | # TPU requires a fixed batch size for all batches, therefore the number
731 | # of examples must be a multiple of the batch size, or else examples
732 | # will get dropped. So we pad with fake examples which are ignored
733 | # later on. These do NOT count towards the metric (all tf.metrics
734 | # support a per-instance weight, and these get a weight of 0.0).
735 | #
736 | # Modified in XL: We also adopt the same mechanism for GPUs.
737 | while len(eval_examples) % FLAGS.eval_batch_size != 0:
738 | eval_examples.append(PaddingInputExample())
739 |
740 | eval_file_base = "{}.len-{}.{}.eval.tf_record".format(
741 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split)
742 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base)
743 |
744 | file_based_convert_examples_to_features(
745 | eval_examples, label_list, FLAGS.max_seq_length, tokenize_fn,
746 | eval_file)
747 |
748 | assert len(eval_examples) % FLAGS.eval_batch_size == 0
749 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
750 |
751 | eval_input_fn = file_based_input_fn_builder(
752 | input_file=eval_file,
753 | seq_length=FLAGS.max_seq_length,
754 | is_training=False,
755 | drop_remainder=True)
756 |
757 | # Filter out all checkpoints in the directory
758 | steps_and_files = []
759 | filenames = tf.gfile.ListDirectory(FLAGS.model_dir)
760 |
761 | for filename in filenames:
762 | if filename.endswith(".index"):
763 | ckpt_name = filename[:-6]
764 | cur_filename = join(FLAGS.model_dir, ckpt_name)
765 | global_step = int(cur_filename.split("-")[-1])
766 | tf.logging.info("Add {} to eval list.".format(cur_filename))
767 | steps_and_files.append([global_step, cur_filename])
768 | steps_and_files = sorted(steps_and_files, key=lambda x: x[0])
769 |
770 | # Decide whether to evaluate all ckpts
771 | if not FLAGS.eval_all_ckpt:
772 | steps_and_files = steps_and_files[-1:]
773 |
774 | eval_results = []
775 | for global_step, filename in sorted(steps_and_files, key=lambda x: x[0]):
776 | ret = estimator.evaluate(
777 | input_fn=eval_input_fn,
778 | steps=eval_steps,
779 | checkpoint_path=filename)
780 |
781 | ret["step"] = global_step
782 | ret["path"] = filename
783 |
784 | eval_results.append(ret)
785 |
786 | tf.logging.info("=" * 80)
787 | log_str = "Eval result | "
788 | for key, val in sorted(ret.items(), key=lambda x: x[0]):
789 | log_str += "{} {} | ".format(key, val)
790 | tf.logging.info(log_str)
791 |
792 | key_name = "eval_pearsonr" if FLAGS.is_regression else "eval_accuracy"
793 | eval_results.sort(key=lambda x: x[key_name], reverse=True)
794 |
795 | tf.logging.info("=" * 80)
796 | log_str = "Best result | "
797 | for key, val in sorted(eval_results[0].items(), key=lambda x: x[0]):
798 | log_str += "{} {} | ".format(key, val)
799 | tf.logging.info(log_str)
800 |
801 | if FLAGS.do_predict:
802 | eval_file_base = "{}.len-{}.{}.predict.tf_record".format(
803 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split)
804 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base)
805 |
806 | file_based_convert_examples_to_features(
807 | eval_examples, label_list, FLAGS.max_seq_length, tokenize_fn,
808 | eval_file)
809 |
810 | pred_input_fn = file_based_input_fn_builder(
811 | input_file=eval_file,
812 | seq_length=FLAGS.max_seq_length,
813 | is_training=False,
814 | drop_remainder=False)
815 |
816 | predict_results = []
817 | with tf.gfile.Open(os.path.join(predict_dir, "{}.tsv".format(
818 | task_name)), "w") as fout:
819 | fout.write("index\tprediction\n")
820 |
821 | for pred_cnt, result in enumerate(estimator.predict(
822 | input_fn=pred_input_fn,
823 | yield_single_examples=True,
824 | checkpoint_path=FLAGS.predict_ckpt)):
825 | if pred_cnt % 1000 == 0:
826 | tf.logging.info("Predicting submission for example: {}".format(
827 | pred_cnt))
828 |
829 | logits = [float(x) for x in result["logits"].flat]
830 | predict_results.append(logits)
831 |
832 | if len(logits) == 1:
833 | label_out = logits[0]
834 | elif len(logits) == 2:
835 | if logits[1] - logits[0] > FLAGS.predict_threshold:
836 | label_out = label_list[1]
837 | else:
838 | label_out = label_list[0]
839 | elif len(logits) > 2:
840 | max_index = np.argmax(np.array(logits, dtype=np.float32))
841 | label_out = label_list[max_index]
842 | else:
843 | raise NotImplementedError
844 |
845 | fout.write("{}\t{}\n".format(pred_cnt, label_out))
846 |
847 | predict_json_path = os.path.join(predict_dir, "{}.logits.json".format(
848 | task_name))
849 |
850 | with tf.gfile.Open(predict_json_path, "w") as fp:
851 | json.dump(predict_results, fp, indent=4)
852 |
853 |
854 | if __name__ == "__main__":
855 | tf.app.run()
856 |
--------------------------------------------------------------------------------
/run_race.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from os.path import join
6 | from absl import flags
7 | import os
8 | import csv
9 | import collections
10 | import numpy as np
11 | import time
12 | import math
13 | import json
14 | import random
15 | from copy import copy
16 | from collections import defaultdict as dd
17 |
18 | from scipy.stats import pearsonr, spearmanr
19 | from sklearn.metrics import matthews_corrcoef, f1_score
20 |
21 | import absl.logging as _logging # pylint: disable=unused-import
22 |
23 | import tensorflow as tf
24 | import sentencepiece as spm
25 |
26 | from data_utils import SEP_ID, VOCAB_SIZE, CLS_ID
27 | import model_utils
28 | import function_builder
29 | from classifier_utils import PaddingInputExample
30 | from classifier_utils import convert_single_example
31 | from prepro_utils import preprocess_text, encode_ids
32 |
33 | # Model
34 | flags.DEFINE_string("model_config_path", default=None,
35 | help="Model config path.")
36 | flags.DEFINE_float("dropout", default=0.1,
37 | help="Dropout rate.")
38 | flags.DEFINE_float("dropatt", default=0.1,
39 | help="Attention dropout rate.")
40 | flags.DEFINE_integer("clamp_len", default=-1,
41 | help="Clamp length")
42 | flags.DEFINE_string("summary_type", default="last",
43 | help="Method used to summarize a sequence into a compact vector.")
44 | flags.DEFINE_bool("use_summ_proj", default=True,
45 | help="Whether to use projection for summarizing sequences.")
46 | flags.DEFINE_bool("use_bfloat16", default=False,
47 | help="Whether to use bfloat16.")
48 |
49 | # Parameter initialization
50 | flags.DEFINE_enum("init", default="normal",
51 | enum_values=["normal", "uniform"],
52 | help="Initialization method.")
53 | flags.DEFINE_float("init_std", default=0.02,
54 | help="Initialization std when init is normal.")
55 | flags.DEFINE_float("init_range", default=0.1,
56 | help="Initialization std when init is uniform.")
57 |
58 | # I/O paths
59 | flags.DEFINE_bool("overwrite_data", default=False,
60 | help="If False, will use cached data if available.")
61 | flags.DEFINE_string("init_checkpoint", default=None,
62 | help="checkpoint path for initializing the model. "
63 | "Could be a pretrained model or a finetuned model.")
64 | flags.DEFINE_string("output_dir", default="",
65 | help="Output dir for TF records.")
66 | flags.DEFINE_string("spiece_model_file", default="",
67 | help="Sentence Piece model path.")
68 | flags.DEFINE_string("model_dir", default="",
69 | help="Directory for saving the finetuned model.")
70 | flags.DEFINE_string("data_dir", default="",
71 | help="Directory for input data.")
72 |
73 | # TPUs and machines
74 | flags.DEFINE_bool("use_tpu", default=False, help="whether to use TPU.")
75 | flags.DEFINE_integer("num_hosts", default=1, help="How many TPU hosts.")
76 | flags.DEFINE_integer("num_core_per_host", default=8,
77 | help="8 for TPU v2 and v3-8, 16 for larger TPU v3 pod. In the context "
78 | "of GPU training, it refers to the number of GPUs used.")
79 | flags.DEFINE_string("tpu_job_name", default=None, help="TPU worker job name.")
80 | flags.DEFINE_string("tpu", default=None, help="TPU name.")
81 | flags.DEFINE_string("tpu_zone", default=None, help="TPU zone.")
82 | flags.DEFINE_string("gcp_project", default=None, help="gcp project.")
83 | flags.DEFINE_string("master", default=None, help="master")
84 | flags.DEFINE_integer("iterations", default=1000,
85 | help="number of iterations per TPU training loop.")
86 |
87 | # Training
88 | flags.DEFINE_bool("do_train", default=False, help="whether to do training")
89 | flags.DEFINE_integer("train_steps", default=12000,
90 | help="Number of training steps")
91 | flags.DEFINE_integer("warmup_steps", default=0, help="number of warmup steps")
92 | flags.DEFINE_float("learning_rate", default=2e-5, help="initial learning rate")
93 | flags.DEFINE_float("lr_layer_decay_rate", 1.0,
94 | "Top layer: lr[L] = FLAGS.learning_rate."
95 | "Low layer: lr[l-1] = lr[l] * lr_layer_decay_rate.")
96 | flags.DEFINE_float("min_lr_ratio", default=0.0,
97 | help="min lr ratio for cos decay.")
98 | flags.DEFINE_float("clip", default=1.0, help="Gradient clipping")
99 | flags.DEFINE_integer("max_save", default=0,
100 | help="Max number of checkpoints to save. Use 0 to save all.")
101 | flags.DEFINE_integer("save_steps", default=None,
102 | help="Save the model for every save_steps. "
103 | "If None, not to save any model.")
104 | flags.DEFINE_integer("train_batch_size", default=8,
105 | help="Batch size for training. Note that batch size 1 corresponds to "
106 | "4 sequences: one paragraph + one quesetion + 4 candidate answers.")
107 | flags.DEFINE_float("weight_decay", default=0.00, help="weight decay rate")
108 | flags.DEFINE_float("adam_epsilon", default=1e-6, help="adam epsilon")
109 | flags.DEFINE_string("decay_method", default="poly", help="poly or cos")
110 |
111 | # Evaluation
112 | flags.DEFINE_bool("do_eval", default=False, help="whether to do eval")
113 | flags.DEFINE_string("eval_split", default="dev",
114 | help="could be dev or test")
115 | flags.DEFINE_integer("eval_batch_size", default=32,
116 | help="Batch size for evaluation.")
117 |
118 | # Data config
119 | flags.DEFINE_integer("max_seq_length", default=512,
120 | help="Max length for the paragraph.")
121 | flags.DEFINE_integer("max_qa_length", default=128,
122 | help="Max length for the concatenated question and answer.")
123 | flags.DEFINE_integer("shuffle_buffer", default=2048,
124 | help="Buffer size used for shuffle.")
125 | flags.DEFINE_bool("uncased", default=False,
126 | help="Use uncased.")
127 | flags.DEFINE_bool("high_only", default=False,
128 | help="Evaluate on high school only.")
129 | flags.DEFINE_bool("middle_only", default=False,
130 | help="Evaluate on middle school only.")
131 |
132 | FLAGS = flags.FLAGS
133 |
134 | SEG_ID_A = 0
135 | SEG_ID_B = 1
136 | SEG_ID_CLS = 2
137 | SEG_ID_SEP = 3
138 | SEG_ID_PAD = 4
139 |
140 |
141 | class PaddingInputExample(object):
142 | """Fake example so the num input examples is a multiple of the batch size.
143 | When running eval/predict on the TPU, we need to pad the number of examples
144 | to be a multiple of the batch size, because the TPU requires a fixed batch
145 | size. The alternative is to drop the last batch, which is bad because it means
146 | the entire output data won't be generated.
147 | We use this class instead of `None` because treating `None` as padding
148 | battches could cause silent errors.
149 | """
150 |
151 |
152 | class InputFeatures(object):
153 | """A single set of features of data."""
154 |
155 | def __init__(self,
156 | input_ids,
157 | input_mask,
158 | segment_ids,
159 | label_id,
160 | is_real_example=True):
161 | self.input_ids = input_ids
162 | self.input_mask = input_mask
163 | self.segment_ids = segment_ids
164 | self.label_id = label_id
165 | self.is_real_example = is_real_example
166 |
167 |
168 | def convert_single_example(example, tokenize_fn):
169 | """Converts a single `InputExample` into a single `InputFeatures`."""
170 |
171 | if isinstance(example, PaddingInputExample):
172 | return InputFeatures(
173 | input_ids=[0] * FLAGS.max_seq_length * 4,
174 | input_mask=[1] * FLAGS.max_seq_length * 4,
175 | segment_ids=[0] * FLAGS.max_seq_length * 4,
176 | label_id=0,
177 | is_real_example=False)
178 |
179 | input_ids, input_mask, all_seg_ids = [], [], []
180 | tokens_context = tokenize_fn(example.context)
181 | for i in range(len(example.qa_list)):
182 | tokens_qa = tokenize_fn(example.qa_list[i])
183 | if len(tokens_qa) > FLAGS.max_qa_length:
184 | tokens_qa = tokens_qa[- FLAGS.max_qa_length:]
185 |
186 | if len(tokens_context) + len(tokens_qa) > FLAGS.max_seq_length - 3:
187 | tokens = tokens_context[: FLAGS.max_seq_length - 3 - len(tokens_qa)]
188 | else:
189 | tokens = tokens_context
190 |
191 | segment_ids = [SEG_ID_A] * len(tokens)
192 |
193 | tokens.append(SEP_ID)
194 | segment_ids.append(SEG_ID_A)
195 |
196 | tokens.extend(tokens_qa)
197 | segment_ids.extend([SEG_ID_B] * len(tokens_qa))
198 |
199 | tokens.append(SEP_ID)
200 | segment_ids.append(SEG_ID_B)
201 |
202 | tokens.append(CLS_ID)
203 | segment_ids.append(SEG_ID_CLS)
204 |
205 | cur_input_ids = tokens
206 | cur_input_mask = [0] * len(cur_input_ids)
207 |
208 | if len(cur_input_ids) < FLAGS.max_seq_length:
209 | delta_len = FLAGS.max_seq_length - len(cur_input_ids)
210 | cur_input_ids = [0] * delta_len + cur_input_ids
211 | cur_input_mask = [1] * delta_len + cur_input_mask
212 | segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
213 |
214 | assert len(cur_input_ids) == FLAGS.max_seq_length
215 | assert len(cur_input_mask) == FLAGS.max_seq_length
216 | assert len(segment_ids) == FLAGS.max_seq_length
217 |
218 | input_ids.extend(cur_input_ids)
219 | input_mask.extend(cur_input_mask)
220 | all_seg_ids.extend(segment_ids)
221 |
222 | label_id = example.label
223 |
224 | feature = InputFeatures(
225 | input_ids=input_ids,
226 | input_mask=input_mask,
227 | segment_ids=all_seg_ids,
228 | label_id=label_id)
229 | return feature
230 |
231 |
232 | class InputExample(object):
233 | def __init__(self, context, qa_list, label, level):
234 | self.context = context
235 | self.qa_list = qa_list
236 | self.label = label
237 | self.level = level
238 |
239 |
240 | def get_examples(data_dir, set_type):
241 | examples = []
242 |
243 | for level in ["middle", "high"]:
244 | if level == "middle" and FLAGS.high_only: continue
245 | if level == "high" and FLAGS.middle_only: continue
246 |
247 | cur_dir = os.path.join(data_dir, set_type, level)
248 | for filename in tf.gfile.ListDirectory(cur_dir):
249 | cur_path = os.path.join(cur_dir, filename)
250 | with tf.gfile.Open(cur_path) as f:
251 | cur_data = json.load(f)
252 |
253 | answers = cur_data["answers"]
254 | options = cur_data["options"]
255 | questions = cur_data["questions"]
256 | context = cur_data["article"]
257 |
258 | for i in range(len(answers)):
259 | label = ord(answers[i]) - ord("A")
260 | qa_list = []
261 |
262 | question = questions[i]
263 | for j in range(4):
264 | option = options[i][j]
265 |
266 | if "_" in question:
267 | qa_cat = question.replace("_", option)
268 | else:
269 | qa_cat = " ".join([question, option])
270 |
271 | qa_list.append(qa_cat)
272 |
273 | examples.append(InputExample(context, qa_list, label, level))
274 |
275 | return examples
276 |
277 |
278 | def file_based_convert_examples_to_features(examples, tokenize_fn, output_file):
279 | if tf.gfile.Exists(output_file) and not FLAGS.overwrite_data:
280 | return
281 |
282 | tf.logging.info("Start writing tfrecord %s.", output_file)
283 | writer = tf.python_io.TFRecordWriter(output_file)
284 |
285 | for ex_index, example in enumerate(examples):
286 | if ex_index % 10000 == 0:
287 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
288 |
289 | feature = convert_single_example(example, tokenize_fn)
290 |
291 | def create_int_feature(values):
292 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
293 | return f
294 |
295 | def create_float_feature(values):
296 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
297 | return f
298 |
299 | features = collections.OrderedDict()
300 | features["input_ids"] = create_int_feature(feature.input_ids)
301 | features["input_mask"] = create_float_feature(feature.input_mask)
302 | features["segment_ids"] = create_int_feature(feature.segment_ids)
303 | features["label_ids"] = create_int_feature([feature.label_id])
304 | features["is_real_example"] = create_int_feature(
305 | [int(feature.is_real_example)])
306 |
307 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
308 | writer.write(tf_example.SerializeToString())
309 | writer.close()
310 |
311 |
312 | def file_based_input_fn_builder(input_file, seq_length, is_training,
313 | drop_remainder):
314 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
315 |
316 | name_to_features = {
317 | "input_ids": tf.FixedLenFeature([seq_length * 4], tf.int64),
318 | "input_mask": tf.FixedLenFeature([seq_length * 4], tf.float32),
319 | "segment_ids": tf.FixedLenFeature([seq_length * 4], tf.int64),
320 | "label_ids": tf.FixedLenFeature([], tf.int64),
321 | "is_real_example": tf.FixedLenFeature([], tf.int64),
322 | }
323 |
324 | tf.logging.info("Input tfrecord file {}".format(input_file))
325 |
326 | def _decode_record(record, name_to_features):
327 | """Decodes a record to a TensorFlow example."""
328 | example = tf.parse_single_example(record, name_to_features)
329 |
330 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
331 | # So cast all int64 to int32.
332 | for name in list(example.keys()):
333 | t = example[name]
334 | if t.dtype == tf.int64:
335 | t = tf.cast(t, tf.int32)
336 | example[name] = t
337 |
338 | return example
339 |
340 | def input_fn(params):
341 | """The actual input function."""
342 | if FLAGS.use_tpu:
343 | batch_size = params["batch_size"]
344 | elif is_training:
345 | batch_size = FLAGS.train_batch_size
346 | elif FLAGS.do_eval:
347 | batch_size = FLAGS.eval_batch_size
348 |
349 | # For training, we want a lot of parallel reading and shuffling.
350 | # For eval, we want no shuffling and parallel reading doesn't matter.
351 | d = tf.data.TFRecordDataset(input_file)
352 | if is_training:
353 | d = d.shuffle(buffer_size=FLAGS.shuffle_buffer)
354 | d = d.repeat()
355 | # d = d.shuffle(buffer_size=100)
356 |
357 | d = d.apply(
358 | tf.contrib.data.map_and_batch(
359 | lambda record: _decode_record(record, name_to_features),
360 | batch_size=batch_size,
361 | drop_remainder=drop_remainder))
362 |
363 | return d
364 |
365 | return input_fn
366 |
367 |
368 | def get_model_fn():
369 | def model_fn(features, labels, mode, params):
370 | #### Training or Evaluation
371 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
372 |
373 | total_loss, per_example_loss, logits = function_builder.get_race_loss(
374 | FLAGS, features, is_training)
375 |
376 | #### Check model parameters
377 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
378 | tf.logging.info('#params: {}'.format(num_params))
379 |
380 | #### load pretrained models
381 | scaffold_fn = model_utils.init_from_checkpoint(FLAGS)
382 |
383 | #### Evaluation mode
384 | if mode == tf.estimator.ModeKeys.EVAL:
385 | assert FLAGS.num_hosts == 1
386 |
387 | def metric_fn(per_example_loss, label_ids, logits, is_real_example):
388 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
389 | eval_input_dict = {
390 | 'labels': label_ids,
391 | 'predictions': predictions,
392 | 'weights': is_real_example
393 | }
394 | accuracy = tf.metrics.accuracy(**eval_input_dict)
395 |
396 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)
397 | return {
398 | 'eval_accuracy': accuracy,
399 | 'eval_loss': loss}
400 |
401 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
402 |
403 | #### Constucting evaluation TPUEstimatorSpec with new cache.
404 | label_ids = tf.reshape(features['label_ids'], [-1])
405 | metric_args = [per_example_loss, label_ids, logits, is_real_example]
406 |
407 | if FLAGS.use_tpu:
408 | eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
409 | mode=mode,
410 | loss=total_loss,
411 | eval_metrics=(metric_fn, metric_args),
412 | scaffold_fn=scaffold_fn)
413 | else:
414 | eval_spec = tf.estimator.EstimatorSpec(
415 | mode=mode,
416 | loss=total_loss,
417 | eval_metric_ops=metric_fn(*metric_args))
418 |
419 | return eval_spec
420 |
421 |
422 | #### Configuring the optimizer
423 | train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss)
424 |
425 | monitor_dict = {}
426 | monitor_dict["lr"] = learning_rate
427 |
428 | #### Constucting training TPUEstimatorSpec with new cache.
429 | if FLAGS.use_tpu:
430 | #### Creating host calls
431 | host_call = None
432 |
433 | train_spec = tf.contrib.tpu.TPUEstimatorSpec(
434 | mode=mode, loss=total_loss, train_op=train_op, host_call=host_call,
435 | scaffold_fn=scaffold_fn)
436 | else:
437 | train_spec = tf.estimator.EstimatorSpec(
438 | mode=mode, loss=total_loss, train_op=train_op)
439 |
440 | return train_spec
441 |
442 | return model_fn
443 |
444 |
445 | def main(_):
446 | tf.logging.set_verbosity(tf.logging.INFO)
447 |
448 | #### Validate flags
449 | if FLAGS.save_steps is not None:
450 | FLAGS.iterations = min(FLAGS.iterations, FLAGS.save_steps)
451 |
452 | if not FLAGS.do_train and not FLAGS.do_eval:
453 | raise ValueError(
454 | "At least one of `do_train` or `do_eval` must be True.")
455 |
456 | if not tf.gfile.Exists(FLAGS.output_dir):
457 | tf.gfile.MakeDirs(FLAGS.output_dir)
458 |
459 | sp = spm.SentencePieceProcessor()
460 | sp.Load(FLAGS.spiece_model_file)
461 | def tokenize_fn(text):
462 | text = preprocess_text(text, lower=FLAGS.uncased)
463 | return encode_ids(sp, text)
464 |
465 | # TPU Configuration
466 | run_config = model_utils.configure_tpu(FLAGS)
467 |
468 | model_fn = get_model_fn()
469 |
470 | spm_basename = os.path.basename(FLAGS.spiece_model_file)
471 |
472 | # If TPU is not available, this will fall back to normal Estimator on CPU
473 | # or GPU.
474 | if FLAGS.use_tpu:
475 | estimator = tf.contrib.tpu.TPUEstimator(
476 | use_tpu=FLAGS.use_tpu,
477 | model_fn=model_fn,
478 | config=run_config,
479 | train_batch_size=FLAGS.train_batch_size,
480 | eval_batch_size=FLAGS.eval_batch_size)
481 | else:
482 | estimator = tf.estimator.Estimator(
483 | model_fn=model_fn,
484 | config=run_config)
485 |
486 | if FLAGS.do_train:
487 | train_file_base = "{}.len-{}.train.tf_record".format(
488 | spm_basename, FLAGS.max_seq_length)
489 | train_file = os.path.join(FLAGS.output_dir, train_file_base)
490 |
491 | if not tf.gfile.Exists(train_file) or FLAGS.overwrite_data:
492 | train_examples = get_examples(FLAGS.data_dir, "train")
493 | random.shuffle(train_examples)
494 | file_based_convert_examples_to_features(
495 | train_examples, tokenize_fn, train_file)
496 |
497 | train_input_fn = file_based_input_fn_builder(
498 | input_file=train_file,
499 | seq_length=FLAGS.max_seq_length,
500 | is_training=True,
501 | drop_remainder=True)
502 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
503 |
504 | if FLAGS.do_eval:
505 | eval_examples = get_examples(FLAGS.data_dir, FLAGS.eval_split)
506 | tf.logging.info("Num of eval samples: {}".format(len(eval_examples)))
507 |
508 | # TPU requires a fixed batch size for all batches, therefore the number
509 | # of examples must be a multiple of the batch size, or else examples
510 | # will get dropped. So we pad with fake examples which are ignored
511 | # later on. These do NOT count towards the metric (all tf.metrics
512 | # support a per-instance weight, and these get a weight of 0.0).
513 | #
514 | # Modified in XL: We also adopt the same mechanism for GPUs.
515 |
516 | while len(eval_examples) % FLAGS.eval_batch_size != 0:
517 | eval_examples.append(PaddingInputExample())
518 |
519 | eval_file_base = "{}.len-{}.{}.tf_record".format(
520 | spm_basename, FLAGS.max_seq_length, FLAGS.eval_split)
521 |
522 | if FLAGS.high_only:
523 | eval_file_base = "high." + eval_file_base
524 | elif FLAGS.middle_only:
525 | eval_file_base = "middle." + eval_file_base
526 |
527 | eval_file = os.path.join(FLAGS.output_dir, eval_file_base)
528 | file_based_convert_examples_to_features(
529 | eval_examples, tokenize_fn, eval_file)
530 |
531 | assert len(eval_examples) % FLAGS.eval_batch_size == 0
532 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
533 |
534 | eval_input_fn = file_based_input_fn_builder(
535 | input_file=eval_file,
536 | seq_length=FLAGS.max_seq_length,
537 | is_training=False,
538 | drop_remainder=True)
539 |
540 | ret = estimator.evaluate(
541 | input_fn=eval_input_fn,
542 | steps=eval_steps)
543 |
544 | # Log current result
545 | tf.logging.info("=" * 80)
546 | log_str = "Eval | "
547 | for key, val in ret.items():
548 | log_str += "{} {} | ".format(key, val)
549 | tf.logging.info(log_str)
550 | tf.logging.info("=" * 80)
551 |
552 |
553 | if __name__ == "__main__":
554 | tf.app.run()
555 |
--------------------------------------------------------------------------------
/scripts/gpu_squad_base.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #### local path
4 | SQUAD_DIR=data/squad
5 | INIT_CKPT_DIR=xlnet_cased_L-12_H-768_A-12
6 | PROC_DATA_DIR=proc_data/squad
7 | MODEL_DIR=experiment/squad
8 |
9 | #### Use 3 GPUs, each with 8 seqlen-512 samples
10 |
11 | python run_squad.py \
12 | --use_tpu=False \
13 | --num_hosts=1 \
14 | --num_core_per_host=3 \
15 | --model_config_path=${INIT_CKPT_DIR}/xlnet_config.json \
16 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \
17 | --output_dir=${PROC_DATA_DIR} \
18 | --init_checkpoint=${INIT_CKPT_DIR}/xlnet_model.ckpt \
19 | --model_dir=${MODEL_DIR} \
20 | --train_file=${SQUAD_DIR}/train-v2.0.json \
21 | --predict_file=${SQUAD_DIR}/dev-v2.0.json \
22 | --uncased=False \
23 | --max_seq_length=512 \
24 | --do_train=True \
25 | --train_batch_size=8 \
26 | --do_predict=True \
27 | --predict_batch_size=32 \
28 | --learning_rate=2e-5 \
29 | --adam_epsilon=1e-6 \
30 | --iterations=1000 \
31 | --save_steps=1000 \
32 | --train_steps=12000 \
33 | --warmup_steps=1000 \
34 | $@
35 |
--------------------------------------------------------------------------------
/scripts/prepro_squad.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #### local path
4 | SQUAD_DIR=data/squad
5 | INIT_CKPT_DIR=xlnet_cased_L-24_H-1024_A-16
6 |
7 | #### google storage path
8 | GS_ROOT=
9 | GS_PROC_DATA_DIR=${GS_ROOT}/proc_data/squad
10 |
11 | python run_squad.py \
12 | --use_tpu=False \
13 | --do_prepro=True \
14 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \
15 | --train_file=${SQUAD_DIR}/train-v2.0.json \
16 | --output_dir=${GS_PROC_DATA_DIR} \
17 | --uncased=False \
18 | --max_seq_length=512 \
19 | $@
20 |
21 | #### Potential multi-processing version
22 | # NUM_PROC=8
23 | # for i in `seq 0 $((NUM_PROC - 1))`; do
24 | # python run_squad.py \
25 | # --use_tpu=False \
26 | # --do_prepro=True \
27 | # --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \
28 | # --train_file=${SQUAD_DIR}/train-v2.0.json \
29 | # --output_dir=${GS_PROC_DATA_DIR} \
30 | # --uncased=False \
31 | # --max_seq_length=512 \
32 | # --num_proc=${NUM_PROC} \
33 | # --proc_id=${i} \
34 | # $@ &
35 | # done
36 |
--------------------------------------------------------------------------------
/scripts/tpu_race_large_bsz32.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #### local path
4 | RACE_DIR=data/RACE
5 | INIT_CKPT_DIR=xlnet_cased_L-24_H-1024_A-16
6 |
7 | #### google storage path
8 | GS_ROOT=
9 | GS_INIT_CKPT_DIR=${GS_ROOT}/${INIT_CKPT_DIR}
10 | GS_PROC_DATA_DIR=${GS_ROOT}/proc_data/race
11 | GS_MODEL_DIR=${GS_ROOT}/experiment/race
12 |
13 | # TPU name in google cloud
14 | TPU_NAME=
15 |
16 | python run_race.py \
17 | --use_tpu=True \
18 | --tpu=${TPU_NAME} \
19 | --num_hosts=4 \
20 | --num_core_per_host=8 \
21 | --model_config_path=${INIT_CKPT_DIR}/xlnet_config.json \
22 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \
23 | --output_dir=${GS_PROC_DATA_DIR} \
24 | --init_checkpoint=${GS_INIT_CKPT_DIR}/xlnet_model.ckpt \
25 | --model_dir=${GS_MODEL_DIR} \
26 | --data_dir=${RACE_DIR} \
27 | --max_seq_length=512 \
28 | --max_qa_length=128 \
29 | --uncased=False \
30 | --do_train=True \
31 | --train_batch_size=32 \
32 | --do_eval=True \
33 | --eval_batch_size=32 \
34 | --train_steps=12000 \
35 | --save_steps=1000 \
36 | --iterations=1000 \
37 | --warmup_steps=1000 \
38 | --learning_rate=2e-5 \
39 | --weight_decay=0 \
40 | --adam_epsilon=1e-6 \
41 | $@
42 |
--------------------------------------------------------------------------------
/scripts/tpu_race_large_bsz8.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #### local path
4 | RACE_DIR=data/RACE
5 | INIT_CKPT_DIR=xlnet_cased_L-24_H-1024_A-16
6 |
7 | #### google storage path
8 | GS_ROOT=
9 | GS_INIT_CKPT_DIR=${GS_ROOT}/${INIT_CKPT_DIR}
10 | GS_PROC_DATA_DIR=${GS_ROOT}/proc_data/race
11 | GS_MODEL_DIR=${GS_ROOT}/experiment/race
12 |
13 | # TPU name in google cloud
14 | TPU_NAME=
15 |
16 | python run_race.py \
17 | --use_tpu=True \
18 | --tpu=${TPU_NAME} \
19 | --num_hosts=1 \
20 | --num_core_per_host=8 \
21 | --model_config_path=${INIT_CKPT_DIR}/xlnet_config.json \
22 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \
23 | --output_dir=${GS_PROC_DATA_DIR} \
24 | --init_checkpoint=${GS_INIT_CKPT_DIR}/xlnet_model.ckpt \
25 | --model_dir=${GS_MODEL_DIR} \
26 | --data_dir=${RACE_DIR} \
27 | --max_seq_length=512 \
28 | --max_qa_length=128 \
29 | --uncased=False \
30 | --do_train=True \
31 | --train_batch_size=8 \
32 | --do_eval=True \
33 | --eval_batch_size=32 \
34 | --train_steps=12000 \
35 | --save_steps=1000 \
36 | --iterations=1000 \
37 | --warmup_steps=1000 \
38 | --learning_rate=2e-5 \
39 | --weight_decay=0 \
40 | --adam_epsilon=1e-6 \
41 | $@
42 |
--------------------------------------------------------------------------------
/scripts/tpu_squad_large.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #### local path
4 | SQUAD_DIR=data/squad
5 | INIT_CKPT_DIR=xlnet_cased_L-24_H-1024_A-16
6 |
7 | #### google storage path
8 | GS_ROOT=
9 | GS_INIT_CKPT_DIR=${GS_ROOT}/${INIT_CKPT_DIR}
10 | GS_PROC_DATA_DIR=${GS_ROOT}/proc_data/squad
11 | GS_MODEL_DIR=${GS_ROOT}/experiment/squad
12 |
13 | # TPU name in google cloud
14 | TPU_NAME=
15 |
16 | python run_squad.py \
17 | --use_tpu=True \
18 | --tpu=${TPU_NAME} \
19 | --num_hosts=1 \
20 | --num_core_per_host=8 \
21 | --model_config_path=${INIT_CKPT_DIR}/xlnet_config.json \
22 | --spiece_model_file=${INIT_CKPT_DIR}/spiece.model \
23 | --output_dir=${GS_PROC_DATA_DIR} \
24 | --init_checkpoint=${GS_INIT_CKPT_DIR}/xlnet_model.ckpt \
25 | --model_dir=${GS_MODEL_DIR} \
26 | --train_file=${SQUAD_DIR}/train-v2.0.json \
27 | --predict_file=${SQUAD_DIR}/dev-v2.0.json \
28 | --uncased=False \
29 | --max_seq_length=512 \
30 | --do_train=True \
31 | --train_batch_size=48 \
32 | --do_predict=True \
33 | --predict_batch_size=32 \
34 | --learning_rate=3e-5 \
35 | --adam_epsilon=1e-6 \
36 | --iterations=1000 \
37 | --save_steps=1000 \
38 | --train_steps=8000 \
39 | --warmup_steps=1000 \
40 | $@
41 |
--------------------------------------------------------------------------------
/squad_utils.py:
--------------------------------------------------------------------------------
1 | """Official evaluation script for SQuAD version 2.0.
2 |
3 | In addition to basic functionality, we also compute additional statistics and
4 | plot precision-recall curves if an additional na_prob.json file is provided.
5 | This file is expected to map question ID's to the model's predicted probability
6 | that a question is unanswerable.
7 | """
8 | import argparse
9 | import collections
10 | import json
11 | import numpy as np
12 | import os
13 | import re
14 | import string
15 | import sys
16 |
17 | OPTS = None
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.')
21 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.')
22 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.')
23 | parser.add_argument('--out-file', '-o', metavar='eval.json',
24 | help='Write accuracy metrics to file (default is stdout).')
25 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json',
26 | help='Model estimates of probability of no answer.')
27 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0,
28 | help='Predict "" if no-answer probability exceeds this (default = 1.0).')
29 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None,
30 | help='Save precision-recall curves to directory.')
31 | parser.add_argument('--verbose', '-v', action='store_true')
32 | if len(sys.argv) == 1:
33 | parser.print_help()
34 | sys.exit(1)
35 | return parser.parse_args()
36 |
37 | def make_qid_to_has_ans(dataset):
38 | qid_to_has_ans = {}
39 | for article in dataset:
40 | for p in article['paragraphs']:
41 | for qa in p['qas']:
42 | qid_to_has_ans[qa['id']] = bool(qa['answers'])
43 | return qid_to_has_ans
44 |
45 | def normalize_answer(s):
46 | """Lower text and remove punctuation, articles and extra whitespace."""
47 | def remove_articles(text):
48 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
49 | return re.sub(regex, ' ', text)
50 | def white_space_fix(text):
51 | return ' '.join(text.split())
52 | def remove_punc(text):
53 | exclude = set(string.punctuation)
54 | return ''.join(ch for ch in text if ch not in exclude)
55 | def lower(text):
56 | return text.lower()
57 | return white_space_fix(remove_articles(remove_punc(lower(s))))
58 |
59 | def get_tokens(s):
60 | if not s: return []
61 | return normalize_answer(s).split()
62 |
63 | def compute_exact(a_gold, a_pred):
64 | return int(normalize_answer(a_gold) == normalize_answer(a_pred))
65 |
66 | def compute_f1(a_gold, a_pred):
67 | gold_toks = get_tokens(a_gold)
68 | pred_toks = get_tokens(a_pred)
69 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
70 | num_same = sum(common.values())
71 | if len(gold_toks) == 0 or len(pred_toks) == 0:
72 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
73 | return int(gold_toks == pred_toks)
74 | if num_same == 0:
75 | return 0
76 | precision = 1.0 * num_same / len(pred_toks)
77 | recall = 1.0 * num_same / len(gold_toks)
78 | f1 = (2 * precision * recall) / (precision + recall)
79 | return f1
80 |
81 | def get_raw_scores(dataset, preds):
82 | exact_scores = {}
83 | f1_scores = {}
84 | for article in dataset:
85 | for p in article['paragraphs']:
86 | for qa in p['qas']:
87 | qid = qa['id']
88 | gold_answers = [a['text'] for a in qa['answers']
89 | if normalize_answer(a['text'])]
90 | if not gold_answers:
91 | # For unanswerable questions, only correct answer is empty string
92 | gold_answers = ['']
93 | if qid not in preds:
94 | print('Missing prediction for %s' % qid)
95 | continue
96 | a_pred = preds[qid]
97 | # Take max over all gold answers
98 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
99 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
100 | return exact_scores, f1_scores
101 |
102 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
103 | new_scores = {}
104 | for qid, s in scores.items():
105 | pred_na = na_probs[qid] > na_prob_thresh
106 | if pred_na:
107 | new_scores[qid] = float(not qid_to_has_ans[qid])
108 | else:
109 | new_scores[qid] = s
110 | return new_scores
111 |
112 | def make_eval_dict(exact_scores, f1_scores, qid_list=None):
113 | if not qid_list:
114 | total = len(exact_scores)
115 | return collections.OrderedDict([
116 | ('exact', 100.0 * sum(exact_scores.values()) / total),
117 | ('f1', 100.0 * sum(f1_scores.values()) / total),
118 | ('total', total),
119 | ])
120 | else:
121 | total = len(qid_list)
122 | return collections.OrderedDict([
123 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
124 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
125 | ('total', total),
126 | ])
127 |
128 | def merge_eval(main_eval, new_eval, prefix):
129 | for k in new_eval:
130 | main_eval['%s_%s' % (prefix, k)] = new_eval[k]
131 |
132 | def plot_pr_curve(precisions, recalls, out_image, title):
133 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post')
134 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b')
135 | plt.xlabel('Recall')
136 | plt.ylabel('Precision')
137 | plt.xlim([0.0, 1.05])
138 | plt.ylim([0.0, 1.05])
139 | plt.title(title)
140 | plt.savefig(out_image)
141 | plt.clf()
142 |
143 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans,
144 | out_image=None, title=None):
145 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
146 | true_pos = 0.0
147 | cur_p = 1.0
148 | cur_r = 0.0
149 | precisions = [1.0]
150 | recalls = [0.0]
151 | avg_prec = 0.0
152 | for i, qid in enumerate(qid_list):
153 | if qid_to_has_ans[qid]:
154 | true_pos += scores[qid]
155 | cur_p = true_pos / float(i+1)
156 | cur_r = true_pos / float(num_true_pos)
157 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
158 | # i.e., if we can put a threshold after this point
159 | avg_prec += cur_p * (cur_r - recalls[-1])
160 | precisions.append(cur_p)
161 | recalls.append(cur_r)
162 | if out_image:
163 | plot_pr_curve(precisions, recalls, out_image, title)
164 | return {'ap': 100.0 * avg_prec}
165 |
166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs,
167 | qid_to_has_ans, out_image_dir):
168 | if out_image_dir and not os.path.exists(out_image_dir):
169 | os.makedirs(out_image_dir)
170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
171 | if num_true_pos == 0:
172 | return
173 | pr_exact = make_precision_recall_eval(
174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans,
175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'),
176 | title='Precision-Recall curve for Exact Match score')
177 | pr_f1 = make_precision_recall_eval(
178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans,
179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'),
180 | title='Precision-Recall curve for F1 score')
181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
182 | pr_oracle = make_precision_recall_eval(
183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans,
184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'),
185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)')
186 | merge_eval(main_eval, pr_exact, 'pr_exact')
187 | merge_eval(main_eval, pr_f1, 'pr_f1')
188 | merge_eval(main_eval, pr_oracle, 'pr_oracle')
189 |
190 | def histogram_na_prob(na_probs, qid_list, image_dir, name):
191 | if not qid_list:
192 | return
193 | x = [na_probs[k] for k in qid_list]
194 | weights = np.ones_like(x) / float(len(x))
195 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0))
196 | plt.xlabel('Model probability of no-answer')
197 | plt.ylabel('Proportion of dataset')
198 | plt.title('Histogram of no-answer probability: %s' % name)
199 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name))
200 | plt.clf()
201 |
202 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
203 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
204 | cur_score = num_no_ans
205 | best_score = cur_score
206 | best_thresh = 0.0
207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
208 | for i, qid in enumerate(qid_list):
209 | if qid not in scores: continue
210 | if qid_to_has_ans[qid]:
211 | diff = scores[qid]
212 | else:
213 | if preds[qid]:
214 | diff = -1
215 | else:
216 | diff = 0
217 | cur_score += diff
218 | if cur_score > best_score:
219 | best_score = cur_score
220 | best_thresh = na_probs[qid]
221 | return 100.0 * best_score / len(scores), best_thresh
222 |
223 | def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
224 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
225 | cur_score = num_no_ans
226 | best_score = cur_score
227 | best_thresh = 0.0
228 | qid_list = sorted(na_probs, key=lambda k: na_probs[k])
229 | for i, qid in enumerate(qid_list):
230 | if qid not in scores: continue
231 | if qid_to_has_ans[qid]:
232 | diff = scores[qid]
233 | else:
234 | if preds[qid]:
235 | diff = -1
236 | else:
237 | diff = 0
238 | cur_score += diff
239 | if cur_score > best_score:
240 | best_score = cur_score
241 | best_thresh = na_probs[qid]
242 |
243 | has_ans_score, has_ans_cnt = 0, 0
244 | for qid in qid_list:
245 | if not qid_to_has_ans[qid]: continue
246 | has_ans_cnt += 1
247 |
248 | if qid not in scores: continue
249 | has_ans_score += scores[qid]
250 |
251 | return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
252 |
253 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
254 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
255 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
256 | main_eval['best_exact'] = best_exact
257 | main_eval['best_exact_thresh'] = exact_thresh
258 | main_eval['best_f1'] = best_f1
259 | main_eval['best_f1_thresh'] = f1_thresh
260 |
261 | def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
262 | best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
263 | best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
264 | main_eval['best_exact'] = best_exact
265 | main_eval['best_exact_thresh'] = exact_thresh
266 | main_eval['best_f1'] = best_f1
267 | main_eval['best_f1_thresh'] = f1_thresh
268 | main_eval['has_ans_exact'] = has_ans_exact
269 | main_eval['has_ans_f1'] = has_ans_f1
270 |
271 | def main():
272 | with open(OPTS.data_file) as f:
273 | dataset_json = json.load(f)
274 | dataset = dataset_json['data']
275 | with open(OPTS.pred_file) as f:
276 | preds = json.load(f)
277 |
278 | new_orig_data = []
279 | for article in dataset:
280 | for p in article['paragraphs']:
281 | for qa in p['qas']:
282 | if qa['id'] in preds:
283 | new_para = {'qas': [qa]}
284 | new_article = {'paragraphs': [new_para]}
285 | new_orig_data.append(new_article)
286 | dataset = new_orig_data
287 |
288 | if OPTS.na_prob_file:
289 | with open(OPTS.na_prob_file) as f:
290 | na_probs = json.load(f)
291 | else:
292 | na_probs = {k: 0.0 for k in preds}
293 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False
294 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
295 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
296 | exact_raw, f1_raw = get_raw_scores(dataset, preds)
297 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans,
298 | OPTS.na_prob_thresh)
299 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans,
300 | OPTS.na_prob_thresh)
301 | out_eval = make_eval_dict(exact_thresh, f1_thresh)
302 | if has_ans_qids:
303 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids)
304 | merge_eval(out_eval, has_ans_eval, 'HasAns')
305 | if no_ans_qids:
306 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
307 | merge_eval(out_eval, no_ans_eval, 'NoAns')
308 | if OPTS.na_prob_file:
309 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans)
310 | if OPTS.na_prob_file and OPTS.out_image_dir:
311 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs,
312 | qid_to_has_ans, OPTS.out_image_dir)
313 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns')
314 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns')
315 | if OPTS.out_file:
316 | with open(OPTS.out_file, 'w') as f:
317 | json.dump(out_eval, f)
318 | else:
319 | print(json.dumps(out_eval, indent=2))
320 |
321 | if __name__ == '__main__':
322 | OPTS = parse_args()
323 | if OPTS.out_image_dir:
324 | import matplotlib
325 | matplotlib.use('Agg')
326 | import matplotlib.pyplot as plt
327 | main()
328 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """Pretraining on TPUs."""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import os
7 |
8 | from absl import app
9 | from absl import flags
10 | import absl.logging as _logging # pylint: disable=unused-import
11 |
12 | import numpy as np
13 |
14 | import tensorflow as tf
15 | import model_utils
16 | import tpu_estimator
17 | import function_builder
18 | import data_utils
19 |
20 | # TPU parameters
21 | flags.DEFINE_string("master", default=None,
22 | help="master")
23 | flags.DEFINE_string("tpu", default=None,
24 | help="The Cloud TPU to use for training. This should be either the name "
25 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.")
26 | flags.DEFINE_string("gcp_project", default=None,
27 | help="Project name for the Cloud TPU-enabled project. If not specified, "
28 | "we will attempt to automatically detect the GCE project from metadata.")
29 | flags.DEFINE_string("tpu_zone",default=None,
30 | help="GCE zone where the Cloud TPU is located in. If not specified, we "
31 | "will attempt to automatically detect the GCE project from metadata.")
32 | flags.DEFINE_bool("use_tpu", default=True,
33 | help="Use TPUs rather than plain CPUs.")
34 | flags.DEFINE_integer("num_hosts", default=1,
35 | help="number of TPU hosts")
36 | flags.DEFINE_integer("num_core_per_host", default=8,
37 | help="number of cores per host")
38 | flags.DEFINE_bool("track_mean", default=False,
39 | help="Whether to track mean loss.")
40 |
41 | # Experiment (data/checkpoint/directory) config
42 | flags.DEFINE_integer("num_passes", default=1,
43 | help="Number of passed used for training.")
44 | flags.DEFINE_string("record_info_dir", default=None,
45 | help="Path to local directory containing `record_info-lm.json`.")
46 | flags.DEFINE_string("model_dir", default=None,
47 | help="Estimator model_dir.")
48 | flags.DEFINE_string("init_checkpoint", default=None,
49 | help="Checkpoint path for initializing the model.")
50 |
51 | # Optimization config
52 | flags.DEFINE_float("learning_rate", default=1e-4,
53 | help="Maximum learning rate.")
54 | flags.DEFINE_float("clip", default=1.0,
55 | help="Gradient clipping value.")
56 | # lr decay
57 | flags.DEFINE_float("min_lr_ratio", default=0.001,
58 | help="Minimum ratio learning rate.")
59 | flags.DEFINE_integer("warmup_steps", default=0,
60 | help="Number of steps for linear lr warmup.")
61 | flags.DEFINE_float("adam_epsilon", default=1e-8,
62 | help="Adam epsilon.")
63 | flags.DEFINE_string("decay_method", default="poly",
64 | help="Poly or cos.")
65 | flags.DEFINE_float("weight_decay", default=0.0,
66 | help="Weight decay rate.")
67 |
68 | # Training config
69 | flags.DEFINE_integer("train_batch_size", default=16,
70 | help="Size of the train batch across all hosts.")
71 | flags.DEFINE_integer("train_steps", default=100000,
72 | help="Total number of training steps.")
73 | flags.DEFINE_integer("iterations", default=1000,
74 | help="Number of iterations per repeat loop.")
75 | flags.DEFINE_integer("save_steps", default=None,
76 | help="Number of steps for model checkpointing. "
77 | "None for not saving checkpoints")
78 | flags.DEFINE_integer("max_save", default=100000,
79 | help="Maximum number of checkpoints to save.")
80 |
81 | # Data config
82 | flags.DEFINE_integer("seq_len", default=0,
83 | help="Sequence length for pretraining.")
84 | flags.DEFINE_integer("reuse_len", default=0,
85 | help="How many tokens to be reused in the next batch. "
86 | "Could be half of `seq_len`.")
87 | flags.DEFINE_bool("uncased", False,
88 | help="Use uncased inputs or not.")
89 | flags.DEFINE_integer("perm_size", 0,
90 | help="Window size of permutation.")
91 | flags.DEFINE_bool("bi_data", default=True,
92 | help="Use bidirectional data streams, i.e., forward & backward.")
93 | flags.DEFINE_integer("mask_alpha", default=6,
94 | help="How many tokens to form a group.")
95 | flags.DEFINE_integer("mask_beta", default=1,
96 | help="How many tokens to mask within each group.")
97 | flags.DEFINE_integer("num_predict", default=None,
98 | help="Number of tokens to predict in partial prediction.")
99 | flags.DEFINE_integer("n_token", 32000, help="Vocab size")
100 |
101 | # Model config
102 | flags.DEFINE_integer("mem_len", default=0,
103 | help="Number of steps to cache")
104 | flags.DEFINE_bool("same_length", default=False,
105 | help="Same length attention")
106 | flags.DEFINE_integer("clamp_len", default=-1,
107 | help="Clamp length")
108 |
109 | flags.DEFINE_integer("n_layer", default=6,
110 | help="Number of layers.")
111 | flags.DEFINE_integer("d_model", default=32,
112 | help="Dimension of the model.")
113 | flags.DEFINE_integer("d_embed", default=32,
114 | help="Dimension of the embeddings.")
115 | flags.DEFINE_integer("n_head", default=4,
116 | help="Number of attention heads.")
117 | flags.DEFINE_integer("d_head", default=8,
118 | help="Dimension of each attention head.")
119 | flags.DEFINE_integer("d_inner", default=32,
120 | help="Dimension of inner hidden size in positionwise feed-forward.")
121 | flags.DEFINE_float("dropout", default=0.0,
122 | help="Dropout rate.")
123 | flags.DEFINE_float("dropatt", default=0.0,
124 | help="Attention dropout rate.")
125 | flags.DEFINE_bool("untie_r", default=False,
126 | help="Untie r_w_bias and r_r_bias")
127 | flags.DEFINE_string("summary_type", default="last",
128 | help="Method used to summarize a sequence into a compact vector.")
129 | flags.DEFINE_string("ff_activation", default="relu",
130 | help="Activation type used in position-wise feed-forward.")
131 | flags.DEFINE_bool("use_bfloat16", False,
132 | help="Whether to use bfloat16.")
133 |
134 | # Parameter initialization
135 | flags.DEFINE_enum("init", default="normal",
136 | enum_values=["normal", "uniform"],
137 | help="Initialization method.")
138 | flags.DEFINE_float("init_std", default=0.02,
139 | help="Initialization std when init is normal.")
140 | flags.DEFINE_float("init_range", default=0.1,
141 | help="Initialization std when init is uniform.")
142 |
143 | FLAGS = flags.FLAGS
144 |
145 |
146 | def get_model_fn():
147 | """doc."""
148 | def model_fn(features, labels, mode, params):
149 | """doc."""
150 | #### Training or Evaluation
151 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
152 | assert is_training
153 |
154 | #### Retrieve `mems` from `params["cache"]`
155 | mems = {}
156 | idx = 0
157 | if FLAGS.mem_len > 0:
158 | mems["mems"] = params["cache"]
159 |
160 | #### Get loss from inputs
161 | total_loss, new_mems, monitor_dict = function_builder.get_loss(
162 | FLAGS, features, labels, mems, is_training)
163 |
164 | #### Turn `new_mems` into `new_cache`
165 | new_cache = []
166 | if FLAGS.mem_len > 0:
167 | new_cache += new_mems["mems"]
168 |
169 | #### Check model parameters
170 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
171 | tf.logging.info("#params: {}".format(num_params))
172 |
173 | #### Configuring the optimizer
174 | train_op, learning_rate, gnorm = model_utils.get_train_op(
175 | FLAGS, total_loss)
176 | monitor_dict["lr"] = learning_rate
177 | monitor_dict["gnorm"] = gnorm
178 |
179 | #### Customized initial checkpoint
180 | scaffold_fn = model_utils.init_from_checkpoint(FLAGS, global_vars=True)
181 |
182 | #### Creating host calls
183 | host_call = function_builder.construct_scalar_host_call(
184 | monitor_dict=monitor_dict,
185 | model_dir=FLAGS.model_dir,
186 | prefix="train/",
187 | reduce_fn=tf.reduce_mean)
188 |
189 | #### Constucting training TPUEstimatorSpec with new cache.
190 | train_spec = tf.contrib.tpu.TPUEstimatorSpec(
191 | mode=mode, loss=total_loss, train_op=train_op, host_call=host_call,
192 | scaffold_fn=scaffold_fn)
193 |
194 | train_spec.cache = new_cache
195 |
196 | return train_spec
197 |
198 | return model_fn
199 |
200 |
201 | def get_cache_fn(mem_len):
202 | """doc."""
203 | tf_float = tf.bfloat16 if FLAGS.use_bfloat16 else tf.float32
204 | def cache_fn(batch_size):
205 | mems = []
206 | if FLAGS.mem_len > 0:
207 | for _ in range(FLAGS.n_layer):
208 | zeros = tf.zeros(
209 | [mem_len, batch_size, FLAGS.d_model],
210 | dtype=tf_float)
211 | mems.append(zeros)
212 |
213 | return mems
214 |
215 | if mem_len > 0:
216 | return cache_fn
217 | else:
218 | return None
219 |
220 |
221 | def get_input_fn(split):
222 | """doc."""
223 | assert split == "train"
224 | batch_size = FLAGS.train_batch_size
225 |
226 | input_fn, record_info_dict = data_utils.get_input_fn(
227 | tfrecord_dir=FLAGS.record_info_dir,
228 | split=split,
229 | bsz_per_host=batch_size // FLAGS.num_hosts,
230 | seq_len=FLAGS.seq_len,
231 | reuse_len=FLAGS.reuse_len,
232 | bi_data=FLAGS.bi_data,
233 | num_hosts=FLAGS.num_hosts,
234 | num_core_per_host=FLAGS.num_core_per_host,
235 | perm_size=FLAGS.perm_size,
236 | mask_alpha=FLAGS.mask_alpha,
237 | mask_beta=FLAGS.mask_beta,
238 | uncased=FLAGS.uncased,
239 | num_passes=FLAGS.num_passes,
240 | use_bfloat16=FLAGS.use_bfloat16,
241 | num_predict=FLAGS.num_predict)
242 |
243 | return input_fn, record_info_dict
244 |
245 |
246 | def main(unused_argv):
247 | del unused_argv # Unused
248 |
249 | tf.logging.set_verbosity(tf.logging.INFO)
250 |
251 | assert FLAGS.seq_len > 0
252 | assert FLAGS.perm_size > 0
253 |
254 | FLAGS.n_token = data_utils.VOCAB_SIZE
255 | tf.logging.info("n_token {}".format(FLAGS.n_token))
256 |
257 | if not tf.gfile.Exists(FLAGS.model_dir):
258 | tf.gfile.MakeDirs(FLAGS.model_dir)
259 |
260 | # Get train input function
261 | train_input_fn, train_record_info_dict = get_input_fn("train")
262 |
263 | tf.logging.info("num of batches {}".format(
264 | train_record_info_dict["num_batch"]))
265 |
266 | # Get train cache function
267 | train_cache_fn = get_cache_fn(FLAGS.mem_len)
268 |
269 | ##### Get model function
270 | model_fn = get_model_fn()
271 |
272 | ##### Create TPUEstimator
273 | # TPU Configuration
274 | run_config = model_utils.configure_tpu(FLAGS)
275 |
276 | # TPU Estimator
277 | estimator = tpu_estimator.TPUEstimator(
278 | model_fn=model_fn,
279 | train_cache_fn=train_cache_fn,
280 | use_tpu=FLAGS.use_tpu,
281 | config=run_config,
282 | params={"track_mean": FLAGS.track_mean},
283 | train_batch_size=FLAGS.train_batch_size,
284 | eval_on_tpu=FLAGS.use_tpu)
285 |
286 | #### Training
287 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
288 |
289 |
290 | if __name__ == "__main__":
291 | app.run(main)
292 |
--------------------------------------------------------------------------------
/train_gpu.py:
--------------------------------------------------------------------------------
1 | """Pretraining on GPUs."""
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import os, sys
7 | import math
8 | import json
9 | import time
10 | import numpy as np
11 |
12 | from absl import flags
13 | import absl.logging as _logging # pylint: disable=unused-import
14 |
15 | import tensorflow as tf
16 |
17 | import data_utils
18 | import model_utils
19 | from gpu_utils import assign_to_gpu, average_grads_and_vars
20 | import function_builder
21 |
22 |
23 | # GPU config
24 | flags.DEFINE_integer("num_hosts", default=1,
25 | help="Number of hosts")
26 | flags.DEFINE_integer("num_core_per_host", default=8,
27 | help="Number of cores per host")
28 | flags.DEFINE_bool("use_tpu", default=False,
29 | help="Whether to use TPUs for training.")
30 |
31 | # Experiment (data/checkpoint/directory) config
32 | flags.DEFINE_integer("num_passes", default=1,
33 | help="Number of passed used for training.")
34 | flags.DEFINE_string("record_info_dir", default=None,
35 | help="Path to local directory containing `record_info-lm.json`.")
36 | flags.DEFINE_string("model_dir", default=None,
37 | help="Estimator model_dir.")
38 | flags.DEFINE_string("init_checkpoint", default=None,
39 | help="checkpoint path for initializing the model.")
40 |
41 | # Optimization config
42 | flags.DEFINE_float("learning_rate", default=1e-4,
43 | help="Maximum learning rate.")
44 | flags.DEFINE_float("clip", default=1.0,
45 | help="Gradient clipping value.")
46 | # for cosine decay
47 | flags.DEFINE_float("min_lr_ratio", default=0.001,
48 | help="Minimum ratio learning rate.")
49 | flags.DEFINE_integer("warmup_steps", default=0,
50 | help="Number of steps for linear lr warmup.")
51 | flags.DEFINE_float("adam_epsilon", default=1e-8,
52 | help="Adam epsilon")
53 | flags.DEFINE_string("decay_method", default="poly",
54 | help="poly or cos")
55 | flags.DEFINE_float("weight_decay", default=0.0,
56 | help="weight decay")
57 |
58 | # Training config
59 | flags.DEFINE_integer("train_batch_size", default=16,
60 | help="Size of train batch.")
61 | flags.DEFINE_integer("train_steps", default=100000,
62 | help="Total number of training steps.")
63 | flags.DEFINE_integer("iterations", default=1000,
64 | help="Number of iterations per repeat loop.")
65 | flags.DEFINE_integer("save_steps", default=None,
66 | help="number of steps for model checkpointing.")
67 |
68 | # Data config
69 | flags.DEFINE_integer('seq_len', default=0,
70 | help='Sequence length for pretraining.')
71 | flags.DEFINE_integer('reuse_len', default=0,
72 | help="How many tokens to be reused in the next batch. "
73 | "Could be half of seq_len")
74 | flags.DEFINE_bool("bi_data", default=True,
75 | help="Use bidirectional data streams, i.e., forward & backward.")
76 | flags.DEFINE_integer("mask_alpha", default=6,
77 | help="How many tokens to form a group.")
78 | flags.DEFINE_integer("mask_beta", default=1,
79 | help="How many tokens to mask within each group.")
80 | flags.DEFINE_integer("num_predict", default=None,
81 | help="Number of tokens to predict in partial prediction.")
82 | flags.DEFINE_integer('perm_size', default=None,
83 | help='perm size.')
84 | flags.DEFINE_bool("uncased", False,
85 | help="Use uncased inputs or not.")
86 | flags.DEFINE_integer("n_token", 32000, help="Vocab size")
87 |
88 | # Model config
89 | flags.DEFINE_integer("mem_len", default=0,
90 | help="Number of steps to cache")
91 | flags.DEFINE_bool("same_length", default=False,
92 | help="Same length attention")
93 | flags.DEFINE_integer("clamp_len", default=-1,
94 | help="Clamp length")
95 |
96 | flags.DEFINE_integer("n_layer", default=6,
97 | help="Number of layers.")
98 | flags.DEFINE_integer("d_model", default=32,
99 | help="Dimension of the model.")
100 | flags.DEFINE_integer("d_embed", default=32,
101 | help="Dimension of the embeddings.")
102 | flags.DEFINE_integer("n_head", default=4,
103 | help="Number of attention heads.")
104 | flags.DEFINE_integer("d_head", default=8,
105 | help="Dimension of each attention head.")
106 | flags.DEFINE_integer("d_inner", default=32,
107 | help="Dimension of inner hidden size in positionwise feed-forward.")
108 | flags.DEFINE_float("dropout", default=0.0,
109 | help="Dropout rate.")
110 | flags.DEFINE_float("dropatt", default=0.0,
111 | help="Attention dropout rate.")
112 | flags.DEFINE_bool("untie_r", default=False,
113 | help="Untie r_w_bias and r_r_bias")
114 | flags.DEFINE_string("summary_type", default="last",
115 | help="Method used to summarize a sequence into a compact vector.")
116 | flags.DEFINE_string("ff_activation", default="relu",
117 | help="Activation type used in position-wise feed-forward.")
118 | flags.DEFINE_bool("use_bfloat16", False,
119 | help="Whether to use bfloat16.")
120 |
121 | # Parameter initialization
122 | flags.DEFINE_enum("init", default="normal",
123 | enum_values=["normal", "uniform"],
124 | help="Initialization method.")
125 | flags.DEFINE_float("init_std", default=0.02,
126 | help="Initialization std when init is normal.")
127 | flags.DEFINE_float("init_range", default=0.1,
128 | help="Initialization std when init is uniform.")
129 |
130 |
131 | FLAGS = flags.FLAGS
132 |
133 |
134 | def get_model_fn():
135 | def model_fn(features, labels, mems, is_training):
136 | #### Get loss from inputs
137 | total_loss, new_mems, monitor_dict = function_builder.get_loss(
138 | FLAGS, features, labels, mems, is_training)
139 |
140 | #### Check model parameters
141 | num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
142 | tf.logging.info('#params: {}'.format(num_params))
143 |
144 | # GPU
145 | assert is_training
146 | all_vars = tf.trainable_variables()
147 | grads = tf.gradients(total_loss, all_vars)
148 | grads_and_vars = list(zip(grads, all_vars))
149 |
150 | return total_loss, new_mems, grads_and_vars
151 |
152 | return model_fn
153 |
154 |
155 | def single_core_graph(is_training, features, mems):
156 | model_fn = get_model_fn()
157 |
158 | model_ret = model_fn(
159 | features=features,
160 | labels=None,
161 | mems=mems,
162 | is_training=is_training)
163 |
164 | return model_ret
165 |
166 |
167 | def create_mems_tf(bsz_per_core):
168 | mems = [tf.placeholder(dtype=tf.float32,
169 | shape=[FLAGS.mem_len, bsz_per_core, FLAGS.d_model])
170 | for layer in range(FLAGS.n_layer)]
171 |
172 | return mems
173 |
174 |
175 | def initialize_mems_np(bsz_per_core):
176 | mems_np = [np.zeros(shape=[FLAGS.mem_len, bsz_per_core, FLAGS.d_model],
177 | dtype=np.float32)
178 | for layer in range(FLAGS.n_layer)]
179 |
180 | return mems_np
181 |
182 |
183 | def train(ps_device):
184 | ##### Get input function and model function
185 |
186 | train_input_fn, record_info_dict = data_utils.get_input_fn(
187 | tfrecord_dir=FLAGS.record_info_dir,
188 | split="train",
189 | bsz_per_host=FLAGS.train_batch_size,
190 | seq_len=FLAGS.seq_len,
191 | reuse_len=FLAGS.reuse_len,
192 | bi_data=FLAGS.bi_data,
193 | num_hosts=1,
194 | num_core_per_host=1, # set to one no matter how many GPUs
195 | perm_size=FLAGS.perm_size,
196 | mask_alpha=FLAGS.mask_alpha,
197 | mask_beta=FLAGS.mask_beta,
198 | uncased=FLAGS.uncased,
199 | num_passes=FLAGS.num_passes,
200 | use_bfloat16=FLAGS.use_bfloat16,
201 | num_predict=FLAGS.num_predict)
202 |
203 | # for key, info in record_info_dict.items():
204 | tf.logging.info("num of batches {}".format(record_info_dict["num_batch"]))
205 |
206 | ##### Create input tensors / placeholders
207 | bsz_per_core = FLAGS.train_batch_size // FLAGS.num_core_per_host
208 |
209 | params = {
210 | "batch_size": FLAGS.train_batch_size # the whole batch
211 | }
212 | train_set = train_input_fn(params)
213 |
214 | example = train_set.make_one_shot_iterator().get_next()
215 |
216 | if FLAGS.num_core_per_host > 1:
217 | examples = [{} for _ in range(FLAGS.num_core_per_host)]
218 | for key in example.keys():
219 | vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
220 | for device_id in range(FLAGS.num_core_per_host):
221 | examples[device_id][key] = vals[device_id]
222 | else:
223 | examples = [example]
224 |
225 | ##### Create computational graph
226 | tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []
227 |
228 | for i in range(FLAGS.num_core_per_host):
229 | reuse = True if i > 0 else None
230 | with tf.device(assign_to_gpu(i, ps_device)), \
231 | tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
232 |
233 | # The mems for each tower is a dictionary
234 | mems_i = {}
235 | if FLAGS.mem_len:
236 | mems_i["mems"] = create_mems_tf(bsz_per_core)
237 |
238 | loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
239 | is_training=True,
240 | features=examples[i],
241 | mems=mems_i)
242 |
243 | tower_mems.append(mems_i)
244 | tower_losses.append(loss_i)
245 | tower_new_mems.append(new_mems_i)
246 | tower_grads_and_vars.append(grads_and_vars_i)
247 |
248 | ## average losses and gradients across towers
249 | if len(tower_losses) > 1:
250 | loss = tf.add_n(tower_losses) / len(tower_losses)
251 | grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
252 | else:
253 | loss = tower_losses[0]
254 | grads_and_vars = tower_grads_and_vars[0]
255 |
256 | ## get train op
257 | train_op, learning_rate, gnorm = model_utils.get_train_op(FLAGS, None,
258 | grads_and_vars=grads_and_vars)
259 | global_step = tf.train.get_global_step()
260 |
261 | ##### Training loop
262 | # initialize mems
263 | tower_mems_np = []
264 | for i in range(FLAGS.num_core_per_host):
265 | mems_i_np = {}
266 | for key in tower_mems[i].keys():
267 | mems_i_np[key] = initialize_mems_np(bsz_per_core)
268 | tower_mems_np.append(mems_i_np)
269 |
270 | saver = tf.train.Saver()
271 |
272 | gpu_options = tf.GPUOptions(allow_growth=True)
273 |
274 | model_utils.init_from_checkpoint(FLAGS, global_vars=True)
275 |
276 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
277 | gpu_options=gpu_options)) as sess:
278 | sess.run(tf.global_variables_initializer())
279 |
280 | fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]
281 |
282 | total_loss, prev_step = 0., -1
283 | while True:
284 | feed_dict = {}
285 | for i in range(FLAGS.num_core_per_host):
286 | for key in tower_mems_np[i].keys():
287 | for m, m_np in zip(tower_mems[i][key], tower_mems_np[i][key]):
288 | feed_dict[m] = m_np
289 |
290 | fetched = sess.run(fetches, feed_dict=feed_dict)
291 |
292 | loss_np, tower_mems_np, curr_step = fetched[:3]
293 | total_loss += loss_np
294 |
295 | if curr_step > 0 and curr_step % FLAGS.iterations == 0:
296 | curr_loss = total_loss / (curr_step - prev_step)
297 | tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
298 | "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
299 | curr_step, fetched[-3], fetched[-2],
300 | curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
301 | total_loss, prev_step = 0., curr_step
302 |
303 | if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
304 | save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
305 | saver.save(sess, save_path)
306 | tf.logging.info("Model saved in path: {}".format(save_path))
307 |
308 | if curr_step >= FLAGS.train_steps:
309 | break
310 |
311 |
312 | def main(unused_argv):
313 | del unused_argv # Unused
314 |
315 | tf.logging.set_verbosity(tf.logging.INFO)
316 |
317 | # Get corpus info
318 | FLAGS.n_token = data_utils.VOCAB_SIZE
319 | tf.logging.info("n_token {}".format(FLAGS.n_token))
320 |
321 | if not tf.gfile.Exists(FLAGS.model_dir):
322 | tf.gfile.MakeDirs(FLAGS.model_dir)
323 |
324 | train("/gpu:0")
325 |
326 |
327 | if __name__ == "__main__":
328 | tf.app.run()
329 |
--------------------------------------------------------------------------------
/xlnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import json
6 | import os
7 | import tensorflow as tf
8 | import modeling
9 |
10 |
11 | def _get_initializer(FLAGS):
12 | """Get variable intializer."""
13 | if FLAGS.init == "uniform":
14 | initializer = tf.initializers.random_uniform(
15 | minval=-FLAGS.init_range,
16 | maxval=FLAGS.init_range,
17 | seed=None)
18 | elif FLAGS.init == "normal":
19 | initializer = tf.initializers.random_normal(
20 | stddev=FLAGS.init_std,
21 | seed=None)
22 | else:
23 | raise ValueError("Initializer {} not supported".format(FLAGS.init))
24 | return initializer
25 |
26 |
27 | class XLNetConfig(object):
28 | """XLNetConfig contains hyperparameters that are specific to a model checkpoint;
29 | i.e., these hyperparameters should be the same between
30 | pretraining and finetuning.
31 |
32 | The following hyperparameters are defined:
33 | n_layer: int, the number of layers.
34 | d_model: int, the hidden size.
35 | n_head: int, the number of attention heads.
36 | d_head: int, the dimension size of each attention head.
37 | d_inner: int, the hidden size in feed-forward layers.
38 | ff_activation: str, "relu" or "gelu".
39 | untie_r: bool, whether to untie the biases in attention.
40 | n_token: int, the vocab size.
41 | """
42 |
43 | def __init__(self, FLAGS=None, json_path=None):
44 | """Constructing an XLNetConfig.
45 | One of FLAGS or json_path should be provided."""
46 |
47 | assert FLAGS is not None or json_path is not None
48 |
49 | self.keys = ["n_layer", "d_model", "n_head", "d_head", "d_inner",
50 | "ff_activation", "untie_r", "n_token"]
51 |
52 | if FLAGS is not None:
53 | self.init_from_flags(FLAGS)
54 |
55 | if json_path is not None:
56 | self.init_from_json(json_path)
57 |
58 | def init_from_flags(self, FLAGS):
59 | for key in self.keys:
60 | setattr(self, key, getattr(FLAGS, key))
61 |
62 | def init_from_json(self, json_path):
63 | with tf.gfile.Open(json_path) as f:
64 | json_data = json.load(f)
65 | for key in self.keys:
66 | setattr(self, key, json_data[key])
67 |
68 | def to_json(self, json_path):
69 | """Save XLNetConfig to a json file."""
70 | json_data = {}
71 | for key in self.keys:
72 | json_data[key] = getattr(self, key)
73 |
74 | json_dir = os.path.dirname(json_path)
75 | if not tf.gfile.Exists(json_dir):
76 | tf.gfile.MakeDirs(json_dir)
77 | with tf.gfile.Open(json_path, "w") as f:
78 | json.dump(json_data, f, indent=4, sort_keys=True)
79 |
80 |
81 | def create_run_config(is_training, is_finetune, FLAGS):
82 | kwargs = dict(
83 | is_training=is_training,
84 | use_tpu=FLAGS.use_tpu,
85 | use_bfloat16=FLAGS.use_bfloat16,
86 | dropout=FLAGS.dropout,
87 | dropatt=FLAGS.dropatt,
88 | init=FLAGS.init,
89 | init_range=FLAGS.init_range,
90 | init_std=FLAGS.init_std,
91 | clamp_len=FLAGS.clamp_len)
92 |
93 | if not is_finetune:
94 | kwargs.update(dict(
95 | mem_len=FLAGS.mem_len,
96 | reuse_len=FLAGS.reuse_len,
97 | bi_data=FLAGS.bi_data,
98 | clamp_len=FLAGS.clamp_len,
99 | same_length=FLAGS.same_length))
100 |
101 | return RunConfig(**kwargs)
102 |
103 |
104 | class RunConfig(object):
105 | """RunConfig contains hyperparameters that could be different
106 | between pretraining and finetuning.
107 | These hyperparameters can also be changed from run to run.
108 | We store them separately from XLNetConfig for flexibility.
109 | """
110 |
111 | def __init__(self, is_training, use_tpu, use_bfloat16, dropout, dropatt,
112 | init="normal", init_range=0.1, init_std=0.02, mem_len=None,
113 | reuse_len=None, bi_data=False, clamp_len=-1, same_length=False):
114 | """
115 | Args:
116 | is_training: bool, whether in training mode.
117 | use_tpu: bool, whether TPUs are used.
118 | use_bfloat16: bool, use bfloat16 instead of float32.
119 | dropout: float, dropout rate.
120 | dropatt: float, dropout rate on attention probabilities.
121 | init: str, the initialization scheme, either "normal" or "uniform".
122 | init_range: float, initialize the parameters with a uniform distribution
123 | in [-init_range, init_range]. Only effective when init="uniform".
124 | init_std: float, initialize the parameters with a normal distribution
125 | with mean 0 and stddev init_std. Only effective when init="normal".
126 | mem_len: int, the number of tokens to cache.
127 | reuse_len: int, the number of tokens in the currect batch to be cached
128 | and reused in the future.
129 | bi_data: bool, whether to use bidirectional input pipeline.
130 | Usually set to True during pretraining and False during finetuning.
131 | clamp_len: int, clamp all relative distances larger than clamp_len.
132 | -1 means no clamping.
133 | same_length: bool, whether to use the same attention length for each token.
134 | """
135 |
136 | self.init = init
137 | self.init_range = init_range
138 | self.init_std = init_std
139 | self.is_training = is_training
140 | self.dropout = dropout
141 | self.dropatt = dropatt
142 | self.use_tpu = use_tpu
143 | self.use_bfloat16 = use_bfloat16
144 | self.mem_len = mem_len
145 | self.reuse_len = reuse_len
146 | self.bi_data = bi_data
147 | self.clamp_len = clamp_len
148 | self.same_length = same_length
149 |
150 |
151 | class XLNetModel(object):
152 | """A wrapper of the XLNet model used during both pretraining and finetuning."""
153 |
154 | def __init__(self, xlnet_config, run_config, input_ids, seg_ids, input_mask,
155 | mems=None, perm_mask=None, target_mapping=None, inp_q=None,
156 | **kwargs):
157 | """
158 | Args:
159 | xlnet_config: XLNetConfig,
160 | run_config: RunConfig,
161 | input_ids: int32 Tensor in shape [len, bsz], the input token IDs.
162 | seg_ids: int32 Tensor in shape [len, bsz], the input segment IDs.
163 | input_mask: float32 Tensor in shape [len, bsz], the input mask.
164 | 0 for real tokens and 1 for padding.
165 | mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
166 | from previous batches. The length of the list equals n_layer.
167 | If None, no memory is used.
168 | perm_mask: float32 Tensor in shape [len, len, bsz].
169 | If perm_mask[i, j, k] = 0, i attend to j in batch k;
170 | if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
171 | If None, each position attends to all the others.
172 | target_mapping: float32 Tensor in shape [num_predict, len, bsz].
173 | If target_mapping[i, j, k] = 1, the i-th predict in batch k is
174 | on the j-th token.
175 | Only used during pretraining for partial prediction.
176 | Set to None during finetuning.
177 | inp_q: float32 Tensor in shape [len, bsz].
178 | 1 for tokens with losses and 0 for tokens without losses.
179 | Only used during pretraining for two-stream attention.
180 | Set to None during finetuning.
181 | """
182 |
183 | initializer = _get_initializer(run_config)
184 |
185 | tfm_args = dict(
186 | n_token=xlnet_config.n_token,
187 | initializer=initializer,
188 | attn_type="bi",
189 | n_layer=xlnet_config.n_layer,
190 | d_model=xlnet_config.d_model,
191 | n_head=xlnet_config.n_head,
192 | d_head=xlnet_config.d_head,
193 | d_inner=xlnet_config.d_inner,
194 | ff_activation=xlnet_config.ff_activation,
195 | untie_r=xlnet_config.untie_r,
196 |
197 | is_training=run_config.is_training,
198 | use_bfloat16=run_config.use_bfloat16,
199 | use_tpu=run_config.use_tpu,
200 | dropout=run_config.dropout,
201 | dropatt=run_config.dropatt,
202 |
203 | mem_len=run_config.mem_len,
204 | reuse_len=run_config.reuse_len,
205 | bi_data=run_config.bi_data,
206 | clamp_len=run_config.clamp_len,
207 | same_length=run_config.same_length
208 | )
209 |
210 | input_args = dict(
211 | inp_k=input_ids,
212 | seg_id=seg_ids,
213 | input_mask=input_mask,
214 | mems=mems,
215 | perm_mask=perm_mask,
216 | target_mapping=target_mapping,
217 | inp_q=inp_q)
218 | tfm_args.update(input_args)
219 |
220 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
221 | (self.output, self.new_mems, self.lookup_table
222 | ) = modeling.transformer_xl(**tfm_args)
223 |
224 | self.input_mask = input_mask
225 | self.initializer = initializer
226 | self.xlnet_config = xlnet_config
227 | self.run_config = run_config
228 |
229 | def get_pooled_out(self, summary_type, use_summ_proj=True):
230 | """
231 | Args:
232 | summary_type: str, "last", "first", "mean", or "attn". The method
233 | to pool the input to get a vector representation.
234 | use_summ_proj: bool, whether to use a linear projection during pooling.
235 |
236 | Returns:
237 | float32 Tensor in shape [bsz, d_model], the pooled representation.
238 | """
239 |
240 | xlnet_config = self.xlnet_config
241 | run_config = self.run_config
242 |
243 | with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
244 | summary = modeling.summarize_sequence(
245 | summary_type=summary_type,
246 | hidden=self.output,
247 | d_model=xlnet_config.d_model,
248 | n_head=xlnet_config.n_head,
249 | d_head=xlnet_config.d_head,
250 | dropout=run_config.dropout,
251 | dropatt=run_config.dropatt,
252 | is_training=run_config.is_training,
253 | input_mask=self.input_mask,
254 | initializer=self.initializer,
255 | use_proj=use_summ_proj)
256 |
257 | return summary
258 |
259 | def get_sequence_output(self):
260 | """
261 | Returns:
262 | float32 Tensor in shape [len, bsz, d_model]. The last layer hidden
263 | representation of XLNet.
264 | """
265 |
266 | return self.output
267 |
268 | def get_new_memory(self):
269 | """
270 | Returns:
271 | list of float32 Tensors in shape [mem_len, bsz, d_model], the new
272 | memory that concatenates the previous memory with the current input
273 | representations.
274 | The length of the list equals n_layer.
275 | """
276 | return self.new_mems
277 |
278 | def get_embedding_table(self):
279 | """
280 | Returns:
281 | float32 Tensor in shape [n_token, d_model]. The embedding lookup table.
282 | Used for tying embeddings between input and output layers.
283 | """
284 | return self.lookup_table
285 |
286 | def get_initializer(self):
287 | """
288 | Returns:
289 | A tf initializer. Used to initialize variables in layers on top of XLNet.
290 | """
291 | return self.initializer
292 |
293 |
--------------------------------------------------------------------------------