├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── __init__.py
├── build.sh
├── download_glue_data.py
├── flow
├── __init__.py
├── config
│ ├── config_l2_d2_w32.json
│ ├── config_l2_d3_w16.json
│ ├── config_l2_d3_w32.json
│ ├── config_l3_d2_w32.json
│ ├── config_l3_d3_w16.json
│ ├── config_l3_d3_w32.json
│ └── dump_config.py
├── glow_1x1.py
├── glow_init_hook.py
└── glow_ops_1x1.py
├── img
└── bert-flow.png
├── modeling.py
├── optimization.py
├── optimization_bert_flow.py
├── run_siamese.py
├── scripts
├── eval_stsb.py
└── train_siamese.sh
├── siamese_utils.py
├── tokenization.py
└── tokenization_test.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | BERT needs to maintain permanent compatibility with the pre-trained model files,
4 | so we do not plan to make any major changes to this library (other than what was
5 | promised in the README). However, we can accept small patches related to
6 | re-factoring and documentation. To submit contributes, there are just a few
7 | small guidelines you need to follow.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement. You (or your employer) retain the copyright to your contribution;
13 | this simply gives us permission to use and redistribute your contributions as
14 | part of the project. Head over to to see
15 | your current agreements on file or to sign a new one.
16 |
17 | You generally only need to submit a CLA once, so if you've already submitted one
18 | (even if it was for a different project), you probably don't need to do it
19 | again.
20 |
21 | ## Code reviews
22 |
23 | All submissions, including submissions by project members, require review. We
24 | use GitHub pull requests for this purpose. Consult
25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
26 | information on using pull requests.
27 |
28 | ## Community Guidelines
29 |
30 | This project follows
31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # On the Sentence Embeddings from Pre-trained Language Models
2 |
3 |
4 |
5 |
6 |
7 | This is a TensorFlow implementation of the following [paper](https://arxiv.org/abs/2011.05864):
8 |
9 | ```
10 | On the Sentence Embeddings from Pre-trained Language Models
11 | Bohan Li, Hao Zhou, Junxian He, Mingxuan Wang, Yiming Yang, Lei Li
12 | EMNLP 2020
13 | ```
14 |
15 |
16 |
17 | Model | Spearman's rho
18 | -------------------------------------------- | :-------------:
19 | BERT-large-NLI | 77.80
20 | BERT-large-NLI-last2avg | 78.45
21 | BERT-large-NLI-flow (target, train only) | 80.54
22 | BERT-large-NLI-flow (target, train+dev+test) | 81.18
23 |
24 |
25 | Please contact bohanl1@cs.cmu.edu if you have any questions.
26 |
27 |
28 | ## Requirements
29 |
30 | * Python >= 3.6
31 | * TensorFlow >= 1.14
32 |
33 | ## Preparation
34 |
35 | ### Pretrained BERT models
36 | ```bash
37 | export BERT_PREMODELS="../bert_premodels"
38 | mkdir ${BERT_PREMODELS}; cd ${BERT_PREMODELS}
39 |
40 | # then download the pre-trained BERT models from https://github.com/google-research/bert
41 | curl -O https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
42 | curl -O https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip
43 |
44 | ls ${BERT_PREMODELS}/uncased_L-12_H-768_A-12 # base
45 | ls ${BERT_PREMODELS}/uncased_L-24_H-1024_A-16 # large
46 | ```
47 |
48 | ### GLUE
49 | ```bash
50 | export GLUE_DIR="../glue_data"
51 | python download_glue_data.py --data_dir=${GLUE_DIR}
52 |
53 | # then download the labeled test set of STS-B
54 | cd ../glue_data/STS-B
55 | curl -O https://raw.githubusercontent.com/kawine/usif/master/STSBenchmark/sts-test.csv
56 | ```
57 |
58 | ### SentEval
59 | ```bash
60 | cd ..
61 | git clone https://github.com/facebookresearch/SentEval
62 | ```
63 |
64 | ## Usage
65 |
66 | ### Fine-tune BERT with NLI supervision (optional)
67 | ```bash
68 | export OUTPUT_PARENT_DIR="../exp"
69 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data
70 | mkdir ${CACHED_DIR}
71 |
72 | export RANDOM_SEED=1234
73 | export CUDA_VISIBLE_DEVICES=0
74 | export BERT_NAME="large"
75 | export TASK_NAME="ALLNLI"
76 | unset INIT_CKPT
77 | bash scripts/train_siamese.sh train \
78 | "--exp_name=exp_${BERT_NAME}_${RANDOM_SEED} \
79 | --num_train_epochs=1.0 \
80 | --learning_rate=2e-5 \
81 | --train_batch_size=16 \
82 | --cached_dir=${CACHED_DIR}"
83 |
84 |
85 | # evaluation
86 | export RANDOM_SEED=1234
87 | export CUDA_VISIBLE_DEVICES=0
88 | export TASK_NAME=STS-B
89 | export BERT_NAME=large
90 | export OUTPUT_PARENT_DIR="../exp"
91 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_${BERT_NAME}_${RANDOM_SEED}/model.ckpt-60108
92 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data
93 | export EXP_NAME=exp_${BERT_NAME}_${RANDOM_SEED}_eval
94 | bash scripts/train_siamese.sh predict \
95 | "--exp_name=${EXP_NAME} \
96 | --cached_dir=${CACHED_DIR} \
97 | --sentence_embedding_type=avg \
98 | --flow=0 --flow_loss=0 \
99 | --num_examples=0 \
100 | --num_train_epochs=1e-10"
101 | ```
102 |
103 | Note: You may want to add `--use_xla` to speed up the BERT fine-tuning.
104 |
105 | ### Unsupervised learning of flow-based generative models
106 | ```bash
107 | export CUDA_VISIBLE_DEVICES=0
108 | export TASK_NAME=STS-B
109 | export BERT_NAME=large
110 | export OUTPUT_PARENT_DIR="../exp"
111 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_large_1234/model.ckpt-60108
112 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data
113 | bash scripts/train_siamese.sh train \
114 | "--exp_name_prefix=exp \
115 | --cached_dir=${CACHED_DIR} \
116 | --sentence_embedding_type=avg-last-2 \
117 | --flow=1 --flow_loss=1 \
118 | --num_examples=0 \
119 | --num_train_epochs=1.0 \
120 | --flow_learning_rate=1e-3 \
121 | --use_full_for_training=1"
122 |
123 | # evaluation
124 | export CUDA_VISIBLE_DEVICES=0
125 | export TASK_NAME=STS-B
126 | export BERT_NAME=large
127 | export OUTPUT_PARENT_DIR="../exp"
128 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_large_1234/model.ckpt-60108
129 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data
130 | export EXP_NAME=exp_t_STS-B_ep_1.00_lr_5.00e-05_e_avg-last-2_f_11_1.00e-03_allsplits
131 | bash scripts/train_siamese.sh predict \
132 | "--exp_name=${EXP_NAME} \
133 | --cached_dir=${CACHED_DIR} \
134 | --sentence_embedding_type=avg-last-2 \
135 | --flow=1 --flow_loss=1 \
136 | --num_examples=0 \
137 | --num_train_epochs=1.0 \
138 | --flow_learning_rate=1e-3 \
139 | --use_full_for_training=1"
140 | ```
141 |
142 | ### Fit flow with only the training set of STS-B
143 | ```bash
144 | export CUDA_VISIBLE_DEVICES=0
145 | export TASK_NAME=STS-B
146 | export BERT_NAME=large
147 | export OUTPUT_PARENT_DIR="../exp"
148 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_large_1234/model.ckpt-60108
149 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data
150 | bash scripts/train_siamese.sh train \
151 | "--exp_name_prefix=exp \
152 | --cached_dir=${CACHED_DIR} \
153 | --sentence_embedding_type=avg-last-2 \
154 | --flow=1 --flow_loss=1 \
155 | --num_examples=0 \
156 | --num_train_epochs=1.0 \
157 | --flow_learning_rate=1e-3 \
158 | --use_full_for_training=0"
159 |
160 | # evaluation
161 | export CUDA_VISIBLE_DEVICES=0
162 | export TASK_NAME=STS-B
163 | export BERT_NAME=large
164 | export OUTPUT_PARENT_DIR="../exp"
165 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_large_1234/model.ckpt-60108
166 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data
167 | export EXP_NAME=exp_t_STS-B_ep_1.00_lr_5.00e-05_e_avg-last-2_f_11_1.00e-03
168 | bash scripts/train_siamese.sh predict \
169 | "--exp_name=${EXP_NAME} \
170 | --cached_dir=${CACHED_DIR} \
171 | --sentence_embedding_type=avg-last-2 \
172 | --flow=1 --flow_loss=1 \
173 | --num_examples=0 \
174 | --num_train_epochs=1.0 \
175 | --flow_learning_rate=1e-3 \
176 | --use_full_for_training=1"
177 | ```
178 |
179 | ## Download our models
180 | Our models are available at https://drive.google.com/file/d/1-vO47t5SPFfzZPKkkhSe4tXhn8u--KLR/view?usp=sharing
181 |
182 | ## Reference
183 |
184 | ```
185 | @inproceedings{li2020emnlp,
186 | title = {On the Sentence Embeddings from Pre-trained Language Models},
187 | author = {Bohan Li and Hao Zhou and Junxian He and Mingxuan Wang and Yiming Yang and Lei Li},
188 | booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
189 | month = {November},
190 | year = {2020}
191 | }
192 |
193 | ```
194 |
195 | ## Acknowledgements
196 |
197 | A large portion of this repo is borrowed from the following projects:
198 | - https://github.com/google-research/bert
199 | - https://github.com/zihangdai/xlnet
200 | - https://github.com/tensorflow/tensor2tensor
201 |
202 |
203 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/build.sh:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bohanli/BERT-flow/7fa8f6d4a1a73e2c2f8549799d9bafc9d6048d67/build.sh
--------------------------------------------------------------------------------
/download_glue_data.py:
--------------------------------------------------------------------------------
1 | ''' Script for downloading all GLUE data.
2 |
3 | Note: for legal reasons, we are unable to host MRPC.
4 | You can either use the version hosted by the SentEval team, which is already tokenized,
5 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
6 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
7 | You should then rename and place specific files in a folder (see below for an example).
8 |
9 | mkdir MRPC
10 | cabextract MSRParaphraseCorpus.msi -d MRPC
11 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
12 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
13 | rm MRPC/_*
14 | rm MSRParaphraseCorpus.msi
15 |
16 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
17 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
18 | '''
19 |
20 | import os
21 | import sys
22 | import shutil
23 | import argparse
24 | import tempfile
25 | import urllib.request
26 | import zipfile
27 |
28 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
29 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
30 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
31 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
32 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
33 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
34 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
35 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
36 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
37 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
38 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
39 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
40 |
41 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
42 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
43 |
44 | def download_and_extract(task, data_dir):
45 | print("Downloading and extracting %s..." % task)
46 | data_file = "%s.zip" % task
47 | urllib.request.urlretrieve(TASK2PATH[task], data_file)
48 | with zipfile.ZipFile(data_file) as zip_ref:
49 | zip_ref.extractall(data_dir)
50 | os.remove(data_file)
51 | print("\tCompleted!")
52 |
53 | def format_mrpc(data_dir, path_to_data):
54 | print("Processing MRPC...")
55 | mrpc_dir = os.path.join(data_dir, "MRPC")
56 | if not os.path.isdir(mrpc_dir):
57 | os.mkdir(mrpc_dir)
58 | if path_to_data:
59 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
60 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
61 | else:
62 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
63 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
64 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
65 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
66 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
67 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file
68 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file
69 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
70 |
71 | dev_ids = []
72 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
73 | for row in ids_fh:
74 | dev_ids.append(row.strip().split('\t'))
75 |
76 | with open(mrpc_train_file, encoding="utf8") as data_fh, \
77 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
78 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
79 | header = data_fh.readline()
80 | train_fh.write(header)
81 | dev_fh.write(header)
82 | for row in data_fh:
83 | label, id1, id2, s1, s2 = row.strip().split('\t')
84 | if [id1, id2] in dev_ids:
85 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
86 | else:
87 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
88 |
89 | with open(mrpc_test_file, encoding="utf8") as data_fh, \
90 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
91 | header = data_fh.readline()
92 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
93 | for idx, row in enumerate(data_fh):
94 | label, id1, id2, s1, s2 = row.strip().split('\t')
95 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
96 | print("\tCompleted!")
97 |
98 | def download_diagnostic(data_dir):
99 | print("Downloading and extracting diagnostic...")
100 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
101 | os.mkdir(os.path.join(data_dir, "diagnostic"))
102 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
103 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
104 | print("\tCompleted!")
105 | return
106 |
107 | def get_tasks(task_names):
108 | task_names = task_names.split(',')
109 | if "all" in task_names:
110 | tasks = TASKS
111 | else:
112 | tasks = []
113 | for task_name in task_names:
114 | assert task_name in TASKS, "Task %s not found!" % task_name
115 | tasks.append(task_name)
116 | return tasks
117 |
118 | def main(arguments):
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
121 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
122 | type=str, default='all')
123 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
124 | type=str, default='')
125 | args = parser.parse_args(arguments)
126 |
127 | if not os.path.isdir(args.data_dir):
128 | os.mkdir(args.data_dir)
129 | tasks = get_tasks(args.tasks)
130 |
131 | for task in tasks:
132 | if task == 'MRPC':
133 | format_mrpc(args.data_dir, args.path_to_mrpc)
134 | elif task == 'diagnostic':
135 | download_diagnostic(args.data_dir)
136 | else:
137 | download_and_extract(task, args.data_dir)
138 |
139 |
140 | if __name__ == '__main__':
141 | sys.exit(main(sys.argv[1:]))
--------------------------------------------------------------------------------
/flow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bohanli/BERT-flow/7fa8f6d4a1a73e2c2f8549799d9bafc9d6048d67/flow/__init__.py
--------------------------------------------------------------------------------
/flow/config/config_l2_d2_w32.json:
--------------------------------------------------------------------------------
1 | {
2 | "latent_architecture": "glow_resnet",
3 | "activation": "relu",
4 | "coupling": "additive",
5 | "coupling_width": 32,
6 | "coupling_dropout": 0.0,
7 | "top_prior": "normal",
8 | "n_levels": 2,
9 | "depth": 2,
10 | "permutation": true,
11 | "use_fp16": false
12 | }
13 |
--------------------------------------------------------------------------------
/flow/config/config_l2_d3_w16.json:
--------------------------------------------------------------------------------
1 | {
2 | "latent_architecture": "glow_resnet",
3 | "activation": "relu",
4 | "coupling": "additive",
5 | "coupling_width": 32,
6 | "coupling_dropout": 0.0,
7 | "top_prior": "normal",
8 | "n_levels": 2,
9 | "depth": 3,
10 | "permutation": true,
11 | "use_fp16": false
12 | }
13 |
--------------------------------------------------------------------------------
/flow/config/config_l2_d3_w32.json:
--------------------------------------------------------------------------------
1 | {
2 | "latent_architecture": "glow_resnet",
3 | "activation": "relu",
4 | "coupling": "additive",
5 | "coupling_width": 32,
6 | "coupling_dropout": 0.0,
7 | "top_prior": "normal",
8 | "n_levels": 2,
9 | "depth": 3,
10 | "permutation": true,
11 | "use_fp16": false
12 | }
13 |
--------------------------------------------------------------------------------
/flow/config/config_l3_d2_w32.json:
--------------------------------------------------------------------------------
1 | {
2 | "latent_architecture": "glow_resnet",
3 | "activation": "relu",
4 | "coupling": "additive",
5 | "coupling_width": 32,
6 | "coupling_dropout": 0.0,
7 | "top_prior": "normal",
8 | "n_levels": 3,
9 | "depth": 2,
10 | "permutation": true,
11 | "use_fp16": false
12 | }
13 |
--------------------------------------------------------------------------------
/flow/config/config_l3_d3_w16.json:
--------------------------------------------------------------------------------
1 | {
2 | "latent_architecture": "glow_resnet",
3 | "activation": "relu",
4 | "coupling": "additive",
5 | "coupling_width": 16,
6 | "coupling_dropout": 0.0,
7 | "top_prior": "normal",
8 | "n_levels": 3,
9 | "depth": 3,
10 | "permutation": true,
11 | "use_fp16": false
12 | }
13 |
--------------------------------------------------------------------------------
/flow/config/config_l3_d3_w32.json:
--------------------------------------------------------------------------------
1 | {
2 | "latent_architecture": "glow_resnet",
3 | "activation": "relu",
4 | "coupling": "additive",
5 | "coupling_width": 32,
6 | "coupling_dropout": 0.0,
7 | "top_prior": "normal",
8 | "n_levels": 3,
9 | "depth": 3,
10 | "permutation": true,
11 | "use_fp16": false
12 | }
13 |
--------------------------------------------------------------------------------
/flow/config/dump_config.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | model_config = {
4 | "latent_architecture": "glow_resnet",
5 | "activation": "relu", # Activation - Relu or Gatu
6 | "coupling": "affine", # Coupling layer, additive or affine.
7 | "coupling_width": 512,
8 | "coupling_dropout": 0.0,
9 | "top_prior": "normal",
10 | "n_levels": 3,
11 | "depth": 6,
12 |
13 | "use_fp16": False # not implemented yet
14 | }
15 |
16 | with open('/mnt/cephfs_new_wj/mlnlp/libohan.05/text_flow/config/model_config_normal.json', 'w') as jp:
17 | json.dump(model_config, jp, indent=4)
18 |
19 |
--------------------------------------------------------------------------------
/flow/glow_1x1.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope
3 | import flow.glow_ops_1x1 as glow_ops
4 | from flow.glow_ops_1x1 import get_shape_list
5 | import flow.glow_init_hook
6 |
7 |
8 | import numpy as np
9 | import os, sys
10 |
11 | arg_scope = tf.contrib.framework.arg_scope
12 | add_arg_scope = tf.contrib.framework.add_arg_scope
13 |
14 | class AttrDict(dict):
15 | def __init__(self, *args, **kwargs):
16 | super(AttrDict, self).__init__(*args, **kwargs)
17 | self.__dict__ = self
18 |
19 | class Glow():
20 |
21 | def __init__(self, hparams):
22 | self.hparams = hparams
23 |
24 | @property
25 | def is_predicting(self):
26 | return not self.is_training
27 |
28 | @staticmethod
29 | def train_hooks():
30 | #del hook_context
31 | return [glow_init_hook.GlowInitHook()]
32 |
33 | def top_prior(self):
34 | """Objective based on the prior over latent z.
35 |
36 | Returns:
37 | dist: instance of tfp.distributions.Normal, prior distribution.
38 | """
39 | return glow_ops.top_prior(
40 | "top_prior", self.z_top_shape, learn_prior=self.hparams.top_prior)
41 |
42 | def body(self, features, is_training):
43 | if is_training:
44 | init_features = features
45 | init_op = self.objective_tower(init_features, init=True)
46 | init_op = tf.Print(
47 | init_op, [init_op], message="Triggering data-dependent init.",
48 | first_n=20)
49 | tf.compat.v1.add_to_collection("glow_init_op", init_op)
50 | return self.objective_tower(features, init=False)
51 |
52 | def objective_tower(self, features, init=True):
53 | """Objective in terms of bits-per-pixel.
54 | """
55 | #features = tf.expand_dims(features, [1, 2])
56 | features = features[:, tf.newaxis, tf.newaxis, :]
57 | x = features
58 |
59 | objective = 0
60 |
61 | # The arg_scope call ensures that the actnorm parameters are set such that
62 | # the per-channel output activations have zero mean and unit variance
63 | # ONLY during the first step. After that the parameters are learned
64 | # through optimisation.
65 | ops = [glow_ops.get_variable_ddi, glow_ops.actnorm, glow_ops.get_dropout]
66 | with arg_scope(ops, init=init):
67 | encoder = glow_ops.encoder_decoder
68 |
69 | self.z, encoder_objective, self.eps, _, _ = encoder(
70 | "flow", x, self.hparams, eps=None, reverse=False)
71 | objective += encoder_objective
72 |
73 | self.z_top_shape = get_shape_list(self.z)
74 | prior_dist = self.top_prior()
75 | prior_objective = tf.reduce_sum(
76 | prior_dist.log_prob(self.z), axis=[1, 2, 3])
77 | #self.z_sample = prior_dist.sample()
78 | objective += prior_objective
79 |
80 | # bits per pixel
81 | _, h, w, c = get_shape_list(x)
82 | objective = -objective / (np.log(2) * h * w * c)
83 |
84 | self.z = tf.concat(self.eps + [self.z], axis=-1)
85 | return objective
--------------------------------------------------------------------------------
/flow/glow_init_hook.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class GlowInitHook(tf.estimator.SessionRunHook):
5 | """
6 | Hook that runs data-dependent initialization once before the first step.
7 |
8 | The init op is stored in the tf collection glow_init_op. Look at the
9 | "body" in glow.py for more details.
10 | """
11 |
12 | def after_create_session(self, session, coord):
13 | del coord
14 | global_step = session.run(tf.compat.v1.train.get_or_create_global_step())
15 | if global_step == 0:
16 | ddi = tf.get_collection("glow_init_op")
17 | # In-case of a multi-GPU system, this just runs the first op in the
18 | # collection.
19 | print(ddi)
20 | if ddi:
21 | session.run(ddi[0])
22 | #input()
23 |
--------------------------------------------------------------------------------
/flow/glow_ops_1x1.py:
--------------------------------------------------------------------------------
1 | """
2 | modified from
3 | https://github.com/tensorflow/tensor2tensor/blob/8a084a4d56/tensor2tensor/models/research/glow_ops.py
4 |
5 | modifications are as follows:
6 | 1. replace tfp with tf because neither tfp 0.6 or 0.7 is compatible with tf 1.14
7 | 2. remove support for video-related operators like conv3d
8 | 3. remove support for conditional distributions
9 | """
10 | import tensorflow as tf
11 | from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope
12 | # import tensorflow_probability as tfp
13 |
14 | import functools
15 | import numpy as np
16 | import scipy
17 |
18 |
19 | def get_shape_list(x):
20 | """Return list of dims, statically where possible."""
21 | x = tf.convert_to_tensor(x)
22 |
23 | # If unknown rank, return dynamic shape
24 | if x.get_shape().dims is None:
25 | return tf.shape(x)
26 |
27 | static = x.get_shape().as_list()
28 | shape = tf.shape(x)
29 |
30 | ret = []
31 | for i, dim in enumerate(static):
32 | if dim is None:
33 | dim = shape[i]
34 | ret.append(dim)
35 | return ret
36 |
37 | def get_eps(dist, x):
38 | """Z = (X - mu) / sigma."""
39 | return (x - dist.loc) / dist.scale
40 |
41 |
42 | def set_eps(dist, eps):
43 | """Z = eps * sigma + mu."""
44 | return eps * dist.scale + dist.loc
45 |
46 |
47 | # ===============================================
48 | @add_arg_scope
49 | def assign(w, initial_value):
50 | w = w.assign(initial_value)
51 | with tf.control_dependencies([w]):
52 | return w
53 |
54 | @add_arg_scope
55 | def get_variable_ddi(name, shape, initial_value, dtype=tf.float32, init=False,
56 | trainable=True):
57 | """Wrapper for data-dependent initialization."""
58 | # If init is a tf bool: w is assigned dynamically at runtime.
59 | # If init is a python bool: then w is determined during graph construction.
60 | w = tf.compat.v1.get_variable(name, shape, dtype, None, trainable=trainable)
61 | if isinstance(init, bool):
62 | if init:
63 | return assign(w, initial_value)
64 | return w
65 | else:
66 | return tf.cond(init, lambda: assign(w, initial_value), lambda: w)
67 |
68 | @add_arg_scope
69 | def get_dropout(x, rate=0.0, init=True):
70 | """Dropout x with dropout_rate = rate.
71 | Apply zero dropout during init or prediction time.
72 | Args:
73 | x: 4-D Tensor, shape=(NHWC).
74 | rate: Dropout rate.
75 | init: Initialization.
76 | Returns:
77 | x: activations after dropout.
78 | """
79 | if init or rate == 0:
80 | return x
81 | return tf.layers.dropout(x, rate=rate, training=True) # TODO
82 |
83 | def default_initializer(std=0.05):
84 | return tf.random_normal_initializer(0., std)
85 |
86 | # ===============================================
87 |
88 | # Activation normalization
89 | # Convenience function that does centering+scaling
90 |
91 | @add_arg_scope
92 | def actnorm(name, x, logscale_factor=3., reverse=False, init=False,
93 | trainable=True):
94 | """x_{ij} = s x x_{ij} + b. Per-channel scaling and bias.
95 | If init is set to True, the scaling and bias are initialized such
96 | that the mean and variance of the output activations of the first minibatch
97 | are zero and one respectively.
98 | Args:
99 | name: variable scope.
100 | x: input
101 | logscale_factor: Used in actnorm_scale. Optimizes f(ls*s') instead of f(s)
102 | where s' = s / ls. Helps in faster convergence.
103 | reverse: forward or reverse operation.
104 | init: Whether or not to do data-dependent initialization.
105 | trainable:
106 | Returns:
107 | x: output after adding bias and scaling.
108 | objective: log(sum(s))
109 | """
110 | var_arg_scope = arg_scope([get_variable_ddi], trainable=trainable)
111 | var_scope = tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE)
112 |
113 | with var_scope, var_arg_scope:
114 | if not reverse:
115 | x = actnorm_center(name + "_center", x, reverse, init=init)
116 | x, objective = actnorm_scale(
117 | name + "_scale", x, logscale_factor=logscale_factor,
118 | reverse=reverse, init=init)
119 | else:
120 | x, objective = actnorm_scale(
121 | name + "_scale", x, logscale_factor=logscale_factor,
122 | reverse=reverse, init=init)
123 | x = actnorm_center(name + "_center", x, reverse, init=init)
124 | return x, objective
125 |
126 |
127 | @add_arg_scope
128 | def actnorm_center(name, x, reverse=False, init=False):
129 | """Add a bias to x.
130 | Initialize such that the output of the first minibatch is zero centered
131 | per channel.
132 | Args:
133 | name: scope
134 | x: 2-D or 4-D Tensor.
135 | reverse: Forward or backward operation.
136 | init: data-dependent initialization.
137 | Returns:
138 | x_center: (x + b), if reverse is True and (x - b) otherwise.
139 | """
140 | shape = get_shape_list(x)
141 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
142 | assert len(shape) == 2 or len(shape) == 4
143 | if len(shape) == 2:
144 | x_mean = tf.reduce_mean(x, [0], keepdims=True)
145 | b = get_variable_ddi("b", (1, shape[1]), initial_value=-x_mean,
146 | init=init)
147 | elif len(shape) == 4:
148 | x_mean = tf.reduce_mean(x, [0, 1, 2], keepdims=True)
149 | b = get_variable_ddi(
150 | "b", (1, 1, 1, shape[3]), initial_value=-x_mean, init=init)
151 |
152 | if not reverse:
153 | x += b
154 | else:
155 | x -= b
156 | return x
157 |
158 |
159 | @add_arg_scope
160 | def actnorm_scale(name, x, logscale_factor=3., reverse=False, init=False):
161 | """Per-channel scaling of x."""
162 | x_shape = get_shape_list(x)
163 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
164 |
165 | # Variance initialization logic.
166 | assert len(x_shape) == 2 or len(x_shape) == 4
167 | if len(x_shape) == 2:
168 | x_var = tf.reduce_mean(x**2, [0], keepdims=True)
169 | logdet_factor = 1
170 | var_shape = (1, x_shape[1])
171 | elif len(x_shape) == 4:
172 | x_var = tf.reduce_mean(x**2, [0, 1, 2], keepdims=True)
173 | logdet_factor = x_shape[1]*x_shape[2]
174 | var_shape = (1, 1, 1, x_shape[3])
175 |
176 | init_value = tf.math.log(1.0 / (tf.sqrt(x_var) + 1e-6)) / logscale_factor
177 | logs = get_variable_ddi("logs", var_shape, initial_value=init_value,
178 | init=init)
179 | logs = logs * logscale_factor
180 |
181 | # Function and reverse function.
182 | if not reverse:
183 | x = x * tf.exp(logs)
184 | else:
185 | x = x * tf.exp(-logs)
186 |
187 | # Objective calculation, h * w * sum(log|s|)
188 | dlogdet = tf.reduce_sum(logs) * logdet_factor
189 | if reverse:
190 | dlogdet *= -1
191 | return x, dlogdet
192 |
193 |
194 | # ===============================================
195 |
196 |
197 | @add_arg_scope
198 | def invertible_1x1_conv(name, x, reverse=False, permutation=False):
199 | """1X1 convolution on x.
200 | The 1X1 convolution is parametrized as P*L*(U + sign(s)*exp(log(s))) where
201 | 1. P is a permutation matrix.
202 | 2. L is a lower triangular matrix with diagonal entries unity.
203 | 3. U is a upper triangular matrix where the diagonal entries zero.
204 | 4. s is a vector.
205 | sign(s) and P are fixed and the remaining are optimized. P, L, U and s are
206 | initialized by the PLU decomposition of a random rotation matrix.
207 | Args:
208 | name: scope
209 | x: Input Tensor.
210 | reverse: whether the pass is from z -> x or x -> z.
211 | Returns:
212 | x_conv: x after a 1X1 convolution is applied on x.
213 | objective: sum(log(s))
214 | """
215 | _, height, width, channels = get_shape_list(x)
216 | w_shape = [channels, channels]
217 |
218 | if permutation:
219 | np_w = np.zeros((channels, channels)).astype("float32")
220 | for i in range(channels):
221 | np_w[i][channels-1-i] = 1.
222 |
223 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
224 | w = tf.compat.v1.get_variable("w", initializer=np_w, trainable=False)
225 |
226 | # If height or width cannot be statically determined then they end up as
227 | # tf.int32 tensors, which cannot be directly multiplied with a floating
228 | # point tensor without a cast.
229 | objective = 0.
230 | if not reverse:
231 | w = tf.reshape(w, [1, 1] + w_shape)
232 | x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format="NHWC")
233 | else:
234 | w_inv = tf.reshape(tf.linalg.inv(w), [1, 1] + w_shape)
235 | x = tf.nn.conv2d(
236 | x, w_inv, [1, 1, 1, 1], "SAME", data_format="NHWC")
237 | objective *= -1
238 | return x, objective
239 | else:
240 | # Random rotation-matrix Q
241 | random_matrix = np.random.rand(channels, channels)
242 | np_w = scipy.linalg.qr(random_matrix)[0].astype("float32")
243 |
244 | # Initialize P,L,U and s from the LU decomposition of a random rotation matrix
245 | np_p, np_l, np_u = scipy.linalg.lu(np_w)
246 | np_s = np.diag(np_u)
247 | np_sign_s = np.sign(np_s)
248 | np_log_s = np.log(np.abs(np_s))
249 | np_u = np.triu(np_u, k=1)
250 |
251 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
252 | p = tf.compat.v1.get_variable("P", initializer=np_p, trainable=False)
253 | l = tf.compat.v1.get_variable("L", initializer=np_l)
254 | sign_s = tf.compat.v1.get_variable(
255 | "sign_S", initializer=np_sign_s, trainable=False)
256 | log_s = tf.compat.v1.get_variable("log_S", initializer=np_log_s)
257 | u = tf.compat.v1.get_variable("U", initializer=np_u)
258 |
259 | # W = P * L * (U + sign_s * exp(log_s))
260 | l_mask = np.tril(np.ones([channels, channels], dtype=np.float32), -1)
261 | l = l * l_mask + tf.eye(channels, channels)
262 | u = u * np.transpose(l_mask) + tf.linalg.diag(sign_s * tf.exp(log_s))
263 | w = tf.matmul(p, tf.matmul(l, u))
264 |
265 | # If height or width cannot be statically determined then they end up as
266 | # tf.int32 tensors, which cannot be directly multiplied with a floating
267 | # point tensor without a cast.
268 | objective = tf.reduce_sum(log_s) * tf.cast(height * width, log_s.dtype)
269 | if not reverse:
270 | w = tf.reshape(w, [1, 1] + w_shape)
271 | x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format="NHWC")
272 | else:
273 | w_inv = tf.reshape(tf.linalg.inv(w), [1, 1] + w_shape)
274 | x = tf.nn.conv2d(
275 | x, w_inv, [1, 1, 1, 1], "SAME", data_format="NHWC")
276 | objective *= -1
277 | return x, objective
278 |
279 |
280 |
281 |
282 | # ===============================================
283 |
284 | def add_edge_bias(x, filter_size):
285 | """Pad x and concatenates an edge bias across the depth of x.
286 | The edge bias can be thought of as a binary feature which is unity when
287 | the filter is being convolved over an edge and zero otherwise.
288 | Args:
289 | x: Input tensor, shape (NHWC)
290 | filter_size: filter_size to determine padding.
291 | Returns:
292 | x_pad: Input tensor, shape (NHW(c+1))
293 | """
294 | x_shape = get_shape_list(x)
295 | if filter_size[0] == 1 and filter_size[1] == 1:
296 | return x
297 | a = (filter_size[0] - 1) // 2 # vertical padding size
298 | b = (filter_size[1] - 1) // 2 # horizontal padding size
299 | padding = [[0, 0], [a, a], [b, b], [0, 0]]
300 | x_bias = tf.zeros(x_shape[:-1] + [1])
301 |
302 | x = tf.pad(x, padding)
303 | x_pad = tf.pad(x_bias, padding, constant_values=1)
304 | return tf.concat([x, x_pad], axis=3)
305 |
306 |
307 | @add_arg_scope
308 | def conv(name, x, output_channels, filter_size=None, stride=None,
309 | logscale_factor=3.0, apply_actnorm=True, conv_init="default",
310 | dilations=None):
311 | """Convolutional layer with edge bias padding and optional actnorm.
312 | If x is 5-dimensional, actnorm is applied independently across every
313 | time-step.
314 | Args:
315 | name: variable scope.
316 | x: 4-D Tensor or 5-D Tensor of shape NHWC or NTHWC
317 | output_channels: Number of output channels.
318 | filter_size: list of ints, if None [3, 3] and [2, 3, 3] are defaults for
319 | 4-D and 5-D input tensors respectively.
320 | stride: list of ints, default stride: 1
321 | logscale_factor: see actnorm for parameter meaning.
322 | apply_actnorm: if apply_actnorm the activations of the first minibatch
323 | have zero mean and unit variance. Else, there is no scaling
324 | applied.
325 | conv_init: default or zeros. default is a normal distribution with 0.05 std.
326 | dilations: List of integers, apply dilations.
327 | Returns:
328 | x: actnorm(conv2d(x))
329 | Raises:
330 | ValueError: if init is set to "zeros" and apply_actnorm is set to True.
331 | """
332 | if conv_init == "zeros" and apply_actnorm:
333 | raise ValueError("apply_actnorm is unstable when init is set to zeros.")
334 |
335 | x_shape = get_shape_list(x)
336 | is_2d = len(x_shape) == 4
337 | num_steps = x_shape[1]
338 |
339 | # set filter_size, stride and in_channels
340 | if is_2d:
341 | if filter_size is None:
342 | filter_size = [1, 1] # filter_size = [3, 3]
343 | if stride is None:
344 | stride = [1, 1]
345 | if dilations is None:
346 | dilations = [1, 1, 1, 1]
347 | actnorm_func = actnorm
348 | x = add_edge_bias(x, filter_size=filter_size)
349 | conv_filter = tf.nn.conv2d
350 | else:
351 | raise NotImplementedError('x must be a NHWC 4-D Tensor!')
352 |
353 | in_channels = get_shape_list(x)[-1]
354 | filter_shape = filter_size + [in_channels, output_channels]
355 | stride_shape = [1] + stride + [1]
356 |
357 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
358 |
359 | if conv_init == "default":
360 | initializer = default_initializer()
361 | elif conv_init == "zeros":
362 | initializer = tf.zeros_initializer()
363 |
364 | w = tf.compat.v1.get_variable("W", filter_shape, tf.float32, initializer=initializer)
365 | x = conv_filter(x, w, stride_shape, padding="VALID", dilations=dilations)
366 | if apply_actnorm:
367 | x, _ = actnorm_func("actnorm", x, logscale_factor=logscale_factor)
368 | else:
369 | x += tf.compat.v1.get_variable("b", [1, 1, 1, output_channels],
370 | initializer=tf.zeros_initializer())
371 | logs = tf.compat.v1.get_variable("logs", [1, output_channels],
372 | initializer=tf.zeros_initializer())
373 | x *= tf.exp(logs * logscale_factor)
374 | return x
375 |
376 |
377 | @add_arg_scope
378 | def conv_block(name, x, mid_channels, dilations=None, activation="relu",
379 | dropout=0.0):
380 | """2 layer conv block used in the affine coupling layer.
381 | Args:
382 | name: variable scope.
383 | x: 4-D or 5-D Tensor.
384 | mid_channels: Output channels of the second layer.
385 | dilations: Optional, list of integers.
386 | activation: relu or gatu.
387 | If relu, the second layer is relu(W*x)
388 | If gatu, the second layer is tanh(W1*x) * sigmoid(W2*x)
389 | dropout: Dropout probability.
390 | Returns:
391 | x: 4-D Tensor: Output activations.
392 | """
393 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
394 |
395 | x_shape = get_shape_list(x)
396 | is_2d = len(x_shape) == 4
397 | num_steps = x_shape[1]
398 | if is_2d:
399 | first_filter = [1, 1] # first_filter = [3, 3]
400 | second_filter = [1, 1]
401 | else:
402 | raise NotImplementedError('x must be a NHWC 4-D Tensor!')
403 |
404 | # Edge Padding + conv2d + actnorm + relu:
405 | # [output: 512 channels]
406 | x = conv("1_1", x, output_channels=mid_channels, filter_size=first_filter,
407 | dilations=dilations)
408 | x = tf.nn.relu(x)
409 | x = get_dropout(x, rate=dropout)
410 |
411 | # Padding + conv2d + actnorm + activation.
412 | # [input, output: 512 channels]
413 | if activation == "relu":
414 | x = conv("1_2", x, output_channels=mid_channels,
415 | filter_size=second_filter, dilations=dilations)
416 | x = tf.nn.relu(x)
417 | elif activation == "gatu":
418 | # x = tanh(w1*x) * sigm(w2*x)
419 | x_tanh = conv("1_tanh", x, output_channels=mid_channels,
420 | filter_size=second_filter, dilations=dilations)
421 | x_sigm = conv("1_sigm", x, output_channels=mid_channels,
422 | filter_size=second_filter, dilations=dilations)
423 | x = tf.nn.tanh(x_tanh) * tf.nn.sigmoid(x_sigm)
424 |
425 | x = get_dropout(x, rate=dropout)
426 | return x
427 |
428 |
429 | @add_arg_scope
430 | def conv_stack(name, x, mid_channels, output_channels, dilations=None,
431 | activation="relu", dropout=0.0):
432 | """3-layer convolutional stack.
433 | Args:
434 | name: variable scope.
435 | x: 5-D Tensor.
436 | mid_channels: Number of output channels of the first layer.
437 | output_channels: Number of output channels.
438 | dilations: Dilations to apply in the first 3x3 layer and the last 3x3 layer.
439 | By default, apply no dilations.
440 | activation: relu or gatu.
441 | If relu, the second layer is relu(W*x)
442 | If gatu, the second layer is tanh(W1*x) * sigmoid(W2*x)
443 | dropout: float, 0.0
444 | Returns:
445 | output: output of 3 layer conv network.
446 | """
447 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
448 |
449 | x = conv_block("conv_block", x, mid_channels=mid_channels,
450 | dilations=dilations, activation=activation,
451 | dropout=dropout)
452 |
453 | # Final layer.
454 | x = conv("zeros", x, apply_actnorm=False, conv_init="zeros",
455 | output_channels=output_channels, dilations=dilations)
456 | return x
457 |
458 |
459 | @add_arg_scope
460 | def additive_coupling(name, x, mid_channels=512, reverse=False,
461 | activation="relu", dropout=0.0):
462 | """Reversible additive coupling layer.
463 | Args:
464 | name: variable scope.
465 | x: 4-D Tensor, shape=(NHWC).
466 | mid_channels: number of channels in the coupling layer.
467 | reverse: Forward or reverse operation.
468 | activation: "relu" or "gatu"
469 | dropout: default, 0.0
470 | Returns:
471 | output: 4-D Tensor, shape=(NHWC)
472 | objective: 0.0
473 | """
474 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
475 | output_channels = get_shape_list(x)[-1] // 2
476 | x1, x2 = tf.split(x, num_or_size_splits=2, axis=-1)
477 |
478 | z1 = x1
479 | shift = conv_stack("nn", x1, mid_channels, output_channels=output_channels,
480 | activation=activation, dropout=dropout)
481 |
482 | if not reverse:
483 | z2 = x2 + shift
484 | else:
485 | z2 = x2 - shift
486 | return tf.concat([z1, z2], axis=3), 0.0
487 |
488 |
489 | @add_arg_scope
490 | def affine_coupling(name, x, mid_channels=512, activation="relu",
491 | reverse=False, dropout=0.0):
492 | """Reversible affine coupling layer.
493 | Args:
494 | name: variable scope.
495 | x: 4-D Tensor.
496 | mid_channels: number of channels in the coupling layer.
497 | activation: Can be either "relu" or "gatu".
498 | reverse: Forward or reverse operation.
499 | dropout: default, 0.0
500 | Returns:
501 | output: x shifted and scaled by an affine transformation.
502 | objective: log-determinant of the jacobian
503 | """
504 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
505 | x_shape = get_shape_list(x)
506 | x1, x2 = tf.split(x, num_or_size_splits=2, axis=-1)
507 |
508 | # scale, shift = NN(x1)
509 | # If reverse:
510 | # z2 = scale * (x2 + shift)
511 | # Else:
512 | # z2 = (x2 / scale) - shift
513 | z1 = x1
514 | log_scale_and_shift = conv_stack(
515 | "nn", x1, mid_channels, x_shape[-1], activation=activation,
516 | dropout=dropout)
517 | shift = log_scale_and_shift[:, :, :, 0::2]
518 | scale = tf.nn.sigmoid(log_scale_and_shift[:, :, :, 1::2] + 2.0)
519 | if not reverse:
520 | z2 = (x2 + shift) * scale
521 | else:
522 | z2 = x2 / scale - shift
523 |
524 | objective = tf.reduce_sum(tf.math.log(scale), axis=[1, 2, 3])
525 | if reverse:
526 | objective *= -1
527 | return tf.concat([z1, z2], axis=3), objective
528 |
529 |
530 | # ===============================================
531 |
532 |
533 | @add_arg_scope
534 | def single_conv_dist(name, x, output_channels=None):
535 | """A 1x1 convolution mapping x to a standard normal distribution at init.
536 | Args:
537 | name: variable scope.
538 | x: 4-D Tensor.
539 | output_channels: number of channels of the mean and std.
540 | """
541 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
542 | x_shape = get_shape_list(x)
543 | if output_channels is None:
544 | output_channels = x_shape[-1]
545 | mean_log_scale = conv("conv2d", x, output_channels=2*output_channels,
546 | conv_init="zeros", apply_actnorm=False)
547 | mean = mean_log_scale[:, :, :, 0::2]
548 | log_scale = mean_log_scale[:, :, :, 1::2]
549 | return tf.distributions.Normal(mean, tf.exp(log_scale))
550 |
551 |
552 | # # ===============================================
553 |
554 |
555 | @add_arg_scope
556 | def revnet_step(name, x, hparams, reverse=True):
557 | """One step of glow generative flow.
558 | Actnorm + invertible 1X1 conv + affine_coupling.
559 | Args:
560 | name: used for variable scope.
561 | x: input
562 | hparams: coupling_width is the only hparam that is being used in
563 | this function.
564 | reverse: forward or reverse pass.
565 | Returns:
566 | z: Output of one step of reversible flow.
567 | """
568 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
569 | if hparams.coupling == "additive":
570 | coupling_layer = functools.partial(
571 | additive_coupling, name="additive", reverse=reverse,
572 | mid_channels=hparams.coupling_width,
573 | activation=hparams.activation,
574 | dropout=hparams.coupling_dropout if hparams.is_training else 0)
575 | else:
576 | coupling_layer = functools.partial(
577 | affine_coupling, name="affine", reverse=reverse,
578 | mid_channels=hparams.coupling_width,
579 | activation=hparams.activation,
580 | dropout=hparams.coupling_dropout if hparams.is_training else 0)
581 |
582 | if "permutation" in hparams and hparams["permutation"] == True:
583 | ops = [
584 | functools.partial(actnorm, name="actnorm", reverse=reverse),
585 | functools.partial(invertible_1x1_conv, name="invertible", reverse=reverse, permutation=True),
586 | coupling_layer]
587 | else:
588 | ops = [
589 | functools.partial(actnorm, name="actnorm", reverse=reverse),
590 | functools.partial(invertible_1x1_conv, name="invertible", reverse=reverse),
591 | coupling_layer]
592 |
593 | if reverse:
594 | ops = ops[::-1]
595 |
596 | objective = 0.0
597 | for op in ops:
598 | x, curr_obj = op(x=x)
599 | objective += curr_obj
600 | return x, objective
601 |
602 |
603 | def revnet(name, x, hparams, reverse=True):
604 | """'hparams.depth' steps of generative flow.
605 | Args:
606 | name: variable scope for the revnet block.
607 | x: 4-D Tensor, shape=(NHWC).
608 | hparams: HParams.
609 | reverse: bool, forward or backward pass.
610 | Returns:
611 | x: 4-D Tensor, shape=(NHWC).
612 | objective: float.
613 | """
614 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
615 | steps = np.arange(hparams.depth)
616 | if reverse:
617 | steps = steps[::-1]
618 |
619 | objective = 0.0
620 | for step in steps:
621 | x, curr_obj = revnet_step(
622 | "revnet_step_%d" % step, x, hparams, reverse=reverse)
623 | objective += curr_obj
624 | return x, objective
625 |
626 | # ===============================================
627 |
628 | @add_arg_scope
629 | def compute_prior(name, z, latent, hparams, condition=False, state=None,
630 | temperature=1.0):
631 | """Distribution on z_t conditioned on z_{t-1} and latent.
632 | Args:
633 | name: variable scope.
634 | z: 4-D Tensor.
635 | latent: optional,
636 | if hparams.latent_dist_encoder == "pointwise", this is a list
637 | of 4-D Tensors of length hparams.num_cond_latents.
638 | else, this is just a 4-D Tensor
639 | The first-three dimensions of the latent should be the same as z.
640 | hparams: next_frame_glow_hparams.
641 | condition: Whether or not to condition the distribution on latent.
642 | state: tf.nn.rnn_cell.LSTMStateTuple.
643 | the current state of a LSTM used to model the distribution. Used
644 | only if hparams.latent_dist_encoder = "conv_lstm".
645 | temperature: float, temperature with which to sample from the Gaussian.
646 | Returns:
647 | prior_dist: instance of tfp.distributions.Normal
648 | state: Returns updated state.
649 | Raises:
650 | ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape
651 | of latent is different from z.
652 | """
653 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
654 | z_shape = get_shape_list(z)
655 | h = tf.zeros(z_shape, dtype=tf.float32)
656 | prior_dist = tf.distributions.Normal(h, tf.exp(h))
657 | return prior_dist, state
658 |
659 |
660 |
661 | @add_arg_scope
662 | def split(name, x, reverse=False, eps=None, eps_std=None, cond_latents=None,
663 | hparams=None, state=None, condition=False, temperature=1.0):
664 | """Splits / concatenates x into x1 and x2 across number of channels.
665 | For the forward pass, x2 is assumed be gaussian,
666 | i.e P(x2 | x1) ~ N(mu, sigma) where mu and sigma are the outputs of
667 | a network conditioned on x1 and optionally on cond_latents.
668 | For the reverse pass, x2 is determined from mu(x1) and sigma(x1).
669 | This is deterministic/stochastic depending on whether eps is provided.
670 | Args:
671 | name: variable scope.
672 | x: 4-D Tensor, shape (NHWC).
673 | reverse: Forward or reverse pass.
674 | eps: If eps is provided, x2 is set to be mu(x1) + eps * sigma(x1).
675 | eps_std: Sample x2 with the provided eps_std.
676 | cond_latents: optionally condition x2 on cond_latents.
677 | hparams: next_frame_glow hparams.
678 | state: tf.nn.rnn_cell.LSTMStateTuple.. Current state of the LSTM over z_2.
679 | Used only when hparams.latent_dist_encoder == "conv_lstm"
680 | condition: bool, Whether or not to condition the distribution on
681 | cond_latents.
682 | temperature: Temperature with which to sample from the gaussian.
683 | Returns:
684 | If reverse:
685 | x: 4-D Tensor, concats input and x2 across channels.
686 | x2: 4-D Tensor, a sample from N(mu(x1), sigma(x1))
687 | Else:
688 | x1: 4-D Tensor, Output of the split operation.
689 | logpb: log-probability of x2 belonging to mu(x1), sigma(x1)
690 | eps: 4-D Tensor, (x2 - mu(x1)) / sigma(x1)
691 | x2: 4-D Tensor, Latent representation at the current level.
692 | state: Current LSTM state.
693 | 4-D Tensor, only if hparams.latent_dist_encoder is set to conv_lstm.
694 | Raises:
695 | ValueError: If latent is provided and shape is not equal to NHW(C/2)
696 | where (NHWC) is the size of x.
697 | """
698 | # TODO(mechcoder) Change the return type to be a dict.
699 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
700 | if not reverse:
701 | x1, x2 = tf.split(x, num_or_size_splits=2, axis=-1)
702 |
703 | # objective: P(x2|x1) ~N(x2 ; NN(x1))
704 | prior_dist, state = compute_prior(
705 | "prior_on_z2", x1, cond_latents, hparams, condition, state=state)
706 | logpb = tf.reduce_sum(prior_dist.log_prob(x2), axis=[1, 2, 3])
707 | eps = get_eps(prior_dist, x2)
708 | return x1, logpb, eps, x2, state
709 | else:
710 | prior_dist, state = compute_prior(
711 | "prior_on_z2", x, cond_latents, hparams, condition, state=state,
712 | temperature=temperature)
713 | if eps is not None:
714 | x2 = set_eps(prior_dist, eps)
715 | elif eps_std is not None:
716 | x2 = eps_std * tf.random_normal(get_shape_list(x))
717 | else:
718 | x2 = prior_dist.sample()
719 | return tf.concat([x, x2], 3), x2, state
720 |
721 |
722 | @add_arg_scope
723 | def squeeze(name, x, factor=2, reverse=True):
724 | """Block-wise spatial squeezing of x to increase the number of channels.
725 | Args:
726 | name: Used for variable scoping.
727 | x: 4-D Tensor of shape (batch_size X H X W X C)
728 | factor: Factor by which the spatial dimensions should be squeezed.
729 | reverse: Squueze or unsqueeze operation.
730 | Returns:
731 | x: 4-D Tensor of shape (batch_size X (H//factor) X (W//factor) X
732 | (cXfactor^2). If reverse is True, then it is factor = (1 / factor)
733 | """
734 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
735 | shape = get_shape_list(x)
736 | if factor == 1:
737 | return x
738 | height = int(shape[1])
739 | width = int(shape[2])
740 | n_channels = int(shape[3])
741 |
742 | if not reverse:
743 | assert height % factor == 0 and width % factor == 0
744 | x = tf.reshape(x, [-1, height//factor, factor,
745 | width//factor, factor, n_channels])
746 | x = tf.transpose(x, [0, 1, 3, 5, 2, 4])
747 | x = tf.reshape(x, [-1, height//factor, width //
748 | factor, n_channels*factor*factor])
749 | else:
750 | x = tf.reshape(
751 | x, (-1, height, width, int(n_channels/factor**2), factor, factor))
752 | x = tf.transpose(x, [0, 1, 4, 2, 5, 3])
753 | x = tf.reshape(x, (-1, int(height*factor),
754 | int(width*factor), int(n_channels/factor**2)))
755 | return x
756 |
757 |
758 | def get_cond_latents_at_level(cond_latents, level, hparams):
759 | """Returns a single or list of conditional latents at level 'level'."""
760 | if cond_latents:
761 | if hparams.latent_dist_encoder in ["conv_net", "conv3d_net"]:
762 | return [cond_latent[level] for cond_latent in cond_latents]
763 | elif hparams.latent_dist_encoder in ["pointwise", "conv_lstm"]:
764 | return cond_latents[level]
765 |
766 |
767 | def check_cond_latents(cond_latents, hparams):
768 | """Shape checking for cond_latents."""
769 | if cond_latents is None:
770 | return
771 | if not isinstance(cond_latents[0], list):
772 | cond_latents = [cond_latents]
773 | exp_num_latents = hparams.num_cond_latents
774 | if hparams.latent_dist_encoder == "conv_net":
775 | exp_num_latents += int(hparams.cond_first_frame)
776 | if len(cond_latents) != exp_num_latents:
777 | raise ValueError("Expected number of cond_latents: %d, got %d" %
778 | (exp_num_latents, len(cond_latents)))
779 | for cond_latent in cond_latents:
780 | if len(cond_latent) != hparams.n_levels - 1:
781 | raise ValueError("Expected level_latents to be %d, got %d" %
782 | (hparams.n_levels - 1, len(cond_latent)))
783 |
784 |
785 | @add_arg_scope
786 | def encoder_decoder(name, x, hparams, eps=None, reverse=False,
787 | cond_latents=None, condition=False, states=None,
788 | temperature=1.0):
789 | """Glow encoder-decoder. n_levels of (Squeeze + Flow + Split.) operations.
790 | Args:
791 | name: variable scope.
792 | x: 4-D Tensor, shape=(NHWC).
793 | hparams: HParams.
794 | eps: Stores (glow(x) - mu) / sigma during the forward pass.
795 | Used only to test if the network is reversible.
796 | reverse: Forward or reverse pass.
797 | cond_latents: list of lists of tensors.
798 | outer length equals hparams.num_cond_latents
799 | innter length equals hparams.num_levels - 1.
800 | condition: If set to True, condition the encoder/decoder on cond_latents.
801 | states: LSTM states, used only if hparams.latent_dist_encoder is set
802 | to "conv_lstm.
803 | temperature: Temperature set during sampling.
804 | Returns:
805 | x: If reverse, decoded image, else the encoded glow latent representation.
806 | objective: log-likelihood.
807 | eps: list of tensors, shape=(num_levels-1).
808 | Stores (glow(x) - mu_level(x)) / sigma_level(x)) for each level.
809 | all_latents: list of tensors, shape=(num_levels-1).
810 | Latent representations for each level.
811 | new_states: list of tensors, shape=(num_levels-1).
812 | useful only if hparams.latent_dist_encoder="conv_lstm", returns
813 | the current state of each level.
814 | """
815 | # TODO(mechcoder) Change return_type to a dict to be backward compatible.
816 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
817 |
818 | if states and len(states) != hparams.n_levels - 1:
819 | raise ValueError("Expected length of states to be %d, got %d" %
820 | (hparams.n_levels - 1, len(states)))
821 | if states is None:
822 | states = [None] * (hparams.n_levels - 1)
823 | if eps and len(eps) != hparams.n_levels - 1:
824 | raise ValueError("Expected length of eps to be %d, got %d" %
825 | (hparams.n_levels - 1, len(eps)))
826 | if eps is None:
827 | eps = [None] * (hparams.n_levels - 1)
828 | check_cond_latents(cond_latents, hparams)
829 |
830 | objective = 0.0
831 | all_eps = []
832 | all_latents = []
833 | new_states = []
834 |
835 | if not reverse:
836 | # Squeeze + Flow + Split
837 | for level in range(hparams.n_levels):
838 | # x = squeeze("squeeze_%d" % level, x, factor=2, reverse=False)
839 |
840 | x, obj = revnet("revnet_%d" % level, x, hparams, reverse=False)
841 | objective += obj
842 |
843 | if level < hparams.n_levels - 1:
844 | curr_cond_latents = get_cond_latents_at_level(
845 | cond_latents, level, hparams)
846 | x, obj, eps, z, state = split("split_%d" % level, x, reverse=False,
847 | cond_latents=curr_cond_latents,
848 | condition=condition,
849 | hparams=hparams, state=states[level])
850 | objective += obj
851 | all_eps.append(eps)
852 | all_latents.append(z)
853 | new_states.append(state)
854 |
855 | return x, objective, all_eps, all_latents, new_states
856 |
857 | else:
858 | for level in reversed(range(hparams.n_levels)):
859 | if level < hparams.n_levels - 1:
860 |
861 | curr_cond_latents = get_cond_latents_at_level(
862 | cond_latents, level, hparams)
863 |
864 | x, latent, state = split("split_%d" % level, x, eps=eps[level],
865 | reverse=True, cond_latents=curr_cond_latents,
866 | condition=condition, hparams=hparams,
867 | state=states[level],
868 | temperature=temperature)
869 | new_states.append(state)
870 | all_latents.append(latent)
871 |
872 | x, obj = revnet(
873 | "revnet_%d" % level, x, hparams=hparams, reverse=True)
874 | objective += obj
875 | # x = squeeze("squeeze_%d" % level, x, reverse=True)
876 | return x, objective, all_latents[::-1], new_states[::-1]
877 |
878 |
879 | # ===============================================
880 |
881 |
882 | @add_arg_scope
883 | def top_prior(name, z_shape, learn_prior="normal", temperature=1.0):
884 | """Unconditional prior distribution.
885 | Args:
886 | name: variable scope
887 | z_shape: Shape of the mean / scale of the prior distribution.
888 | learn_prior: Possible options are "normal" and "single_conv".
889 | If set to "single_conv", the gaussian is parametrized by a
890 | single convolutional layer whose input are an array of zeros
891 | and initialized such that the mean and std are zero and one.
892 | If set to "normal", the prior is just a Gaussian with zero
893 | mean and unit variance.
894 | temperature: Temperature with which to sample from the Gaussian.
895 | Returns:
896 | objective: 1-D Tensor shape=(batch_size,) summed across spatial components.
897 | Raises:
898 | ValueError: If learn_prior not in "normal" or "single_conv"
899 | """
900 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
901 | h = tf.zeros(z_shape, dtype=tf.float32)
902 | prior_dist = tf.distributions.Normal(h, tf.exp(h))
903 | return prior_dist
--------------------------------------------------------------------------------
/img/bert-flow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bohanli/BERT-flow/7fa8f6d4a1a73e2c2f8549799d9bafc9d6048d67/img/bert-flow.png
--------------------------------------------------------------------------------
/modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The main BERT model and related functions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import copy
23 | import json
24 | import math
25 | import re
26 | import numpy as np
27 | import six
28 | import tensorflow as tf
29 |
30 |
31 | class BertConfig(object):
32 | """Configuration for `BertModel`."""
33 |
34 | def __init__(self,
35 | vocab_size,
36 | hidden_size=768,
37 | num_hidden_layers=12,
38 | num_attention_heads=12,
39 | intermediate_size=3072,
40 | hidden_act="gelu",
41 | hidden_dropout_prob=0.1,
42 | attention_probs_dropout_prob=0.1,
43 | max_position_embeddings=512,
44 | type_vocab_size=16,
45 | initializer_range=0.02):
46 | """Constructs BertConfig.
47 |
48 | Args:
49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
50 | hidden_size: Size of the encoder layers and the pooler layer.
51 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
52 | num_attention_heads: Number of attention heads for each attention layer in
53 | the Transformer encoder.
54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
55 | layer in the Transformer encoder.
56 | hidden_act: The non-linear activation function (function or string) in the
57 | encoder and pooler.
58 | hidden_dropout_prob: The dropout probability for all fully connected
59 | layers in the embeddings, encoder, and pooler.
60 | attention_probs_dropout_prob: The dropout ratio for the attention
61 | probabilities.
62 | max_position_embeddings: The maximum sequence length that this model might
63 | ever be used with. Typically set this to something large just in case
64 | (e.g., 512 or 1024 or 2048).
65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
66 | `BertModel`.
67 | initializer_range: The stdev of the truncated_normal_initializer for
68 | initializing all weight matrices.
69 | """
70 | self.vocab_size = vocab_size
71 | self.hidden_size = hidden_size
72 | self.num_hidden_layers = num_hidden_layers
73 | self.num_attention_heads = num_attention_heads
74 | self.hidden_act = hidden_act
75 | self.intermediate_size = intermediate_size
76 | self.hidden_dropout_prob = hidden_dropout_prob
77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
78 | self.max_position_embeddings = max_position_embeddings
79 | self.type_vocab_size = type_vocab_size
80 | self.initializer_range = initializer_range
81 |
82 | @classmethod
83 | def from_dict(cls, json_object):
84 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
85 | config = BertConfig(vocab_size=None)
86 | for (key, value) in six.iteritems(json_object):
87 | config.__dict__[key] = value
88 | return config
89 |
90 | @classmethod
91 | def from_json_file(cls, json_file):
92 | """Constructs a `BertConfig` from a json file of parameters."""
93 | with tf.gfile.GFile(json_file, "r") as reader:
94 | text = reader.read()
95 | return cls.from_dict(json.loads(text))
96 |
97 | def to_dict(self):
98 | """Serializes this instance to a Python dictionary."""
99 | output = copy.deepcopy(self.__dict__)
100 | return output
101 |
102 | def to_json_string(self):
103 | """Serializes this instance to a JSON string."""
104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
105 |
106 |
107 | class BertModel(object):
108 | """BERT model ("Bidirectional Encoder Representations from Transformers").
109 |
110 | Example usage:
111 |
112 | ```python
113 | # Already been converted into WordPiece token ids
114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
117 |
118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
120 |
121 | model = modeling.BertModel(config=config, is_training=True,
122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
123 |
124 | label_embeddings = tf.get_variable(...)
125 | pooled_output = model.get_pooled_output()
126 | logits = tf.matmul(pooled_output, label_embeddings)
127 | ...
128 | ```
129 | """
130 |
131 | def __init__(self,
132 | config,
133 | is_training,
134 | input_ids,
135 | input_mask=None,
136 | token_type_ids=None,
137 | use_one_hot_embeddings=False,
138 | scope=None):
139 | """Constructor for BertModel.
140 |
141 | Args:
142 | config: `BertConfig` instance.
143 | is_training: bool. true for training model, false for eval model. Controls
144 | whether dropout will be applied.
145 | input_ids: int32 Tensor of shape [batch_size, seq_length].
146 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
147 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
148 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
149 | embeddings or tf.embedding_lookup() for the word embeddings.
150 | scope: (optional) variable scope. Defaults to "bert".
151 |
152 | Raises:
153 | ValueError: The config is invalid or one of the input tensor shapes
154 | is invalid.
155 | """
156 | config = copy.deepcopy(config)
157 | if not is_training:
158 | config.hidden_dropout_prob = 0.0
159 | config.attention_probs_dropout_prob = 0.0
160 |
161 | input_shape = get_shape_list(input_ids, expected_rank=2)
162 | batch_size = input_shape[0]
163 | seq_length = input_shape[1]
164 |
165 | if input_mask is None:
166 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
167 |
168 | if token_type_ids is None:
169 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
170 |
171 | with tf.variable_scope(scope, default_name="bert"):
172 | with tf.variable_scope("embeddings"):
173 | # Perform embedding lookup on the word ids.
174 | (self.embedding_output, self.embedding_table) = embedding_lookup(
175 | input_ids=input_ids,
176 | vocab_size=config.vocab_size,
177 | embedding_size=config.hidden_size,
178 | initializer_range=config.initializer_range,
179 | word_embedding_name="word_embeddings",
180 | use_one_hot_embeddings=use_one_hot_embeddings)
181 |
182 | # Add positional embeddings and token type embeddings, then layer
183 | # normalize and perform dropout.
184 | self.embedding_output = embedding_postprocessor(
185 | input_tensor=self.embedding_output,
186 | use_token_type=True,
187 | token_type_ids=token_type_ids,
188 | token_type_vocab_size=config.type_vocab_size,
189 | token_type_embedding_name="token_type_embeddings",
190 | use_position_embeddings=True,
191 | position_embedding_name="position_embeddings",
192 | initializer_range=config.initializer_range,
193 | max_position_embeddings=config.max_position_embeddings,
194 | dropout_prob=config.hidden_dropout_prob)
195 |
196 | with tf.variable_scope("encoder"):
197 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
198 | # mask of shape [batch_size, seq_length, seq_length] which is used
199 | # for the attention scores.
200 | attention_mask = create_attention_mask_from_input_mask(
201 | input_ids, input_mask)
202 |
203 | # Run the stacked transformer.
204 | # `sequence_output` shape = [batch_size, seq_length, hidden_size].
205 | self.all_encoder_layers = transformer_model(
206 | input_tensor=self.embedding_output,
207 | attention_mask=attention_mask,
208 | hidden_size=config.hidden_size,
209 | num_hidden_layers=config.num_hidden_layers,
210 | num_attention_heads=config.num_attention_heads,
211 | intermediate_size=config.intermediate_size,
212 | intermediate_act_fn=get_activation(config.hidden_act),
213 | hidden_dropout_prob=config.hidden_dropout_prob,
214 | attention_probs_dropout_prob=config.attention_probs_dropout_prob,
215 | initializer_range=config.initializer_range,
216 | do_return_all_layers=True)
217 |
218 | self.sequence_output = self.all_encoder_layers[-1]
219 | # The "pooler" converts the encoded sequence tensor of shape
220 | # [batch_size, seq_length, hidden_size] to a tensor of shape
221 | # [batch_size, hidden_size]. This is necessary for segment-level
222 | # (or segment-pair-level) classification tasks where we need a fixed
223 | # dimensional representation of the segment.
224 | with tf.variable_scope("pooler"):
225 | # We "pool" the model by simply taking the hidden state corresponding
226 | # to the first token. We assume that this has been pre-trained
227 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
228 | self.pooled_output = tf.layers.dense(
229 | first_token_tensor,
230 | config.hidden_size,
231 | activation=tf.tanh,
232 | kernel_initializer=create_initializer(config.initializer_range))
233 |
234 | def get_pooled_output(self):
235 | return self.pooled_output
236 |
237 | def get_sequence_output(self):
238 | """Gets final hidden layer of encoder.
239 |
240 | Returns:
241 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
242 | to the final hidden of the transformer encoder.
243 | """
244 | return self.sequence_output
245 |
246 | def get_all_encoder_layers(self):
247 | return self.all_encoder_layers
248 |
249 | def get_embedding_output(self):
250 | """Gets output of the embedding lookup (i.e., input to the transformer).
251 |
252 | Returns:
253 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
254 | to the output of the embedding layer, after summing the word
255 | embeddings with the positional embeddings and the token type embeddings,
256 | then performing layer normalization. This is the input to the transformer.
257 | """
258 | return self.embedding_output
259 |
260 | def get_embedding_table(self):
261 | return self.embedding_table
262 |
263 |
264 | def gelu(x):
265 | """Gaussian Error Linear Unit.
266 |
267 | This is a smoother version of the RELU.
268 | Original paper: https://arxiv.org/abs/1606.08415
269 | Args:
270 | x: float Tensor to perform activation.
271 |
272 | Returns:
273 | `x` with the GELU activation applied.
274 | """
275 | cdf = 0.5 * (1.0 + tf.tanh(
276 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
277 | return x * cdf
278 |
279 |
280 | def get_activation(activation_string):
281 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
282 |
283 | Args:
284 | activation_string: String name of the activation function.
285 |
286 | Returns:
287 | A Python function corresponding to the activation function. If
288 | `activation_string` is None, empty, or "linear", this will return None.
289 | If `activation_string` is not a string, it will return `activation_string`.
290 |
291 | Raises:
292 | ValueError: The `activation_string` does not correspond to a known
293 | activation.
294 | """
295 |
296 | # We assume that anything that"s not a string is already an activation
297 | # function, so we just return it.
298 | if not isinstance(activation_string, six.string_types):
299 | return activation_string
300 |
301 | if not activation_string:
302 | return None
303 |
304 | act = activation_string.lower()
305 | if act == "linear":
306 | return None
307 | elif act == "relu":
308 | return tf.nn.relu
309 | elif act == "gelu":
310 | return gelu
311 | elif act == "tanh":
312 | return tf.tanh
313 | else:
314 | raise ValueError("Unsupported activation: %s" % act)
315 |
316 |
317 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
318 | """Compute the union of the current variables and checkpoint variables."""
319 | assignment_map = {}
320 | initialized_variable_names = {}
321 |
322 | name_to_variable = collections.OrderedDict()
323 | for var in tvars:
324 | name = var.name
325 | m = re.match("^(.*):\\d+$", name)
326 | if m is not None:
327 | name = m.group(1)
328 | name_to_variable[name] = var
329 |
330 | init_vars = tf.train.list_variables(init_checkpoint)
331 |
332 | assignment_map = collections.OrderedDict()
333 | for x in init_vars:
334 | (name, var) = (x[0], x[1])
335 | if name not in name_to_variable:
336 | continue
337 | assignment_map[name] = name
338 | initialized_variable_names[name] = 1
339 | initialized_variable_names[name + ":0"] = 1
340 |
341 | return (assignment_map, initialized_variable_names)
342 |
343 |
344 | def dropout(input_tensor, dropout_prob):
345 | """Perform dropout.
346 |
347 | Args:
348 | input_tensor: float Tensor.
349 | dropout_prob: Python float. The probability of dropping out a value (NOT of
350 | *keeping* a dimension as in `tf.nn.dropout`).
351 |
352 | Returns:
353 | A version of `input_tensor` with dropout applied.
354 | """
355 | if dropout_prob is None or dropout_prob == 0.0:
356 | return input_tensor
357 |
358 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
359 | return output
360 |
361 |
362 | def layer_norm(input_tensor, name=None):
363 | """Run layer normalization on the last dimension of the tensor."""
364 | return tf.contrib.layers.layer_norm(
365 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
366 |
367 |
368 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
369 | """Runs layer normalization followed by dropout."""
370 | output_tensor = layer_norm(input_tensor, name)
371 | output_tensor = dropout(output_tensor, dropout_prob)
372 | return output_tensor
373 |
374 |
375 | def create_initializer(initializer_range=0.02):
376 | """Creates a `truncated_normal_initializer` with the given range."""
377 | return tf.truncated_normal_initializer(stddev=initializer_range)
378 |
379 |
380 | def embedding_lookup(input_ids,
381 | vocab_size,
382 | embedding_size=128,
383 | initializer_range=0.02,
384 | word_embedding_name="word_embeddings",
385 | use_one_hot_embeddings=False):
386 | """Looks up words embeddings for id tensor.
387 |
388 | Args:
389 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
390 | ids.
391 | vocab_size: int. Size of the embedding vocabulary.
392 | embedding_size: int. Width of the word embeddings.
393 | initializer_range: float. Embedding initialization range.
394 | word_embedding_name: string. Name of the embedding table.
395 | use_one_hot_embeddings: bool. If True, use one-hot method for word
396 | embeddings. If False, use `tf.gather()`.
397 |
398 | Returns:
399 | float Tensor of shape [batch_size, seq_length, embedding_size].
400 | """
401 | # This function assumes that the input is of shape [batch_size, seq_length,
402 | # num_inputs].
403 | #
404 | # If the input is a 2D tensor of shape [batch_size, seq_length], we
405 | # reshape to [batch_size, seq_length, 1].
406 | if input_ids.shape.ndims == 2:
407 | input_ids = tf.expand_dims(input_ids, axis=[-1])
408 |
409 | embedding_table = tf.get_variable(
410 | name=word_embedding_name,
411 | shape=[vocab_size, embedding_size],
412 | initializer=create_initializer(initializer_range))
413 |
414 | flat_input_ids = tf.reshape(input_ids, [-1])
415 | if use_one_hot_embeddings:
416 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
417 | output = tf.matmul(one_hot_input_ids, embedding_table)
418 | else:
419 | output = tf.gather(embedding_table, flat_input_ids)
420 |
421 | input_shape = get_shape_list(input_ids)
422 |
423 | output = tf.reshape(output,
424 | input_shape[0:-1] + [input_shape[-1] * embedding_size])
425 | return (output, embedding_table)
426 |
427 |
428 | def embedding_postprocessor(input_tensor,
429 | use_token_type=False,
430 | token_type_ids=None,
431 | token_type_vocab_size=16,
432 | token_type_embedding_name="token_type_embeddings",
433 | use_position_embeddings=True,
434 | position_embedding_name="position_embeddings",
435 | initializer_range=0.02,
436 | max_position_embeddings=512,
437 | dropout_prob=0.1):
438 | """Performs various post-processing on a word embedding tensor.
439 |
440 | Args:
441 | input_tensor: float Tensor of shape [batch_size, seq_length,
442 | embedding_size].
443 | use_token_type: bool. Whether to add embeddings for `token_type_ids`.
444 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
445 | Must be specified if `use_token_type` is True.
446 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
447 | token_type_embedding_name: string. The name of the embedding table variable
448 | for token type ids.
449 | use_position_embeddings: bool. Whether to add position embeddings for the
450 | position of each token in the sequence.
451 | position_embedding_name: string. The name of the embedding table variable
452 | for positional embeddings.
453 | initializer_range: float. Range of the weight initialization.
454 | max_position_embeddings: int. Maximum sequence length that might ever be
455 | used with this model. This can be longer than the sequence length of
456 | input_tensor, but cannot be shorter.
457 | dropout_prob: float. Dropout probability applied to the final output tensor.
458 |
459 | Returns:
460 | float tensor with same shape as `input_tensor`.
461 |
462 | Raises:
463 | ValueError: One of the tensor shapes or input values is invalid.
464 | """
465 | input_shape = get_shape_list(input_tensor, expected_rank=3)
466 | batch_size = input_shape[0]
467 | seq_length = input_shape[1]
468 | width = input_shape[2]
469 |
470 | output = input_tensor
471 |
472 | if use_token_type:
473 | if token_type_ids is None:
474 | raise ValueError("`token_type_ids` must be specified if"
475 | "`use_token_type` is True.")
476 | token_type_table = tf.get_variable(
477 | name=token_type_embedding_name,
478 | shape=[token_type_vocab_size, width],
479 | initializer=create_initializer(initializer_range))
480 | # This vocab will be small so we always do one-hot here, since it is always
481 | # faster for a small vocabulary.
482 | flat_token_type_ids = tf.reshape(token_type_ids, [-1])
483 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
484 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
485 | token_type_embeddings = tf.reshape(token_type_embeddings,
486 | [batch_size, seq_length, width])
487 | output += token_type_embeddings
488 |
489 | if use_position_embeddings:
490 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
491 | with tf.control_dependencies([assert_op]):
492 | full_position_embeddings = tf.get_variable(
493 | name=position_embedding_name,
494 | shape=[max_position_embeddings, width],
495 | initializer=create_initializer(initializer_range))
496 | # Since the position embedding table is a learned variable, we create it
497 | # using a (long) sequence length `max_position_embeddings`. The actual
498 | # sequence length might be shorter than this, for faster training of
499 | # tasks that do not have long sequences.
500 | #
501 | # So `full_position_embeddings` is effectively an embedding table
502 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
503 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
504 | # perform a slice.
505 | position_embeddings = tf.slice(full_position_embeddings, [0, 0],
506 | [seq_length, -1])
507 | num_dims = len(output.shape.as_list())
508 |
509 | # Only the last two dimensions are relevant (`seq_length` and `width`), so
510 | # we broadcast among the first dimensions, which is typically just
511 | # the batch size.
512 | position_broadcast_shape = []
513 | for _ in range(num_dims - 2):
514 | position_broadcast_shape.append(1)
515 | position_broadcast_shape.extend([seq_length, width])
516 | position_embeddings = tf.reshape(position_embeddings,
517 | position_broadcast_shape)
518 | output += position_embeddings
519 |
520 | output = layer_norm_and_dropout(output, dropout_prob)
521 | return output
522 |
523 |
524 | def create_attention_mask_from_input_mask(from_tensor, to_mask):
525 | """Create 3D attention mask from a 2D tensor mask.
526 |
527 | Args:
528 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
529 | to_mask: int32 Tensor of shape [batch_size, to_seq_length].
530 |
531 | Returns:
532 | float Tensor of shape [batch_size, from_seq_length, to_seq_length].
533 | """
534 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
535 | batch_size = from_shape[0]
536 | from_seq_length = from_shape[1]
537 |
538 | to_shape = get_shape_list(to_mask, expected_rank=2)
539 | to_seq_length = to_shape[1]
540 |
541 | to_mask = tf.cast(
542 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
543 |
544 | # We don't assume that `from_tensor` is a mask (although it could be). We
545 | # don't actually care if we attend *from* padding tokens (only *to* padding)
546 | # tokens so we create a tensor of all ones.
547 | #
548 | # `broadcast_ones` = [batch_size, from_seq_length, 1]
549 | broadcast_ones = tf.ones(
550 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
551 |
552 | # Here we broadcast along two dimensions to create the mask.
553 | mask = broadcast_ones * to_mask
554 |
555 | return mask
556 |
557 |
558 | def attention_layer(from_tensor,
559 | to_tensor,
560 | attention_mask=None,
561 | num_attention_heads=1,
562 | size_per_head=512,
563 | query_act=None,
564 | key_act=None,
565 | value_act=None,
566 | attention_probs_dropout_prob=0.0,
567 | initializer_range=0.02,
568 | do_return_2d_tensor=False,
569 | batch_size=None,
570 | from_seq_length=None,
571 | to_seq_length=None):
572 | """Performs multi-headed attention from `from_tensor` to `to_tensor`.
573 |
574 | This is an implementation of multi-headed attention based on "Attention
575 | is all you Need". If `from_tensor` and `to_tensor` are the same, then
576 | this is self-attention. Each timestep in `from_tensor` attends to the
577 | corresponding sequence in `to_tensor`, and returns a fixed-with vector.
578 |
579 | This function first projects `from_tensor` into a "query" tensor and
580 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list
581 | of tensors of length `num_attention_heads`, where each tensor is of shape
582 | [batch_size, seq_length, size_per_head].
583 |
584 | Then, the query and key tensors are dot-producted and scaled. These are
585 | softmaxed to obtain attention probabilities. The value tensors are then
586 | interpolated by these probabilities, then concatenated back to a single
587 | tensor and returned.
588 |
589 | In practice, the multi-headed attention are done with transposes and
590 | reshapes rather than actual separate tensors.
591 |
592 | Args:
593 | from_tensor: float Tensor of shape [batch_size, from_seq_length,
594 | from_width].
595 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
596 | attention_mask: (optional) int32 Tensor of shape [batch_size,
597 | from_seq_length, to_seq_length]. The values should be 1 or 0. The
598 | attention scores will effectively be set to -infinity for any positions in
599 | the mask that are 0, and will be unchanged for positions that are 1.
600 | num_attention_heads: int. Number of attention heads.
601 | size_per_head: int. Size of each attention head.
602 | query_act: (optional) Activation function for the query transform.
603 | key_act: (optional) Activation function for the key transform.
604 | value_act: (optional) Activation function for the value transform.
605 | attention_probs_dropout_prob: (optional) float. Dropout probability of the
606 | attention probabilities.
607 | initializer_range: float. Range of the weight initializer.
608 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
609 | * from_seq_length, num_attention_heads * size_per_head]. If False, the
610 | output will be of shape [batch_size, from_seq_length, num_attention_heads
611 | * size_per_head].
612 | batch_size: (Optional) int. If the input is 2D, this might be the batch size
613 | of the 3D version of the `from_tensor` and `to_tensor`.
614 | from_seq_length: (Optional) If the input is 2D, this might be the seq length
615 | of the 3D version of the `from_tensor`.
616 | to_seq_length: (Optional) If the input is 2D, this might be the seq length
617 | of the 3D version of the `to_tensor`.
618 |
619 | Returns:
620 | float Tensor of shape [batch_size, from_seq_length,
621 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
622 | true, this will be of shape [batch_size * from_seq_length,
623 | num_attention_heads * size_per_head]).
624 |
625 | Raises:
626 | ValueError: Any of the arguments or tensor shapes are invalid.
627 | """
628 |
629 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
630 | seq_length, width):
631 | output_tensor = tf.reshape(
632 | input_tensor, [batch_size, seq_length, num_attention_heads, width])
633 |
634 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
635 | return output_tensor
636 |
637 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
638 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
639 |
640 | if len(from_shape) != len(to_shape):
641 | raise ValueError(
642 | "The rank of `from_tensor` must match the rank of `to_tensor`.")
643 |
644 | if len(from_shape) == 3:
645 | batch_size = from_shape[0]
646 | from_seq_length = from_shape[1]
647 | to_seq_length = to_shape[1]
648 | elif len(from_shape) == 2:
649 | if (batch_size is None or from_seq_length is None or to_seq_length is None):
650 | raise ValueError(
651 | "When passing in rank 2 tensors to attention_layer, the values "
652 | "for `batch_size`, `from_seq_length`, and `to_seq_length` "
653 | "must all be specified.")
654 |
655 | # Scalar dimensions referenced here:
656 | # B = batch size (number of sequences)
657 | # F = `from_tensor` sequence length
658 | # T = `to_tensor` sequence length
659 | # N = `num_attention_heads`
660 | # H = `size_per_head`
661 |
662 | from_tensor_2d = reshape_to_matrix(from_tensor)
663 | to_tensor_2d = reshape_to_matrix(to_tensor)
664 |
665 | # `query_layer` = [B*F, N*H]
666 | query_layer = tf.layers.dense(
667 | from_tensor_2d,
668 | num_attention_heads * size_per_head,
669 | activation=query_act,
670 | name="query",
671 | kernel_initializer=create_initializer(initializer_range))
672 |
673 | # `key_layer` = [B*T, N*H]
674 | key_layer = tf.layers.dense(
675 | to_tensor_2d,
676 | num_attention_heads * size_per_head,
677 | activation=key_act,
678 | name="key",
679 | kernel_initializer=create_initializer(initializer_range))
680 |
681 | # `value_layer` = [B*T, N*H]
682 | value_layer = tf.layers.dense(
683 | to_tensor_2d,
684 | num_attention_heads * size_per_head,
685 | activation=value_act,
686 | name="value",
687 | kernel_initializer=create_initializer(initializer_range))
688 |
689 | # `query_layer` = [B, N, F, H]
690 | query_layer = transpose_for_scores(query_layer, batch_size,
691 | num_attention_heads, from_seq_length,
692 | size_per_head)
693 |
694 | # `key_layer` = [B, N, T, H]
695 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
696 | to_seq_length, size_per_head)
697 |
698 | # Take the dot product between "query" and "key" to get the raw
699 | # attention scores.
700 | # `attention_scores` = [B, N, F, T]
701 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
702 | attention_scores = tf.multiply(attention_scores,
703 | 1.0 / math.sqrt(float(size_per_head)))
704 |
705 | if attention_mask is not None:
706 | # `attention_mask` = [B, 1, F, T]
707 | attention_mask = tf.expand_dims(attention_mask, axis=[1])
708 |
709 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
710 | # masked positions, this operation will create a tensor which is 0.0 for
711 | # positions we want to attend and -10000.0 for masked positions.
712 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
713 |
714 | # Since we are adding it to the raw scores before the softmax, this is
715 | # effectively the same as removing these entirely.
716 | attention_scores += adder
717 |
718 | # Normalize the attention scores to probabilities.
719 | # `attention_probs` = [B, N, F, T]
720 | attention_probs = tf.nn.softmax(attention_scores)
721 |
722 | # This is actually dropping out entire tokens to attend to, which might
723 | # seem a bit unusual, but is taken from the original Transformer paper.
724 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
725 |
726 | # `value_layer` = [B, T, N, H]
727 | value_layer = tf.reshape(
728 | value_layer,
729 | [batch_size, to_seq_length, num_attention_heads, size_per_head])
730 |
731 | # `value_layer` = [B, N, T, H]
732 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
733 |
734 | # `context_layer` = [B, N, F, H]
735 | context_layer = tf.matmul(attention_probs, value_layer)
736 |
737 | # `context_layer` = [B, F, N, H]
738 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
739 |
740 | if do_return_2d_tensor:
741 | # `context_layer` = [B*F, N*H]
742 | context_layer = tf.reshape(
743 | context_layer,
744 | [batch_size * from_seq_length, num_attention_heads * size_per_head])
745 | else:
746 | # `context_layer` = [B, F, N*H]
747 | context_layer = tf.reshape(
748 | context_layer,
749 | [batch_size, from_seq_length, num_attention_heads * size_per_head])
750 |
751 | return context_layer
752 |
753 |
754 | def transformer_model(input_tensor,
755 | attention_mask=None,
756 | hidden_size=768,
757 | num_hidden_layers=12,
758 | num_attention_heads=12,
759 | intermediate_size=3072,
760 | intermediate_act_fn=gelu,
761 | hidden_dropout_prob=0.1,
762 | attention_probs_dropout_prob=0.1,
763 | initializer_range=0.02,
764 | do_return_all_layers=False):
765 | """Multi-headed, multi-layer Transformer from "Attention is All You Need".
766 |
767 | This is almost an exact implementation of the original Transformer encoder.
768 |
769 | See the original paper:
770 | https://arxiv.org/abs/1706.03762
771 |
772 | Also see:
773 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
774 |
775 | Args:
776 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
777 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
778 | seq_length], with 1 for positions that can be attended to and 0 in
779 | positions that should not be.
780 | hidden_size: int. Hidden size of the Transformer.
781 | num_hidden_layers: int. Number of layers (blocks) in the Transformer.
782 | num_attention_heads: int. Number of attention heads in the Transformer.
783 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed
784 | forward) layer.
785 | intermediate_act_fn: function. The non-linear activation function to apply
786 | to the output of the intermediate/feed-forward layer.
787 | hidden_dropout_prob: float. Dropout probability for the hidden layers.
788 | attention_probs_dropout_prob: float. Dropout probability of the attention
789 | probabilities.
790 | initializer_range: float. Range of the initializer (stddev of truncated
791 | normal).
792 | do_return_all_layers: Whether to also return all layers or just the final
793 | layer.
794 |
795 | Returns:
796 | float Tensor of shape [batch_size, seq_length, hidden_size], the final
797 | hidden layer of the Transformer.
798 |
799 | Raises:
800 | ValueError: A Tensor shape or parameter is invalid.
801 | """
802 | if hidden_size % num_attention_heads != 0:
803 | raise ValueError(
804 | "The hidden size (%d) is not a multiple of the number of attention "
805 | "heads (%d)" % (hidden_size, num_attention_heads))
806 |
807 | attention_head_size = int(hidden_size / num_attention_heads)
808 | input_shape = get_shape_list(input_tensor, expected_rank=3)
809 | batch_size = input_shape[0]
810 | seq_length = input_shape[1]
811 | input_width = input_shape[2]
812 |
813 | # The Transformer performs sum residuals on all layers so the input needs
814 | # to be the same as the hidden size.
815 | if input_width != hidden_size:
816 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
817 | (input_width, hidden_size))
818 |
819 | # We keep the representation as a 2D tensor to avoid re-shaping it back and
820 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
821 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
822 | # help the optimizer.
823 | prev_output = reshape_to_matrix(input_tensor)
824 |
825 | all_layer_outputs = []
826 | for layer_idx in range(num_hidden_layers):
827 | with tf.variable_scope("layer_%d" % layer_idx):
828 | layer_input = prev_output
829 |
830 | with tf.variable_scope("attention"):
831 | attention_heads = []
832 | with tf.variable_scope("self"):
833 | attention_head = attention_layer(
834 | from_tensor=layer_input,
835 | to_tensor=layer_input,
836 | attention_mask=attention_mask,
837 | num_attention_heads=num_attention_heads,
838 | size_per_head=attention_head_size,
839 | attention_probs_dropout_prob=attention_probs_dropout_prob,
840 | initializer_range=initializer_range,
841 | do_return_2d_tensor=True,
842 | batch_size=batch_size,
843 | from_seq_length=seq_length,
844 | to_seq_length=seq_length)
845 | attention_heads.append(attention_head)
846 |
847 | attention_output = None
848 | if len(attention_heads) == 1:
849 | attention_output = attention_heads[0]
850 | else:
851 | # In the case where we have other sequences, we just concatenate
852 | # them to the self-attention head before the projection.
853 | attention_output = tf.concat(attention_heads, axis=-1)
854 |
855 | # Run a linear projection of `hidden_size` then add a residual
856 | # with `layer_input`.
857 | with tf.variable_scope("output"):
858 | attention_output = tf.layers.dense(
859 | attention_output,
860 | hidden_size,
861 | kernel_initializer=create_initializer(initializer_range))
862 | attention_output = dropout(attention_output, hidden_dropout_prob)
863 | attention_output = layer_norm(attention_output + layer_input)
864 |
865 | # The activation is only applied to the "intermediate" hidden layer.
866 | with tf.variable_scope("intermediate"):
867 | intermediate_output = tf.layers.dense(
868 | attention_output,
869 | intermediate_size,
870 | activation=intermediate_act_fn,
871 | kernel_initializer=create_initializer(initializer_range))
872 |
873 | # Down-project back to `hidden_size` then add the residual.
874 | with tf.variable_scope("output"):
875 | layer_output = tf.layers.dense(
876 | intermediate_output,
877 | hidden_size,
878 | kernel_initializer=create_initializer(initializer_range))
879 | layer_output = dropout(layer_output, hidden_dropout_prob)
880 | layer_output = layer_norm(layer_output + attention_output)
881 | prev_output = layer_output
882 | all_layer_outputs.append(layer_output)
883 |
884 | if do_return_all_layers:
885 | final_outputs = []
886 | for layer_output in all_layer_outputs:
887 | final_output = reshape_from_matrix(layer_output, input_shape)
888 | final_outputs.append(final_output)
889 | return final_outputs
890 | else:
891 | final_output = reshape_from_matrix(prev_output, input_shape)
892 | return final_output
893 |
894 |
895 | def get_shape_list(tensor, expected_rank=None, name=None):
896 | """Returns a list of the shape of tensor, preferring static dimensions.
897 |
898 | Args:
899 | tensor: A tf.Tensor object to find the shape of.
900 | expected_rank: (optional) int. The expected rank of `tensor`. If this is
901 | specified and the `tensor` has a different rank, and exception will be
902 | thrown.
903 | name: Optional name of the tensor for the error message.
904 |
905 | Returns:
906 | A list of dimensions of the shape of tensor. All static dimensions will
907 | be returned as python integers, and dynamic dimensions will be returned
908 | as tf.Tensor scalars.
909 | """
910 | if name is None:
911 | name = tensor.name
912 |
913 | if expected_rank is not None:
914 | assert_rank(tensor, expected_rank, name)
915 |
916 | shape = tensor.shape.as_list()
917 |
918 | non_static_indexes = []
919 | for (index, dim) in enumerate(shape):
920 | if dim is None:
921 | non_static_indexes.append(index)
922 |
923 | if not non_static_indexes:
924 | return shape
925 |
926 | dyn_shape = tf.shape(tensor)
927 | for index in non_static_indexes:
928 | shape[index] = dyn_shape[index]
929 | return shape
930 |
931 |
932 | def reshape_to_matrix(input_tensor):
933 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
934 | ndims = input_tensor.shape.ndims
935 | if ndims < 2:
936 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
937 | (input_tensor.shape))
938 | if ndims == 2:
939 | return input_tensor
940 |
941 | width = input_tensor.shape[-1]
942 | output_tensor = tf.reshape(input_tensor, [-1, width])
943 | return output_tensor
944 |
945 |
946 | def reshape_from_matrix(output_tensor, orig_shape_list):
947 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
948 | if len(orig_shape_list) == 2:
949 | return output_tensor
950 |
951 | output_shape = get_shape_list(output_tensor)
952 |
953 | orig_dims = orig_shape_list[0:-1]
954 | width = output_shape[-1]
955 |
956 | return tf.reshape(output_tensor, orig_dims + [width])
957 |
958 |
959 | def assert_rank(tensor, expected_rank, name=None):
960 | """Raises an exception if the tensor rank is not of the expected rank.
961 |
962 | Args:
963 | tensor: A tf.Tensor to check the rank of.
964 | expected_rank: Python integer or list of integers, expected rank.
965 | name: Optional name of the tensor for the error message.
966 |
967 | Raises:
968 | ValueError: If the expected shape doesn't match the actual shape.
969 | """
970 | if name is None:
971 | name = tensor.name
972 |
973 | expected_rank_dict = {}
974 | if isinstance(expected_rank, six.integer_types):
975 | expected_rank_dict[expected_rank] = True
976 | else:
977 | for x in expected_rank:
978 | expected_rank_dict[x] = True
979 |
980 | actual_rank = tensor.shape.ndims
981 | if actual_rank not in expected_rank_dict:
982 | scope_name = tf.get_variable_scope().name
983 | raise ValueError(
984 | "For the tensor `%s` in scope `%s`, the actual rank "
985 | "`%d` (shape = %s) is not equal to the expected rank `%s`" %
986 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
987 |
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | # Normally the global step update is done inside of `apply_gradients`.
80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81 | # a different optimizer, you should probably take this line out.
82 | new_global_step = global_step + 1
83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
84 | return train_op
85 |
86 |
87 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
88 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
89 |
90 | def __init__(self,
91 | learning_rate,
92 | weight_decay_rate=0.0,
93 | beta_1=0.9,
94 | beta_2=0.999,
95 | epsilon=1e-6,
96 | exclude_from_weight_decay=None,
97 | name="AdamWeightDecayOptimizer"):
98 | """Constructs a AdamWeightDecayOptimizer."""
99 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
100 |
101 | self.learning_rate = learning_rate
102 | self.weight_decay_rate = weight_decay_rate
103 | self.beta_1 = beta_1
104 | self.beta_2 = beta_2
105 | self.epsilon = epsilon
106 | self.exclude_from_weight_decay = exclude_from_weight_decay
107 |
108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
109 | """See base class."""
110 | assignments = []
111 | for (grad, param) in grads_and_vars:
112 | if grad is None or param is None:
113 | continue
114 |
115 | param_name = self._get_variable_name(param.name)
116 |
117 | m = tf.get_variable(
118 | name=param_name + "/adam_m",
119 | shape=param.shape.as_list(),
120 | dtype=tf.float32,
121 | trainable=False,
122 | initializer=tf.zeros_initializer())
123 | v = tf.get_variable(
124 | name=param_name + "/adam_v",
125 | shape=param.shape.as_list(),
126 | dtype=tf.float32,
127 | trainable=False,
128 | initializer=tf.zeros_initializer())
129 |
130 | # Standard Adam update.
131 | next_m = (
132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
133 | next_v = (
134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
135 | tf.square(grad)))
136 |
137 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
138 |
139 | # Just adding the square of the weights to the loss function is *not*
140 | # the correct way of using L2 regularization/weight decay with Adam,
141 | # since that will interact with the m and v parameters in strange ways.
142 | #
143 | # Instead we want ot decay the weights in a manner that doesn't interact
144 | # with the m/v parameters. This is equivalent to adding the square
145 | # of the weights to the loss with plain (non-momentum) SGD.
146 | if self._do_use_weight_decay(param_name):
147 | update += self.weight_decay_rate * param
148 |
149 | update_with_lr = self.learning_rate * update
150 |
151 | next_param = param - update_with_lr
152 |
153 | assignments.extend(
154 | [param.assign(next_param),
155 | m.assign(next_m),
156 | v.assign(next_v)])
157 | return tf.group(*assignments, name=name)
158 |
159 | def _do_use_weight_decay(self, param_name):
160 | """Whether to use L2 weight decay for `param_name`."""
161 | if not self.weight_decay_rate:
162 | return False
163 | if self.exclude_from_weight_decay:
164 | for r in self.exclude_from_weight_decay:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
169 | def _get_variable_name(self, param_name):
170 | """Get the variable name from the tensor name."""
171 | m = re.match("^(.*):\\d+$", param_name)
172 | if m is not None:
173 | param_name = m.group(1)
174 | return param_name
175 |
--------------------------------------------------------------------------------
/optimization_bert_flow.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, flow_loss, init_lr, init_flow_lr,
26 | num_train_steps, num_warmup_steps, use_tpu):
27 | """Creates an optimizer training op."""
28 | global_step = tf.train.get_or_create_global_step()
29 |
30 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
31 |
32 | # Implements linear decay of the learning rate.
33 | learning_rate = tf.train.polynomial_decay(
34 | learning_rate,
35 | global_step,
36 | num_train_steps,
37 | end_learning_rate=0.0,
38 | power=1.0,
39 | cycle=False)
40 |
41 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
42 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
43 | if num_warmup_steps:
44 | global_steps_int = tf.cast(global_step, tf.int32)
45 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
46 |
47 | global_steps_float = tf.cast(global_steps_int, tf.float32)
48 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
49 |
50 | warmup_percent_done = global_steps_float / warmup_steps_float
51 | warmup_learning_rate = init_lr * warmup_percent_done
52 |
53 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
54 | learning_rate = (
55 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
56 |
57 | if type(flow_loss) != "NoneType":
58 | # bert
59 | # It is recommended that you use this optimizer for fine tuning, since this
60 | # is how the model was trained (note that the Adam m/v variables are NOT
61 | # loaded from init_checkpoint.)
62 | optimizer = AdamWeightDecayOptimizer(
63 | learning_rate=learning_rate,
64 | weight_decay_rate=0.01,
65 | beta_1=0.9,
66 | beta_2=0.999,
67 | epsilon=1e-6,
68 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
69 | if use_tpu:
70 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
71 |
72 | tvars = [v for v in tf.trainable_variables() if not v.name.startswith("bert/flow")]
73 | grads = tf.gradients(loss, tvars)
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 | train_op = optimizer.apply_gradients(
76 | zip(grads, tvars), global_step=global_step)
77 |
78 | ########################
79 | # flow
80 | flow_optimizer = AdamWeightDecayOptimizer(
81 | learning_rate=init_flow_lr, #learning_rate / init_lr *
82 | weight_decay_rate=0.01,
83 | beta_1=0.9,
84 | beta_2=0.999,
85 | epsilon=1e-6,
86 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
87 | if use_tpu:
88 | flow_optimizer = tf.contrib.tpu.CrossShardOptimizer(flow_optimizer)
89 |
90 | flow_tvars = tf.trainable_variables("bert/flow")
91 | flow_grads = tf.gradients(flow_loss, flow_tvars)
92 | (flow_grads, _) = tf.clip_by_global_norm(flow_grads, clip_norm=1.0)
93 | flow_train_op = flow_optimizer.apply_gradients(
94 | zip(flow_grads, flow_tvars), global_step=global_step)
95 |
96 | ########################
97 |
98 | # Normally the global step update is done inside of `apply_gradients`.
99 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
100 | # a different optimizer, you should probably take this line out.
101 | new_global_step = global_step + 1
102 | train_op = tf.group(train_op, flow_train_op, [global_step.assign(new_global_step)])
103 | return train_op
104 | else:
105 | raise NotImplementedError
106 |
107 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
108 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
109 |
110 | def __init__(self,
111 | learning_rate,
112 | weight_decay_rate=0.0,
113 | beta_1=0.9,
114 | beta_2=0.999,
115 | epsilon=1e-6,
116 | exclude_from_weight_decay=None,
117 | name="AdamWeightDecayOptimizer"):
118 | """Constructs a AdamWeightDecayOptimizer."""
119 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
120 |
121 | self.learning_rate = learning_rate
122 | self.weight_decay_rate = weight_decay_rate
123 | self.beta_1 = beta_1
124 | self.beta_2 = beta_2
125 | self.epsilon = epsilon
126 | self.exclude_from_weight_decay = exclude_from_weight_decay
127 |
128 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
129 | """See base class."""
130 | assignments = []
131 | for (grad, param) in grads_and_vars:
132 | if grad is None or param is None:
133 | continue
134 |
135 | param_name = self._get_variable_name(param.name)
136 |
137 | m = tf.get_variable(
138 | name=param_name + "/adam_m",
139 | shape=param.shape.as_list(),
140 | dtype=tf.float32,
141 | trainable=False,
142 | initializer=tf.zeros_initializer())
143 | v = tf.get_variable(
144 | name=param_name + "/adam_v",
145 | shape=param.shape.as_list(),
146 | dtype=tf.float32,
147 | trainable=False,
148 | initializer=tf.zeros_initializer())
149 |
150 | # Standard Adam update.
151 | next_m = (
152 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
153 | next_v = (
154 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
155 | tf.square(grad)))
156 |
157 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
158 |
159 | # Just adding the square of the weights to the loss function is *not*
160 | # the correct way of using L2 regularization/weight decay with Adam,
161 | # since that will interact with the m and v parameters in strange ways.
162 | #
163 | # Instead we want ot decay the weights in a manner that doesn't interact
164 | # with the m/v parameters. This is equivalent to adding the square
165 | # of the weights to the loss with plain (non-momentum) SGD.
166 | if self._do_use_weight_decay(param_name):
167 | update += self.weight_decay_rate * param
168 |
169 | update_with_lr = self.learning_rate * update
170 |
171 | next_param = param - update_with_lr
172 |
173 | assignments.extend(
174 | [param.assign(next_param),
175 | m.assign(next_m),
176 | v.assign(next_v)])
177 | return tf.group(*assignments, name=name)
178 |
179 | def _do_use_weight_decay(self, param_name):
180 | """Whether to use L2 weight decay for `param_name`."""
181 | if not self.weight_decay_rate:
182 | return False
183 | if self.exclude_from_weight_decay:
184 | for r in self.exclude_from_weight_decay:
185 | if re.search(r, param_name) is not None:
186 | return False
187 | return True
188 |
189 | def _get_variable_name(self, param_name):
190 | """Get the variable name from the tensor name."""
191 | m = re.match("^(.*):\\d+$", param_name)
192 | if m is not None:
193 | param_name = m.group(1)
194 | return param_name
195 |
--------------------------------------------------------------------------------
/run_siamese.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """BERT finetuning runner for regression tasks
16 | A large portion of the code is adapted from
17 | https://github.com/zihangdai/xlnet/blob/master/run_classifier.py
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import warnings
25 | warnings.simplefilter(action='ignore', category=FutureWarning)
26 |
27 | import os
28 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
29 |
30 | import collections
31 | import csv
32 |
33 | import modeling
34 | import optimization
35 | import tokenization
36 |
37 | import tensorflow as tf
38 |
39 | import random
40 | import numpy as np
41 |
42 | from flow.glow_1x1 import AttrDict, Glow
43 | from flow.glow_init_hook import GlowInitHook
44 | import optimization_bert_flow
45 | import json
46 |
47 | from siamese_utils import StsbProcessor, SickRProcessor, MnliProcessor, QqpProcessor, \
48 | SnliTrainProcessor, SnliDevTestProcessor, \
49 | Sts_12_16_Processor, MrpcRegressionProcessor, QnliRegressionProcessor, \
50 | file_based_convert_examples_to_features, file_based_input_fn_builder, \
51 | get_input_mask_segment
52 |
53 | flags = tf.flags
54 |
55 | FLAGS = flags.FLAGS
56 |
57 | # model
58 | flags.DEFINE_string("bert_config_file", None,
59 | "The config json file corresponding to the pre-trained BERT model. "
60 | "This specifies the model architecture.")
61 | flags.DEFINE_integer("max_seq_length", 128,
62 | "The maximum total input sequence length after WordPiece tokenization. "
63 | "Sequences longer than this will be truncated, and sequences shorter "
64 | "than this will be padded.")
65 | flags.DEFINE_string("init_checkpoint", None,
66 | "Initial checkpoint (usually from a pre-trained BERT model).")
67 | flags.DEFINE_string("vocab_file", None,
68 | "The vocabulary file that the BERT model was trained on.")
69 | flags.DEFINE_bool("do_lower_case", True,
70 | "Whether to lower case the input text. Should be True for uncased "
71 | "models and False for cased models.")
72 |
73 |
74 | # task and data
75 | flags.DEFINE_string("task_name", None, "The name of the task to train.")
76 | flags.DEFINE_string("data_dir", None,
77 | "The input data dir. Should contain the .tsv files (or other data files) "
78 | "for the task.")
79 | flags.DEFINE_float("label_min", 0., None)
80 | flags.DEFINE_float("label_max", 5., None)
81 |
82 | # exp
83 | flags.DEFINE_string("output_parent_dir", None, None)
84 | flags.DEFINE_string("exp_name", None, None)
85 | flags.DEFINE_string("exp_name_prefix", None, None)
86 | flags.DEFINE_integer("log_every_step", 10, None)
87 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
88 | "How often to save the model checkpoint.")
89 | flags.DEFINE_bool("use_xla", False, None)
90 | flags.DEFINE_integer("seed", 1234, None)
91 | flags.DEFINE_string("cached_dir", None,
92 | "Path to cached training and dev tfrecord file. "
93 | "The file will be generated if not exist.")
94 |
95 | # training
96 | flags.DEFINE_bool("do_train", False, None)
97 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
98 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
99 | flags.DEFINE_float("num_train_epochs", 3.0,
100 | "Total number of training epochs to perform.")
101 | flags.DEFINE_float("warmup_proportion", 0.1,
102 | "Proportion of training to perform linear learning rate warmup for. "
103 | "E.g., 0.1 = 10% of training.")
104 | flags.DEFINE_bool("early_stopping", False, None)
105 |
106 | flags.DEFINE_integer("start_delay_secs", 120, "for tf.estimator.EvalSpec")
107 | flags.DEFINE_integer("throttle_secs", 600, "for tf.estimator.EvalSpec")
108 |
109 | # eval
110 | flags.DEFINE_bool("do_eval", False, None)
111 | flags.DEFINE_bool("do_predict", False, None)
112 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
113 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
114 | flags.DEFINE_bool("predict_pool", False, None)
115 | flags.DEFINE_bool("do_predict_on_dev", False, None)
116 | flags.DEFINE_bool("do_predict_on_full", False, None)
117 | flags.DEFINE_string("eval_checkpoint_name", None, "filename of a finetuned checkpoint")
118 | flags.DEFINE_bool("auc", False, None)
119 |
120 | # sentence embedding related parameters
121 | flags.DEFINE_string("sentence_embedding_type", "avg", "avg, cls, ...")
122 |
123 | # flow parameters
124 | flags.DEFINE_integer("flow", 0, "use flow or not")
125 | flags.DEFINE_integer("flow_loss", 0, "use flow loss or not")
126 | flags.DEFINE_float("flow_learning_rate", 1e-3, "The initial learning rate for Adam.")
127 | flags.DEFINE_string("flow_model_config", "config_l3_d3_w32", None)
128 |
129 | # unsupervised or semi-supervised related parameters
130 | flags.DEFINE_integer("num_examples", -1, "# of labeled training examples")
131 | flags.DEFINE_integer("use_full_for_training", 0, None)
132 | flags.DEFINE_integer("dupe_factor", 1, "Number of times to duplicate the input data (with different masks).")
133 |
134 | # nli related parameters
135 | # flags.DEFINE_integer("use_snli_full", 0, "augment MNLI training data with SNLI")
136 | flags.DEFINE_float("l2_penalty", -1, "penalize l2 norm of sentence embeddings")
137 |
138 | # dimension reduction related parameters
139 | flags.DEFINE_integer("low_dim", -1, "avg pooling over the embedding")
140 |
141 | # senteval
142 | flags.DEFINE_bool("do_senteval", False, None)
143 | flags.DEFINE_string("senteval_tasks", "", None)
144 |
145 | def get_embedding(bert_config, is_training,
146 | input_ids, input_mask, segment_ids, scope=None):
147 |
148 | model = modeling.BertModel(
149 | config=bert_config,
150 | is_training=is_training,
151 | input_ids=input_ids,
152 | input_mask=input_mask,
153 | token_type_ids=segment_ids,
154 | scope=scope)
155 |
156 | if FLAGS.sentence_embedding_type == "avg":
157 | sequence = model.get_sequence_output() # [batch_size, seq_length, hidden_size]
158 | input_mask_ = tf.cast(tf.expand_dims(input_mask, axis=-1), dtype=tf.float32)
159 | pooled = tf.reduce_sum(sequence * input_mask_, axis=1) / tf.reduce_sum(input_mask_, axis=1)
160 | elif FLAGS.sentence_embedding_type == "cls":
161 | pooled = model.get_pooled_output()
162 | elif FLAGS.sentence_embedding_type.startswith("avg-last-last-"):
163 | pooled = 0
164 | n_last = int(FLAGS.sentence_embedding_type[-1])
165 | input_mask_ = tf.cast(tf.expand_dims(input_mask, axis=-1), dtype=tf.float32)
166 | sequence = model.all_encoder_layers[-n_last] # [batch_size, seq_length, hidden_size]
167 | pooled += tf.reduce_sum(sequence * input_mask_, axis=1) / tf.reduce_sum(input_mask_, axis=1)
168 | elif FLAGS.sentence_embedding_type.startswith("avg-last-"):
169 | pooled = 0
170 | n_last = int(FLAGS.sentence_embedding_type[-1])
171 | input_mask_ = tf.cast(tf.expand_dims(input_mask, axis=-1), dtype=tf.float32)
172 | for i in range(n_last):
173 | sequence = model.all_encoder_layers[-i] # [batch_size, seq_length, hidden_size]
174 | pooled += tf.reduce_sum(sequence * input_mask_, axis=1) / tf.reduce_sum(input_mask_, axis=1)
175 | pooled /= float(n_last)
176 | elif FLAGS.sentence_embedding_type.startswith("avg-last-concat-"):
177 | pooled = []
178 | n_last = int(FLAGS.sentence_embedding_type[-1])
179 | input_mask_ = tf.cast(tf.expand_dims(input_mask, axis=-1), dtype=tf.float32)
180 | for i in range(n_last):
181 | sequence = model.all_encoder_layers[-i] # [batch_size, seq_length, hidden_size]
182 | pooled += [tf.reduce_sum(sequence * input_mask_, axis=1) / tf.reduce_sum(input_mask_, axis=1)]
183 | pooled = tf.concat(pooled, axis=-1)
184 | else:
185 | raise NotImplementedError
186 |
187 | # flow
188 | embedding = None
189 | flow_loss_batch, flow_loss_example = None, None
190 | if FLAGS.flow:
191 | # load model and train config
192 | with open(os.path.join("./flow/config", FLAGS.flow_model_config + ".json"), 'r') as jp:
193 | flow_model_config = AttrDict(json.load(jp))
194 | flow_model_config.is_training = is_training
195 | flow_model = Glow(flow_model_config)
196 | flow_loss_example = flow_model.body(pooled, is_training) # no l2 normalization here any more
197 | flow_loss_batch = tf.math.reduce_mean(flow_loss_example)
198 | embedding = tf.identity(tf.squeeze(flow_model.z, [1, 2])) # no l2 normalization here any more
199 | else:
200 | embedding = pooled
201 |
202 | if FLAGS.low_dim > 0:
203 | bsz, org_dim = modeling.get_shape_list(embedding)
204 | embedding = tf.reduce_mean(
205 | tf.reshape(embedding, [bsz, FLAGS.low_dim, org_dim // FLAGS.low_dim]), axis=-1)
206 |
207 | return embedding, flow_loss_batch, flow_loss_example
208 |
209 |
210 | def create_model(bert_config, is_regression,
211 | is_training,
212 | input_ids_a, input_mask_a, segment_ids_a,
213 | input_ids_b, input_mask_b, segment_ids_b,
214 | labels, num_labels):
215 | """Creates a classification model."""
216 |
217 | with tf.variable_scope("bert") as scope:
218 | embedding_a, flow_loss_batch_a, flow_loss_example_a = \
219 | get_embedding(bert_config, is_training,
220 | input_ids_a, input_mask_a, segment_ids_a, scope)
221 | with tf.variable_scope("bert", reuse=tf.AUTO_REUSE) as scope:
222 | embedding_b, flow_loss_batch_b, flow_loss_example_b = \
223 | get_embedding(bert_config, is_training,
224 | input_ids_b, input_mask_b, segment_ids_b, scope)
225 |
226 | with tf.variable_scope("loss"):
227 | cos_similarity = tf.reduce_sum(tf.multiply(
228 | tf.nn.l2_normalize(embedding_a, axis=-1),
229 | tf.nn.l2_normalize(embedding_b, axis=-1)), axis=-1)
230 | if is_regression:
231 | # changing cos_similarity into (cos_similarity + 1)/2.0
232 | # leads to large performance decrease in practice
233 | per_example_loss = tf.square(cos_similarity - labels)
234 | loss = tf.reduce_mean(per_example_loss)
235 | logits, predictions = None, None
236 | else:
237 | output_layer = tf.concat([
238 | embedding_a, embedding_b, tf.math.abs(embedding_a - embedding_b)
239 | ], axis=-1)
240 | output_size = output_layer.shape[-1].value
241 | output_weights = tf.get_variable(
242 | "output_weights", [num_labels, output_size],
243 | initializer=tf.truncated_normal_initializer(stddev=0.02))
244 |
245 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
246 |
247 | probabilities = tf.nn.softmax(logits, axis=-1)
248 | predictions = tf.argmax(probabilities, axis=-1, output_type=tf.int32)
249 | log_probs = tf.nn.log_softmax(logits, axis=-1)
250 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
251 |
252 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
253 | loss = tf.reduce_mean(per_example_loss)
254 |
255 | if FLAGS.num_examples == 0:
256 | per_example_loss = tf.zeros_like(per_example_loss)
257 | loss = tf.zeros_like(loss)
258 | elif FLAGS.num_examples > 0:
259 | per_example_loss = per_example_loss * tf.cast(labels > -1, dtype=tf.float32)
260 | loss = tf.reduce_mean(per_example_loss)
261 |
262 | if FLAGS.l2_penalty > 0:
263 | l2_penalty_loss = tf.norm(embedding_a, axis=-1, keepdims=False)
264 | l2_penalty_loss += tf.norm(embedding_b, axis=-1, keepdims=False)
265 | l2_penalty_loss *= FLAGS.l2_penalty
266 |
267 | per_example_loss += l2_penalty_loss
268 | loss += tf.reduce_mean(l2_penalty_loss)
269 |
270 | model_output = {
271 | "loss": loss,
272 | "per_example_loss": per_example_loss,
273 | "cos_similarity": cos_similarity,
274 | "embedding_a": embedding_a,
275 | "embedding_b": embedding_b,
276 | "logits": logits,
277 | "predictions": predictions,
278 | }
279 |
280 | if FLAGS.flow:
281 | model_output["flow_example_loss"] = flow_loss_example_a + flow_loss_example_b
282 | model_output["flow_loss"] = flow_loss_batch_a + flow_loss_batch_b
283 |
284 | return model_output
285 |
286 |
287 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
288 | num_train_steps, num_warmup_steps, is_regression):
289 | """Returns `model_fn` closure for Estimator."""
290 |
291 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
292 | """The `model_fn` for Estimator."""
293 |
294 | tf.logging.info("*** Features ***")
295 | for name in sorted(features.keys()):
296 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
297 |
298 | input_ids_a = features["input_ids_a"]
299 | input_mask_a = features["input_mask_a"]
300 | segment_ids_a = features["segment_ids_a"]
301 |
302 | input_ids_b = features["input_ids_b"]
303 | input_mask_b = features["input_mask_b"]
304 | segment_ids_b = features["segment_ids_b"]
305 |
306 | label_ids = features["label_ids"]
307 | is_real_example = None
308 | if "is_real_example" in features:
309 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
310 | else:
311 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
312 |
313 | #### Training or Evaluation
314 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
315 |
316 | #### Get loss from inputs
317 | model_output = create_model(
318 | bert_config, is_regression,
319 | is_training,
320 | input_ids_a, input_mask_a, segment_ids_a,
321 | input_ids_b, input_mask_b, segment_ids_b,
322 | label_ids,
323 | num_labels)
324 |
325 | tvars = tf.trainable_variables()
326 | initialized_variable_names = {}
327 | if init_checkpoint:
328 | (assignment_map, initialized_variable_names
329 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
330 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
331 |
332 | tf.logging.info("**** Trainable Variables ****")
333 | for var in tvars:
334 | init_string = ""
335 | if var.name in initialized_variable_names:
336 | init_string = ", *INIT_FROM_CKPT*"
337 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
338 | init_string)
339 | # if "flow" in var.name:
340 | # input()
341 |
342 | output_spec = None
343 | if mode == tf.estimator.ModeKeys.TRAIN:
344 | if FLAGS.flow_loss:
345 | train_op = optimization_bert_flow.create_optimizer(
346 | model_output["loss"], model_output["flow_loss"],
347 | learning_rate, FLAGS.flow_learning_rate,
348 | num_train_steps, num_warmup_steps, use_tpu=False)
349 | tf.summary.scalar("loss", model_output["loss"])
350 | tf.summary.scalar("flow_loss", model_output["flow_loss"])
351 |
352 | output_spec = tf.estimator.EstimatorSpec(
353 | mode=mode,
354 | loss=model_output["loss"] + model_output["flow_loss"],
355 | train_op=train_op)
356 | else:
357 | train_op = optimization.create_optimizer(
358 | model_output["loss"], learning_rate,
359 | num_train_steps, num_warmup_steps, use_tpu=False)
360 | output_spec = tf.estimator.EstimatorSpec(
361 | mode=mode,
362 | loss=model_output["loss"],
363 | train_op=train_op)
364 |
365 | elif mode == tf.estimator.ModeKeys.EVAL:
366 | def metric_fn(model_output, label_ids, is_real_example):
367 | predictions = tf.argmax(model_output["logits"], axis=-1, output_type=tf.int32)
368 | accuracy = tf.metrics.accuracy(
369 | labels=label_ids, predictions=model_output["predictions"],
370 | weights=is_real_example)
371 | loss = tf.metrics.mean(
372 | values=model_output["per_example_loss"], weights=is_real_example)
373 | metric_output = {
374 | "eval_accuracy": accuracy,
375 | "eval_loss": loss,
376 | }
377 |
378 | if "flow_loss" in model_output:
379 | metric_output["eval_loss_flow"] = \
380 | tf.metrics.mean(values=model_output["flow_example_loss"], weights=is_real_example)
381 | metric_output["eval_loss_total"] = \
382 | tf.metrics.mean(
383 | values=model_output["per_example_loss"] + model_output["flow_example_loss"],
384 | weights=is_real_example)
385 |
386 | return metric_output
387 |
388 | def regression_metric_fn(model_output, label_ids, is_real_example):
389 | metric_output = {
390 | "eval_loss": tf.metrics.mean(
391 | values=model_output["per_example_loss"], weights=is_real_example),
392 | "eval_pearsonr": tf.contrib.metrics.streaming_pearson_correlation(
393 | model_output["cos_similarity"], label_ids, weights=is_real_example)
394 | }
395 |
396 | # metric_output["auc"] = tf.compat.v1.metrics.auc(
397 | # label_ids, tf.math.maximum(model_output["cos_similarity"], 0), weights=is_real_example, curve='ROC')
398 |
399 | if "flow_loss" in model_output:
400 | metric_output["eval_loss_flow"] = \
401 | tf.metrics.mean(values=model_output["flow_example_loss"], weights=is_real_example)
402 | metric_output["eval_loss_total"] = \
403 | tf.metrics.mean(
404 | values=model_output["per_example_loss"] + model_output["flow_example_loss"],
405 | weights=is_real_example)
406 |
407 | return metric_output
408 |
409 | if is_regression:
410 | metric_fn = regression_metric_fn
411 |
412 | eval_metrics = metric_fn(model_output, label_ids, is_real_example)
413 |
414 | output_spec = tf.estimator.EstimatorSpec(
415 | mode=mode,
416 | loss=model_output["loss"],
417 | eval_metric_ops=eval_metrics)
418 | else:
419 | output_spec = tf.estimator.EstimatorSpec(
420 | mode=mode,
421 | predictions= {"embedding_a": model_output["embedding_a"],
422 | "embedding_b": model_output["embedding_b"]} if FLAGS.predict_pool else \
423 | {"cos_similarity": model_output["cos_similarity"]})
424 | return output_spec
425 |
426 | return model_fn
427 |
428 |
429 | def main(_):
430 | tf.logging.set_verbosity(tf.logging.INFO)
431 |
432 | # random seed
433 | random.seed(FLAGS.seed)
434 | np.random.seed(FLAGS.seed)
435 | tf.compat.v1.set_random_seed(FLAGS.seed)
436 | print("FLAGS.seed", FLAGS.seed)
437 | # input()
438 |
439 | # prevent double printing of the tf logs
440 | logger = tf.get_logger()
441 | logger.propagate = False
442 |
443 | # get tokenizer
444 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
445 | FLAGS.init_checkpoint)
446 | tokenizer = tokenization.FullTokenizer(
447 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
448 |
449 | # get bert config
450 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
451 | if FLAGS.max_seq_length > bert_config.max_position_embeddings:
452 | raise ValueError(
453 | "Cannot use sequence length %d because the BERT model "
454 | "was only trained up to sequence length %d" %
455 | (FLAGS.max_seq_length, bert_config.max_position_embeddings))
456 |
457 | # GPU config
458 | run_config = tf.compat.v1.ConfigProto()
459 | if FLAGS.use_xla:
460 | run_config.graph_options.optimizer_options.global_jit_level = \
461 | tf.OptimizerOptions.ON_1
462 |
463 | run_config.gpu_options.allow_growth = True
464 |
465 | if FLAGS.do_senteval:
466 | # Set up logger
467 | import logging
468 | tf.logging.set_verbosity(0)
469 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
470 |
471 | # load senteval
472 | import sys
473 | PATH_TO_SENTEVAL, PATH_TO_DATA = '../SentEval', '../SentEval/data'
474 | sys.path.insert(0, PATH_TO_SENTEVAL)
475 | import senteval
476 |
477 | # model
478 | tf.logging.info("***** Running SentEval *****")
479 | with tf.Graph().as_default():
480 | with tf.variable_scope("bert") as scope:
481 | input_ids = tf.placeholder(shape=[None, None], dtype=tf.int32, name="input_ids")
482 | input_mask = tf.placeholder(shape=[None, None], dtype=tf.int32, name="input_mask")
483 | segment_ids = tf.placeholder(shape=[None, None], dtype=tf.int32, name="segment_ids")
484 |
485 | embedding, flow_loss_batch, flow_loss_example = \
486 | get_embedding(bert_config, False,
487 | input_ids, input_mask, segment_ids, scope=scope)
488 | embedding = tf.nn.l2_normalize(embedding, axis=-1)
489 |
490 | tvars = tf.trainable_variables()
491 | initialized_variable_names = {}
492 | if FLAGS.init_checkpoint:
493 | (assignment_map, initialized_variable_names
494 | ) = modeling.get_assignment_map_from_checkpoint(tvars, FLAGS.init_checkpoint)
495 | tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map)
496 |
497 | tf.logging.info("**** Trainable Variables ****")
498 | for var in tvars:
499 | init_string = ""
500 | if var.name in initialized_variable_names:
501 | init_string = ", *INIT_FROM_CKPT*"
502 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
503 | init_string)
504 |
505 | with tf.train.MonitoredSession(
506 | session_creator=tf.compat.v1.train.ChiefSessionCreator(config=run_config)) as session:
507 |
508 | # SentEval prepare and batcher
509 | def prepare(params, samples):
510 | return
511 |
512 | def batcher(params, batch):
513 | batch_input_ids, batch_input_mask, batch_segment_ids = [], [], []
514 | for sent in batch:
515 | if type(sent[0]) == bytes:
516 | sent = [_.decode() for _ in sent]
517 | text = ' '.join(sent) if sent != [] else '.'
518 | # print(text)
519 |
520 | _input_ids, _input_mask, _segment_ids, _tokens = \
521 | get_input_mask_segment(text, FLAGS.max_seq_length, tokenizer)
522 | batch_input_ids.append(_input_ids)
523 | batch_input_mask.append(_input_mask)
524 | batch_segment_ids.append(_segment_ids)
525 |
526 | batch_input_ids = np.asarray(batch_input_ids)
527 | batch_input_mask = np.asarray(batch_input_mask)
528 | batch_segment_ids = np.asarray(batch_segment_ids)
529 |
530 | print(".", end="")
531 |
532 | return session.run(embedding,
533 | {input_ids: batch_input_ids,
534 | input_mask: batch_input_mask,
535 | segment_ids: batch_segment_ids})
536 |
537 | # Set params for SentEval
538 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
539 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
540 | 'tenacity': 3, 'epoch_size': 2}
541 |
542 | # main
543 | se = senteval.engine.SE(params_senteval, batcher, prepare)
544 |
545 | # transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16',
546 | # 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC',
547 | # 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark',
548 | # 'Length', 'WordContent', 'Depth', 'TopConstituents',
549 | # 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber',
550 | # 'OddManOut', 'CoordinationInversion']
551 | #transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']
552 | #transfer_tasks = ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC']
553 | transfer_tasks = FLAGS.senteval_tasks.split(",")
554 | results = se.eval(transfer_tasks)
555 | from collections import OrderedDict
556 | results = OrderedDict(results)
557 | for key in sorted(results):
558 | value = results[key]
559 | if key.startswith("STS"):
560 | print("'" + key + "':", value["all"])
561 | else:
562 | print(key, value)
563 |
564 | return
565 |
566 | processors = {
567 | 'sts-b': StsbProcessor,
568 | 'sick-r': SickRProcessor,
569 | 'mnli': MnliProcessor,
570 | 'allnli': MnliProcessor,
571 | 'qqp': QqpProcessor,
572 | 'sts-12-16': Sts_12_16_Processor,
573 | 'sts-12': Sts_12_16_Processor,
574 | 'sts-13': Sts_12_16_Processor,
575 | 'sts-14': Sts_12_16_Processor,
576 | 'sts-15': Sts_12_16_Processor,
577 | 'sts-16': Sts_12_16_Processor,
578 | 'mrpc-regression': MrpcRegressionProcessor,
579 | 'qnli-regression': QnliRegressionProcessor,
580 | }
581 |
582 | task_name = FLAGS.task_name.lower()
583 | if task_name not in processors:
584 | raise ValueError("Task not found: %s" % (task_name))
585 |
586 | if task_name == 'sick-r' or task_name.startswith("sts"):
587 | is_regression = True
588 | label_min, label_max = 0., 5.
589 | elif task_name in ['qqp', 'mrpc-regression', 'qnli-regression']:
590 | is_regression = True
591 | label_min, label_max = 0., 1.
592 | else:
593 | is_regression = False
594 | label_min, label_max = 0., 1.
595 |
596 | dupe_factor = FLAGS.dupe_factor
597 |
598 | processor = processors[task_name]()
599 |
600 | label_list = processor.get_labels()
601 |
602 | # this block is moved here for calculating the epoch_step for save_checkpoints_steps
603 | train_examples = None
604 | num_train_steps = None
605 | num_warmup_steps = None
606 |
607 | if task_name == "allnli":
608 | FLAGS.data_dir = os.path.join(os.path.dirname(FLAGS.data_dir), "MNLI")
609 |
610 | if FLAGS.do_train and FLAGS.num_train_epochs > 1e-6:
611 | train_examples = processor.get_train_examples(FLAGS.data_dir)
612 |
613 | if task_name == "allnli":
614 | snli_data_dir = os.path.join(os.path.dirname(FLAGS.data_dir), "SNLI")
615 | train_examples.extend(SnliTrainProcessor().get_train_examples(snli_data_dir))
616 | train_examples.extend(SnliDevTestProcessor().get_dev_examples(snli_data_dir))
617 | train_examples.extend(SnliDevTestProcessor().get_test_examples(snli_data_dir))
618 |
619 | if FLAGS.use_full_for_training:
620 | eval_examples = processor.get_dev_examples(FLAGS.data_dir)
621 | predict_examples = processor.get_test_examples(FLAGS.data_dir)
622 | train_examples.extend(eval_examples + predict_examples)
623 |
624 | num_train_steps = int(
625 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
626 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
627 | epoch_step = int(len(train_examples) / FLAGS.train_batch_size)
628 |
629 | if FLAGS.num_examples > 0:
630 | random.shuffle(train_examples)
631 | for i in range(FLAGS.num_examples, len(train_examples)):
632 | train_examples[i].label = -10
633 |
634 | random.shuffle(train_examples)
635 |
636 |
637 | # ==== #
638 |
639 | if FLAGS.early_stopping:
640 | save_checkpoints_steps = epoch_step
641 | else:
642 | save_checkpoints_steps = FLAGS.save_checkpoints_steps
643 |
644 | keep_checkpoint_max = 3
645 | save_summary_steps = log_every_step = FLAGS.log_every_step
646 |
647 | tf.logging.info("save_checkpoints_steps: %d" % save_checkpoints_steps)
648 |
649 | # make exp dir
650 | if FLAGS.exp_name:
651 | output_dir = os.path.join(FLAGS.output_parent_dir, FLAGS.exp_name)
652 | elif FLAGS.exp_name_prefix:
653 | output_dir = os.path.join(FLAGS.output_parent_dir, FLAGS.exp_name_prefix)
654 |
655 | output_dir += "_t_%s" % (FLAGS.task_name)
656 | output_dir += "_ep_%.2f" % (FLAGS.num_train_epochs)
657 | output_dir += "_lr_%.2e" % (FLAGS.learning_rate)
658 |
659 | if FLAGS.train_batch_size != 32:
660 | output_dir += "_bsz_%d" % (FLAGS.train_batch_size)
661 |
662 | if FLAGS.sentence_embedding_type != "avg":
663 | output_dir += "_e_%s" % (FLAGS.sentence_embedding_type)
664 |
665 | if FLAGS.flow > 0:
666 | output_dir += "_f_%d%d" % (FLAGS.flow, FLAGS.flow_loss)
667 |
668 | if FLAGS.flow_loss > 0:
669 | output_dir += "_%.2e" % (FLAGS.flow_learning_rate)
670 |
671 | if FLAGS.use_full_for_training > 0:
672 | output_dir += "_allsplits"
673 |
674 | if FLAGS.flow_model_config != "config_l3_d3_w32":
675 | output_dir += "_%s" % (FLAGS.flow_model_config)
676 |
677 | if FLAGS.num_examples > 0:
678 | output_dir += "_n_%d" % (FLAGS.num_examples)
679 |
680 | if FLAGS.low_dim > -1:
681 | output_dir += "_ld_%d" % (FLAGS.low_dim)
682 |
683 | if FLAGS.l2_penalty > 0:
684 | output_dir += "_l2_%.2e" % (FLAGS.l2_penalty)
685 |
686 | else:
687 | raise NotImplementedError
688 |
689 | if tf.gfile.Exists(output_dir) and FLAGS.do_train:
690 | tf.io.gfile.rmtree(output_dir)
691 | tf.gfile.MakeDirs(output_dir)
692 |
693 | # set up estimator
694 | run_config = tf.estimator.RunConfig(
695 | model_dir=output_dir,
696 | save_summary_steps=save_summary_steps,
697 | save_checkpoints_steps=save_checkpoints_steps,
698 | keep_checkpoint_max=keep_checkpoint_max,
699 | log_step_count_steps=log_every_step,
700 | session_config=run_config)
701 |
702 | model_fn = model_fn_builder(
703 | bert_config=bert_config,
704 | num_labels=len(label_list),
705 | init_checkpoint=FLAGS.init_checkpoint,
706 | learning_rate=FLAGS.learning_rate,
707 | num_train_steps=num_train_steps,
708 | num_warmup_steps=num_warmup_steps,
709 | is_regression=is_regression)
710 |
711 | estimator = tf.estimator.Estimator(
712 | model_fn=model_fn,
713 | config=run_config,
714 | params={
715 | 'train_batch_size': FLAGS.train_batch_size,
716 | 'eval_batch_size': FLAGS.eval_batch_size,
717 | 'predict_batch_size': FLAGS.predict_batch_size})
718 |
719 | def get_train_input_fn():
720 | cached_dir = FLAGS.cached_dir
721 | if not cached_dir:
722 | cached_dir = output_dir
723 |
724 | data_name = task_name
725 |
726 | if FLAGS.num_examples > 0:
727 | train_file = os.path.join(cached_dir,
728 | data_name + "_n_%d" % (FLAGS.num_examples) \
729 | + "_seed_%d" % (FLAGS.seed) + "_train.tf_record")
730 | elif FLAGS.use_full_for_training > 0:
731 | train_file = os.path.join(cached_dir, data_name + "_allsplits.tf_record")
732 | else:
733 | train_file = os.path.join(cached_dir, data_name + "_train.tf_record")
734 |
735 | if not tf.gfile.Exists(train_file):
736 | file_based_convert_examples_to_features(
737 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file,
738 | dupe_factor, label_min, label_max,
739 | is_training=True)
740 |
741 | tf.logging.info("***** Running training *****")
742 | tf.logging.info(" Num examples = %d", len(train_examples))
743 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
744 | tf.logging.info(" Num steps = %d", num_train_steps)
745 |
746 | train_input_fn = file_based_input_fn_builder(
747 | input_file=train_file,
748 | seq_length=FLAGS.max_seq_length,
749 | is_training=True,
750 | drop_remainder=True,
751 | is_regression=is_regression)
752 |
753 | return train_input_fn
754 |
755 | def get_eval_input_fn():
756 | eval_examples = processor.get_dev_examples(FLAGS.data_dir)
757 | num_actual_eval_examples = len(eval_examples)
758 |
759 | cached_dir = FLAGS.cached_dir
760 | if not cached_dir:
761 | cached_dir = output_dir
762 | eval_file = os.path.join(cached_dir, task_name + "_eval.tf_record")
763 |
764 | if not tf.gfile.Exists(eval_file):
765 | file_based_convert_examples_to_features(
766 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file,
767 | dupe_factor, label_min, label_max)
768 |
769 | tf.logging.info("***** Running evaluation *****")
770 | tf.logging.info(" Num examples = %d (%d actual, %d padding)",
771 | len(eval_examples), num_actual_eval_examples,
772 | len(eval_examples) - num_actual_eval_examples)
773 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
774 |
775 | # This tells the estimator to run through the entire set.
776 | eval_drop_remainder = False
777 | eval_input_fn = file_based_input_fn_builder(
778 | input_file=eval_file,
779 | seq_length=FLAGS.max_seq_length,
780 | is_training=False,
781 | drop_remainder=eval_drop_remainder,
782 | is_regression=is_regression)
783 |
784 | return eval_input_fn
785 |
786 | def get_predict_input_fn():
787 | predict_examples = None
788 | if FLAGS.do_predict_on_dev:
789 | predict_examples = processor.get_dev_examples(FLAGS.data_dir)
790 | elif FLAGS.do_predict_on_full:
791 | train_examples = processor.get_train_examples(FLAGS.data_dir)
792 | eval_examples = processor.get_dev_examples(FLAGS.data_dir)
793 | predict_examples = processor.get_test_examples(FLAGS.data_dir)
794 | predict_examples.extend(eval_examples + train_examples)
795 | else:
796 | predict_examples = processor.get_test_examples(FLAGS.data_dir)
797 | num_actual_predict_examples = len(predict_examples)
798 |
799 | cached_dir = FLAGS.cached_dir
800 | if not cached_dir:
801 | cached_dir = output_dir
802 | predict_file = os.path.join(cached_dir, task_name + "_predict.tf_record")
803 |
804 | file_based_convert_examples_to_features(
805 | predict_examples, label_list, FLAGS.max_seq_length, tokenizer, predict_file,
806 | dupe_factor, label_min, label_max)
807 |
808 | tf.logging.info("***** Running prediction*****")
809 | tf.logging.info(" Num examples = %d (%d actual, %d padding)",
810 | len(predict_examples), num_actual_predict_examples,
811 | len(predict_examples) - num_actual_predict_examples)
812 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
813 |
814 | predict_drop_remainder = False
815 | predict_input_fn = file_based_input_fn_builder(
816 | input_file=predict_file,
817 | seq_length=FLAGS.max_seq_length,
818 | is_training=False,
819 | drop_remainder=predict_drop_remainder,
820 | is_regression=is_regression)
821 |
822 | return predict_input_fn, num_actual_predict_examples
823 |
824 | eval_steps = None
825 |
826 | if FLAGS.do_train and FLAGS.num_train_epochs > 1e-6:
827 | train_input_fn = get_train_input_fn()
828 | if FLAGS.early_stopping:
829 | eval_input_fn = get_eval_input_fn()
830 | early_stopping_hook = tf.estimator.experimental.stop_if_no_decrease_hook(
831 | estimator, metric_name="eval_pearsonr",
832 | max_steps_without_decrease=epoch_step//2, run_every_steps=epoch_step, run_every_secs=None)
833 | train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps,
834 | hooks=[early_stopping_hook])
835 |
836 | start_delay_secs = FLAGS.start_delay_secs
837 | throttle_secs = FLAGS.throttle_secs
838 | tf.logging.info("start_delay_secs: %d; throttle_secs: %d" % (start_delay_secs, throttle_secs))
839 | eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=eval_steps,
840 | start_delay_secs=start_delay_secs, throttle_secs=throttle_secs)
841 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
842 | else:
843 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
844 |
845 | if FLAGS.do_eval:
846 | eval_input_fn = get_eval_input_fn()
847 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
848 |
849 | output_eval_file = os.path.join(output_dir, "eval_results.txt")
850 | with tf.gfile.GFile(output_eval_file, "w") as writer:
851 | tf.logging.info("***** Eval results *****")
852 | for key in sorted(result.keys()):
853 | tf.logging.info(" %s = %s", key, str(result[key]))
854 | writer.write("%s = %s\n" % (key, str(result[key])))
855 |
856 | if FLAGS.do_predict:
857 | predict_input_fn, num_actual_predict_examples = get_predict_input_fn()
858 | checkpoint_path = None
859 | if FLAGS.eval_checkpoint_name:
860 | checkpoint_path = os.path.join(output_dir, FLAGS.eval_checkpoint_name)
861 | result = estimator.predict(input_fn=predict_input_fn,
862 | checkpoint_path=checkpoint_path)
863 |
864 | def round_float_list(values):
865 | values = [round(float(x), 6) for x in values.flat]
866 | return values
867 |
868 | fname = ""
869 | if FLAGS.do_predict_on_full:
870 | fname += "full"
871 | elif FLAGS.do_predict_on_dev:
872 | fname += "dev"
873 | else:
874 | fname += "test"
875 |
876 | if FLAGS.predict_pool:
877 | fname += "_pooled.tsv"
878 | else:
879 | fname += "_results.tsv"
880 |
881 | if FLAGS.eval_checkpoint_name:
882 | fname = FLAGS.eval_checkpoint_name + "." + fname
883 | output_predict_file = os.path.join(output_dir, fname)
884 | with tf.gfile.GFile(output_predict_file, "w") as writer:
885 | num_written_lines = 0
886 | tf.logging.info("***** Predict results *****")
887 | for (i, prediction) in enumerate(result):
888 |
889 | if is_regression:
890 | if FLAGS.predict_pool:
891 | embedding_a = prediction["embedding_a"]
892 | embedding_b = prediction["embedding_b"]
893 |
894 | output_json = collections.OrderedDict()
895 | output_json["embedding_a"] = round_float_list(embedding_a)
896 | output_json["embedding_b"] = round_float_list(embedding_b)
897 |
898 | output_line = json.dumps(output_json) + "\n"
899 | else:
900 | cos_similarity = prediction["cos_similarity"]
901 | if i >= num_actual_predict_examples:
902 | break
903 | output_line = str(cos_similarity) + "\n"
904 | else:
905 | raise NotImplementedError
906 |
907 | writer.write(output_line)
908 | num_written_lines += 1
909 | assert num_written_lines == num_actual_predict_examples
910 |
911 | tf.logging.info("*** output_dir ***")
912 | tf.logging.info(output_dir)
913 |
914 |
915 | if __name__ == "__main__":
916 | flags.mark_flag_as_required("data_dir")
917 | flags.mark_flag_as_required("task_name")
918 | flags.mark_flag_as_required("vocab_file")
919 | flags.mark_flag_as_required("bert_config_file")
920 | tf.app.run()
921 |
--------------------------------------------------------------------------------
/scripts/eval_stsb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import sys
5 | import numpy as np
6 | import scipy
7 | import scipy.stats
8 |
9 | def get_pred(fpath):
10 | with open(fpath) as f:
11 | x = [float(_) for _ in f.readlines()]
12 | return x
13 |
14 | def get_gt(fpath, col, header=False):
15 | with open(fpath) as f:
16 | y = np.asarray([float(_.split('\t')[col]) for _ in f.readlines()[int(header):]])
17 | return y
18 |
19 | def get_correlation(x, y):
20 | print("Pearson: %f" % pearson_r(x, y), end=", ")
21 | print("Spearman: %f" % scipy.stats.spearmanr(x, y).correlation)
22 |
23 | def get_auc(pred, y):
24 | from sklearn import metrics
25 | fpr, tpr, thresholds = metrics.roc_curve(y, pred)
26 | print("AUC: %f" % metrics.auc(fpr, tpr))
27 |
28 | def pearson_r(x, y):
29 | """Compute Pearson correlation coefficient between two arrays."""
30 | corr_mat = np.corrcoef(x, y)
31 | return corr_mat[0, 1]
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser(description=None)
35 | parser.add_argument('--glue_path', type=str, default="../glue_data",
36 | help='path to predicted sentence vectors')
37 | parser.add_argument('--task_name', type=str, default="sts-b",
38 | help='path to predicted sentence vectors')
39 | parser.add_argument('--pred_path', type=str,
40 | help='path to predicted sentence vectors')
41 | parser.add_argument('--is_test', type=int, default=0,
42 | help='eval/test set')
43 | args = parser.parse_args()
44 |
45 | x = get_pred(args.pred_path)
46 | if args.task_name.lower() == "sts-b":
47 | if args.is_test == 1:
48 | fpath = os.path.join(args.glue_path, "STS-B/sts-test.csv")
49 | y = get_gt(fpath, 4, 0)
50 | elif args.is_test == 0:
51 | fpath = os.path.join(args.glue_path, "STS-B/dev.tsv")
52 | y = get_gt(fpath, 9, 1)
53 | elif args.is_test == -1:
54 | fpath = os.path.join(args.glue_path, "STS-B/train.tsv")
55 | y = get_gt(fpath, 9, 1)
56 | else:
57 | raise NotImplementedError
58 | elif args.task_name.lower() == "sick-r":
59 | fpath = os.path.join(args.glue_path, "SICK-R/SICK_test_annotated.txt")
60 | y = get_gt(fpath, 3, 1)
61 | elif args.task_name.lower() == "mrpc-regression":
62 | fpath = os.path.join(args.glue_path, "MRPC-Regression/msr_paraphrase_test.txt")
63 | y = get_gt(fpath, 0, 1)
64 | else:
65 | raise NotImplementedError
66 |
67 | get_correlation(x, y)
68 | if args.task_name.lower() in ["mrpc-regression", "qnli-regression"]:
69 | get_auc(x, y)
70 |
71 |
--------------------------------------------------------------------------------
/scripts/train_siamese.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CURDIR=$(cd $(dirname $0); cd ..; pwd)
3 |
4 | export BERT_DIR=${BERT_PREMODELS}/uncased_L-12_H-768_A-12
5 |
6 | if [[ $BERT_NAME == "large-wwm" ]];then
7 | export BERT_DIR=${BERT_PREMODELS}/wwm_uncased_L-24_H-1024_A-16
8 | elif [[ $BERT_NAME == "large" ]];then
9 | export BERT_DIR=${BERT_PREMODELS}/uncased_L-24_H-1024_A-16
10 | else
11 | export BERT_DIR=${BERT_PREMODELS}/uncased_L-12_H-768_A-12
12 | fi
13 |
14 | if [ -z "$INIT_CKPT" ]; then
15 | export INIT_CKPT=$BERT_DIR/bert_model.ckpt
16 | fi
17 |
18 | if [ -z "$TASK_NAME" ]; then
19 | export TASK_NAME="STS-B"
20 | fi
21 |
22 |
23 | if [[ $1 == "train" ]];then
24 | echo "train"
25 |
26 | exec python3 ${CURDIR}/run_siamese.py \
27 | --task_name=${TASK_NAME} \
28 | --do_train=true \
29 | --do_eval=true \
30 | --data_dir=${GLUE_DIR}/${TASK_NAME} \
31 | --vocab_file=${BERT_DIR}/vocab.txt \
32 | --bert_config_file=${BERT_DIR}/bert_config.json \
33 | --init_checkpoint=${INIT_CKPT} \
34 | --max_seq_length=64 \
35 | --output_parent_dir=${OUTPUT_PARENT_DIR} \
36 | ${@:2}
37 | elif [[ $1 == "eval" ]];then
38 | echo "eval"
39 | python3 ${CURDIR}/run_siamese.py \
40 | --task_name=${TASK_NAME} \
41 | --do_eval=true \
42 | --data_dir=${GLUE_DIR}/${TASK_NAME} \
43 | --vocab_file=${BERT_DIR}/vocab.txt \
44 | --bert_config_file=${BERT_DIR}/bert_config.json \
45 | --init_checkpoint=${INIT_CKPT} \
46 | --max_seq_length=64 \
47 | --output_parent_dir=${OUTPUT_PARENT_DIR} \
48 | ${@:2}
49 | elif [[ $1 == "predict" ]];then
50 | echo "predict"
51 | python3 ${CURDIR}/run_siamese.py \
52 | --task_name=${TASK_NAME} \
53 | --do_predict=true \
54 | --data_dir=${GLUE_DIR}/${TASK_NAME} \
55 | --vocab_file=${BERT_DIR}/vocab.txt \
56 | --bert_config_file=${BERT_DIR}/bert_config.json \
57 | --init_checkpoint=${INIT_CKPT} \
58 | --max_seq_length=64 \
59 | --output_parent_dir=${OUTPUT_PARENT_DIR} \
60 | ${@:2}
61 |
62 | python3 scripts/eval_stsb.py \
63 | --glue_path=${GLUE_DIR} \
64 | --task_name=${TASK_NAME} \
65 | --pred_path=${OUTPUT_PARENT_DIR}/${EXP_NAME}/test_results.tsv \
66 | --is_test=1
67 |
68 | elif [[ $1 == "predict_pool" ]];then
69 | echo "predict_dev"
70 | python3 ${CURDIR}/run_siamese.py \
71 | --task_name=${TASK_NAME} \
72 | --do_predict=true \
73 | --data_dir=${GLUE_DIR}/${TASK_NAME} \
74 | --vocab_file=${BERT_DIR}/vocab.txt \
75 | --bert_config_file=${BERT_DIR}/bert_config.json \
76 | --max_seq_length=64 \
77 | --output_parent_dir=${OUTPUT_PARENT_DIR} \
78 | --predict_pool=True \
79 | ${@:2}
80 |
81 | elif [[ $1 == "predict_dev" ]];then
82 | echo "predict_dev"
83 | python3 ${CURDIR}/run_siamese.py \
84 | --task_name=${TASK_NAME} \
85 | --do_predict=true \
86 | --data_dir=${GLUE_DIR}/${TASK_NAME} \
87 | --vocab_file=${BERT_DIR}/vocab.txt \
88 | --bert_config_file=${BERT_DIR}/bert_config.json \
89 | --max_seq_length=64 \
90 | --output_parent_dir=${OUTPUT_PARENT_DIR} \
91 | --do_predict_on_dev=True \
92 | --predict_pool=True \
93 | ${@:2}
94 |
95 | elif [[ $1 == "predict_full" ]];then
96 | echo "predict_dev"
97 | python3 ${CURDIR}/run_siamese.py \
98 | --task_name=${TASK_NAME} \
99 | --do_predict=true \
100 | --data_dir=${GLUE_DIR}/${TASK_NAME} \
101 | --vocab_file=${BERT_DIR}/vocab.txt \
102 | --bert_config_file=${BERT_DIR}/bert_config.json \
103 | --max_seq_length=64 \
104 | --output_parent_dir=${OUTPUT_PARENT_DIR} \
105 | --do_predict_on_full=True \
106 | --predict_pool=True \
107 | ${@:2}
108 |
109 | elif [[ $1 == "do_senteval" ]];then
110 | echo "do_senteval"
111 | python3 ${CURDIR}/run_siamese.py \
112 | --task_name=${TASK_NAME} \
113 | --do_senteval=true \
114 | --data_dir=${GLUE_DIR}/${TASK_NAME} \
115 | --vocab_file=${BERT_DIR}/vocab.txt \
116 | --bert_config_file=${BERT_DIR}/bert_config.json \
117 | --init_checkpoint=${INIT_CKPT} \
118 | --max_seq_length=64 \
119 | --output_parent_dir=${OUTPUT_PARENT_DIR} \
120 | ${@:2}
121 |
122 | else
123 | echo "NotImplementedError"
124 | fi
125 |
126 |
127 |
128 |
129 |
130 |
--------------------------------------------------------------------------------
/siamese_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import warnings
6 | warnings.simplefilter(action='ignore', category=FutureWarning)
7 |
8 | import os
9 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
10 |
11 | import collections
12 | import csv
13 |
14 | import modeling
15 | import optimization
16 | import tokenization
17 |
18 | import tensorflow as tf
19 |
20 | import random
21 | import numpy as np
22 |
23 |
24 | class InputExample(object):
25 | """A single training/test example for simple sequence classification."""
26 |
27 | def __init__(self, guid, text_a, text_b=None, label=None):
28 | """Constructs a InputExample.
29 |
30 | Args:
31 | guid: Unique id for the example.
32 | text_a: string. The untokenized text of the first sequence. For single
33 | sequence tasks, only this sequence must be specified.
34 | text_b: (Optional) string. The untokenized text of the second sequence.
35 | Only must be specified for sequence pair tasks.
36 | label: (Optional) string. The label of the example. This should be
37 | specified for train and dev examples, but not for test examples.
38 | """
39 | self.guid = guid
40 | self.text_a = text_a
41 | self.text_b = text_b
42 | self.label = label
43 |
44 |
45 | class InputFeatures(object):
46 | """A single set of features of data."""
47 |
48 | def __init__(self,
49 | input_ids_a,
50 | input_mask_a,
51 | segment_ids_a,
52 | input_ids_b,
53 | input_mask_b,
54 | segment_ids_b,
55 | label_id,
56 | is_real_example=True):
57 | self.input_ids_a = input_ids_a
58 | self.input_mask_a = input_mask_a
59 | self.segment_ids_a = segment_ids_a
60 | self.input_ids_b = input_ids_b
61 | self.input_mask_b = input_mask_b
62 | self.segment_ids_b = segment_ids_b
63 | self.label_id = label_id
64 | self.is_real_example = is_real_example
65 |
66 |
67 | class DataProcessor(object):
68 | """Base class for data converters for sequence classification data sets."""
69 |
70 | def get_train_examples(self, data_dir):
71 | """Gets a collection of `InputExample`s for the train set."""
72 | raise NotImplementedError()
73 |
74 | def get_dev_examples(self, data_dir):
75 | """Gets a collection of `InputExample`s for the dev set."""
76 | raise NotImplementedError()
77 |
78 | def get_test_examples(self, data_dir):
79 | """Gets a collection of `InputExample`s for prediction."""
80 | raise NotImplementedError()
81 |
82 | def get_labels(self):
83 | """Gets the list of labels for this data set."""
84 | raise NotImplementedError()
85 |
86 | @classmethod
87 | def _read_tsv(cls, input_file, quotechar=None):
88 | """Reads a tab separated value file."""
89 | with tf.gfile.Open(input_file, "r") as f:
90 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
91 | lines = []
92 | for line in reader:
93 | lines.append(line)
94 | return lines
95 |
96 |
97 | class GLUEProcessor(DataProcessor):
98 | def __init__(self):
99 | self.train_file = "train.tsv"
100 | self.dev_file = "dev.tsv"
101 | self.test_file = "test.tsv"
102 | self.label_column = None
103 | self.text_a_column = None
104 | self.text_b_column = None
105 | self.contains_header = True
106 | self.test_text_a_column = None
107 | self.test_text_b_column = None
108 | self.test_contains_header = True
109 |
110 | def get_train_examples(self, data_dir):
111 | """See base class."""
112 | return self._create_examples(
113 | self._read_tsv(os.path.join(data_dir, self.train_file)), "train")
114 |
115 | def get_dev_examples(self, data_dir):
116 | """See base class."""
117 | return self._create_examples(
118 | self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev")
119 |
120 | def get_test_examples(self, data_dir):
121 | """See base class."""
122 | if self.test_text_a_column is None:
123 | self.test_text_a_column = self.text_a_column
124 | if self.test_text_b_column is None:
125 | self.test_text_b_column = self.text_b_column
126 |
127 | return self._create_examples(
128 | self._read_tsv(os.path.join(data_dir, self.test_file)), "test")
129 |
130 | def get_labels(self):
131 | """See base class."""
132 | return ["0", "1"]
133 |
134 | def _create_examples(self, lines, set_type):
135 | """Creates examples for the training and dev sets."""
136 | examples = []
137 | for (i, line) in enumerate(lines):
138 | if i == 0 and self.contains_header and set_type != "test":
139 | continue
140 | if i == 0 and self.test_contains_header and set_type == "test":
141 | continue
142 | guid = "%s-%s" % (set_type, i)
143 |
144 | a_column = (self.text_a_column if set_type != "test" else
145 | self.test_text_a_column)
146 | b_column = (self.text_b_column if set_type != "test" else
147 | self.test_text_b_column)
148 |
149 | # there are some incomplete lines in QNLI
150 | if len(line) <= a_column:
151 | tf.logging.warning('Incomplete line, ignored.')
152 | continue
153 | text_a = line[a_column]
154 |
155 | if b_column is not None:
156 | if len(line) <= b_column:
157 | tf.logging.warning('Incomplete line, ignored.')
158 | continue
159 | text_b = line[b_column]
160 | else:
161 | text_b = None
162 |
163 | if set_type == "test":
164 | label = self.get_labels()[0]
165 | else:
166 | if len(line) <= self.label_column:
167 | tf.logging.warning('Incomplete line, ignored.')
168 | continue
169 | label = line[self.label_column]
170 | if len(label) == 0:
171 | raise Exception
172 | examples.append(
173 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
174 | return examples
175 |
176 |
177 | class StsbProcessor(GLUEProcessor):
178 | def __init__(self):
179 | super(StsbProcessor, self).__init__()
180 | self.label_column = 9
181 | self.text_a_column = 7
182 | self.text_b_column = 8
183 |
184 | def get_labels(self):
185 | return [0.]
186 |
187 | def _create_examples(self, lines, set_type):
188 | """Creates examples for the training and dev sets."""
189 | examples = []
190 | for (i, line) in enumerate(lines):
191 | if i == 0 and self.contains_header and set_type != "test":
192 | continue
193 | if i == 0 and self.test_contains_header and set_type == "test":
194 | continue
195 | guid = "%s-%s" % (set_type, i)
196 |
197 | a_column = (self.text_a_column if set_type != "test" else
198 | self.test_text_a_column)
199 | b_column = (self.text_b_column if set_type != "test" else
200 | self.test_text_b_column)
201 |
202 | # there are some incomplete lines in QNLI
203 | if len(line) <= a_column:
204 | tf.logging.warning('Incomplete line, ignored.')
205 | continue
206 | text_a = line[a_column]
207 |
208 | if b_column is not None:
209 | if len(line) <= b_column:
210 | tf.logging.warning('Incomplete line, ignored.')
211 | continue
212 | text_b = line[b_column]
213 | else:
214 | text_b = None
215 |
216 | if set_type == "test":
217 | label = self.get_labels()[0]
218 | else:
219 | if len(line) <= self.label_column:
220 | tf.logging.warning('Incomplete line, ignored.')
221 | continue
222 | label = float(line[self.label_column])
223 | examples.append(
224 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
225 |
226 | return examples
227 |
228 |
229 | class SickRProcessor(StsbProcessor):
230 | """Processor for the MultiNLI data set (GLUE version)."""
231 | def __init__(self):
232 | super(SickRProcessor, self).__init__()
233 | self.train_file = "SICK_train.txt"
234 | self.dev_file = "SICK_trial.txt"
235 | self.test_file = "SICK_test_annotated.txt"
236 | self.label_column = 3
237 | self.text_a_column = 1
238 | self.text_b_column = 2
239 | self.contains_header = True
240 | self.test_text_a_column = None
241 | self.test_text_b_column = None
242 | self.test_contains_header = True
243 |
244 | class Sts_12_16_Processor(GLUEProcessor):
245 | def __init__(self):
246 | super(Sts_12_16_Processor, self).__init__()
247 | self.train_file = "full.txt"
248 | self.dev_file = "full.txt"
249 | self.test_file = "full.txt"
250 | self.text_a_column = 0
251 | self.text_b_column = 1
252 |
253 | def get_labels(self):
254 | return [0.]
255 |
256 | def _create_examples(self, lines, set_type):
257 | """Creates examples for the training and dev sets."""
258 | examples = []
259 | for (i, line) in enumerate(lines):
260 | guid = "%s-%s" % (set_type, i)
261 | text_a = line[self.text_a_column]
262 | text_b = line[self.text_b_column]
263 | label = self.get_labels()[0]
264 | examples.append(
265 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
266 |
267 | return examples
268 |
269 |
270 | class QqpProcessor(GLUEProcessor):
271 | """Processor for the MultiNLI data set (GLUE version)."""
272 | def __init__(self):
273 | super(QqpProcessor, self).__init__()
274 | self.train_file = "train.tsv"
275 | self.dev_file = "dev.tsv"
276 | self.test_file = "test.tsv"
277 | self.label_column = 5
278 | self.text_a_column = 3
279 | self.text_b_column = 4
280 | self.contains_header = True
281 | self.test_text_a_column = 1
282 | self.test_text_b_column = 2
283 | self.test_contains_header = True
284 |
285 | def get_labels(self):
286 | """See base class."""
287 | return [0.]
288 |
289 |
290 | class MrpcRegressionProcessor(StsbProcessor):
291 | def __init__(self):
292 | super(MrpcRegressionProcessor, self).__init__()
293 | self.label_column = 0
294 | self.text_a_column = 3
295 | self.text_b_column = 4
296 |
297 |
298 | class QnliRegressionProcessor(GLUEProcessor):
299 | def __init__(self):
300 | super(QnliRegressionProcessor, self).__init__()
301 | self.label_column = -1
302 | self.text_a_column = 1
303 | self.text_b_column = 2
304 |
305 | def get_labels(self):
306 | return [0.]
307 |
308 | def _create_examples(self, lines, set_type):
309 | """Creates examples for the training and dev sets."""
310 | examples = []
311 | for (i, line) in enumerate(lines):
312 | if i == 0 and self.contains_header and set_type != "test":
313 | continue
314 | if i == 0 and self.test_contains_header and set_type == "test":
315 | continue
316 | guid = "%s-%s" % (set_type, i)
317 |
318 | a_column = (self.text_a_column if set_type != "test" else
319 | self.test_text_a_column)
320 | b_column = (self.text_b_column if set_type != "test" else
321 | self.test_text_b_column)
322 |
323 | # there are some incomplete lines in QNLI
324 | if len(line) <= a_column:
325 | tf.logging.warning('Incomplete line, ignored.')
326 | continue
327 | text_a = line[a_column]
328 |
329 | if b_column is not None:
330 | if len(line) <= b_column:
331 | tf.logging.warning('Incomplete line, ignored.')
332 | continue
333 | text_b = line[b_column]
334 | else:
335 | text_b = None
336 |
337 | if set_type == "test":
338 | label = self.get_labels()[0]
339 | else:
340 | if len(line) <= self.label_column:
341 | tf.logging.warning('Incomplete line, ignored.')
342 | continue
343 | label_score_map = { "not_entailment": 0, "entailment": 1 }
344 | label = label_score_map[line[self.label_column]]
345 | examples.append(
346 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
347 |
348 | return examples
349 |
350 |
351 | class MnliProcessor(GLUEProcessor):
352 | """Processor for the MultiNLI data set (GLUE version)."""
353 | def __init__(self):
354 | super(MnliProcessor, self).__init__()
355 | self.train_file = "train.tsv"
356 | self.dev_file = "dev_matched.tsv"
357 | self.test_file = "test_matched.tsv"
358 | self.label_column = 10
359 | self.text_a_column = 8
360 | self.text_b_column = 9
361 | self.contains_header = True
362 | self.test_text_a_column = None
363 | self.test_text_b_column = None
364 | self.test_contains_header = True
365 |
366 | def get_labels(self):
367 | """See base class."""
368 | return ["contradiction", "entailment", "neutral"]
369 |
370 |
371 | class SnliTrainProcessor(GLUEProcessor):
372 | """Processor for the MultiNLI data set (GLUE version)."""
373 | def __init__(self):
374 | super(SnliTrainProcessor, self).__init__()
375 | self.train_file = "train.tsv"
376 | self.dev_file = "dev.tsv"
377 | self.test_file = "test.tsv"
378 | self.label_column = -1
379 | self.text_a_column = 7
380 | self.text_b_column = 8
381 | self.contains_header = True
382 |
383 | def get_labels(self):
384 | """See base class."""
385 | return ["contradiction", "entailment", "neutral"]
386 |
387 | class SnliDevTestProcessor(SnliTrainProcessor):
388 | """Processor for the MultiNLI data set (GLUE version)."""
389 | def __init__(self):
390 | super(SnliDevTestProcessor, self).__init__()
391 | self.label_column = -1
392 |
393 | def get_input_mask_segment(text,
394 | max_seq_length, tokenizer, random_mask=0):
395 | tokens = tokenizer.tokenize(text)
396 |
397 | # Account for [CLS] and [SEP] with "- 2"
398 | if len(tokens) > max_seq_length - 2:
399 | tokens = tokens[0:(max_seq_length - 2)]
400 |
401 | if random_mask:
402 | tokens[random.randint(0, len(tokens)-1)] = "[MASK]"
403 |
404 | tokens = ["[CLS]"] + tokens + ["[SEP]"]
405 | segment_ids = [0 for _ in tokens]
406 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
407 |
408 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
409 | # tokens are attended to.
410 | input_mask = [1] * len(input_ids)
411 |
412 | # Zero-pad up to the sequence length.
413 | while len(input_ids) < max_seq_length:
414 | input_ids.append(0)
415 | input_mask.append(0)
416 | segment_ids.append(0)
417 |
418 | assert len(input_ids) == max_seq_length
419 | assert len(input_mask) == max_seq_length
420 | assert len(segment_ids) == max_seq_length
421 |
422 | return (input_ids, input_mask, segment_ids, tokens)
423 |
424 | def convert_single_example(ex_index, example, label_list, max_seq_length,
425 | tokenizer, random_mask=0):
426 | """Converts a single `InputExample` into a single `InputFeatures`."""
427 |
428 | label_map = {}
429 | for (i, label) in enumerate(label_list):
430 | label_map[label] = i
431 |
432 | input_ids_a, input_mask_a, segment_ids_a, tokens_a = \
433 | get_input_mask_segment(example.text_a, max_seq_length, tokenizer, random_mask)
434 | input_ids_b, input_mask_b, segment_ids_b, tokens_b = \
435 | get_input_mask_segment(example.text_b, max_seq_length, tokenizer, random_mask)
436 |
437 | if len(label_list) > 1:
438 | label_id = label_map[example.label]
439 | else:
440 | label_id = example.label
441 |
442 | if ex_index < 5:
443 | tf.logging.info("*** Example ***")
444 | tf.logging.info("guid: %s" % (example.guid))
445 | tf.logging.info("tokens_a: %s" % " ".join(
446 | [tokenization.printable_text(x) for x in tokens_a]))
447 | tf.logging.info("tokens_b: %s" % " ".join(
448 | [tokenization.printable_text(x) for x in tokens_b]))
449 | tf.logging.info("input_ids_a: %s" % " ".join([str(x) for x in input_ids_a]))
450 | tf.logging.info("input_mask_a: %s" % " ".join([str(x) for x in input_mask_a]))
451 | tf.logging.info("segment_ids_a: %s" % " ".join([str(x) for x in segment_ids_a]))
452 | tf.logging.info("input_ids_b: %s" % " ".join([str(x) for x in input_ids_b]))
453 | tf.logging.info("input_mask_b: %s" % " ".join([str(x) for x in input_mask_b]))
454 | tf.logging.info("segment_ids_b: %s" % " ".join([str(x) for x in segment_ids_b]))
455 | tf.logging.info("label: %s (id = %s)" % (example.label, label_id))
456 |
457 | feature = InputFeatures(
458 | input_ids_a=input_ids_a,
459 | input_mask_a=input_mask_a,
460 | segment_ids_a=segment_ids_a,
461 | input_ids_b=input_ids_b,
462 | input_mask_b=input_mask_b,
463 | segment_ids_b=segment_ids_b,
464 | label_id=label_id,
465 | is_real_example=True)
466 | return feature
467 |
468 |
469 | def file_based_convert_examples_to_features(
470 | examples, label_list, max_seq_length, tokenizer, output_file,
471 | dupe_factor, label_min, label_max, is_training=False):
472 | """Convert a set of `InputExample`s to a TFRecord file."""
473 |
474 | writer = tf.python_io.TFRecordWriter(output_file)
475 |
476 | for (ex_index, example) in enumerate(examples):
477 | if ex_index % 10000 == 0:
478 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
479 |
480 | for t in range(dupe_factor if is_training else 1):
481 | feature = convert_single_example(ex_index, example, label_list,
482 | max_seq_length, tokenizer,
483 | random_mask=0 if t == 0 else 1)
484 |
485 | def create_int_feature(values):
486 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
487 | return f
488 |
489 | def create_float_feature(values):
490 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
491 | return f
492 |
493 | features = collections.OrderedDict()
494 | features["input_ids_a"] = create_int_feature(feature.input_ids_a)
495 | features["input_mask_a"] = create_int_feature(feature.input_mask_a)
496 | features["segment_ids_a"] = create_int_feature(feature.segment_ids_a)
497 |
498 | features["input_ids_b"] = create_int_feature(feature.input_ids_b)
499 | features["input_mask_b"] = create_int_feature(feature.input_mask_b)
500 | features["segment_ids_b"] = create_int_feature(feature.segment_ids_b)
501 |
502 | if len(label_list) > 1:
503 | features["label_ids"] = create_int_feature([feature.label_id])
504 | else:
505 | features["label_ids"] = create_float_feature(
506 | [(float(feature.label_id) - label_min) / (label_max - label_min)])
507 | features["is_real_example"] = create_int_feature(
508 | [int(feature.is_real_example)])
509 |
510 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
511 | writer.write(tf_example.SerializeToString())
512 |
513 | writer.close()
514 |
515 |
516 | def file_based_input_fn_builder(input_file, seq_length, is_training,
517 | drop_remainder, is_regression):
518 | """Creates an `input_fn` closure to be passed to Estimator."""
519 |
520 | name_to_features = {
521 | "input_ids_a": tf.FixedLenFeature([seq_length], tf.int64),
522 | "input_mask_a": tf.FixedLenFeature([seq_length], tf.int64),
523 | "segment_ids_a": tf.FixedLenFeature([seq_length], tf.int64),
524 | "input_ids_b": tf.FixedLenFeature([seq_length], tf.int64),
525 | "input_mask_b": tf.FixedLenFeature([seq_length], tf.int64),
526 | "segment_ids_b": tf.FixedLenFeature([seq_length], tf.int64),
527 | "label_ids": tf.FixedLenFeature([], tf.int64),
528 | "is_real_example": tf.FixedLenFeature([], tf.int64),
529 | }
530 |
531 | if is_regression:
532 | name_to_features["label_ids"] = tf.FixedLenFeature([], tf.float32)
533 |
534 | def _decode_record(record, name_to_features):
535 | """Decodes a record to a TensorFlow example."""
536 | example = tf.parse_single_example(record, name_to_features)
537 |
538 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
539 | # So cast all int64 to int32.
540 | for name in list(example.keys()):
541 | t = example[name]
542 | if t.dtype == tf.int64:
543 | t = tf.to_int32(t)
544 | example[name] = t
545 |
546 | return example
547 |
548 | def input_fn(mode, params):
549 | """The actual input function."""
550 | if mode == tf.estimator.ModeKeys.TRAIN:
551 | batch_size = params["train_batch_size"]
552 | elif mode == tf.estimator.ModeKeys.EVAL:
553 | batch_size = params["eval_batch_size"]
554 | elif mode == tf.estimator.ModeKeys.PREDICT:
555 | batch_size = params["predict_batch_size"]
556 | else:
557 | raise NotImplementedError
558 |
559 | # For training, we want a lot of parallel reading and shuffling.
560 | # For eval, we want no shuffling and parallel reading doesn't matter.
561 | d = tf.data.TFRecordDataset(input_file)
562 | # if is_training:
563 | # d = d.repeat()
564 | # d = d.shuffle(buffer_size=100)
565 |
566 | if is_training:
567 | d = d.shuffle(buffer_size=1000, reshuffle_each_iteration=True)
568 | d = d.repeat()
569 |
570 | d = d.apply(
571 | tf.contrib.data.map_and_batch(
572 | lambda record: _decode_record(record, name_to_features),
573 | batch_size=batch_size,
574 | drop_remainder=drop_remainder))
575 |
576 | return d
577 |
578 | return input_fn
579 |
580 |
581 |
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat in ("Cc", "Cf"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/tokenization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import tempfile
21 | import tokenization
22 | import six
23 | import tensorflow as tf
24 |
25 |
26 | class TokenizationTest(tf.test.TestCase):
27 |
28 | def test_full_tokenizer(self):
29 | vocab_tokens = [
30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
31 | "##ing", ","
32 | ]
33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
34 | if six.PY2:
35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
36 | else:
37 | vocab_writer.write("".join(
38 | [x + "\n" for x in vocab_tokens]).encode("utf-8"))
39 |
40 | vocab_file = vocab_writer.name
41 |
42 | tokenizer = tokenization.FullTokenizer(vocab_file)
43 | os.unlink(vocab_file)
44 |
45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
47 |
48 | self.assertAllEqual(
49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
50 |
51 | def test_chinese(self):
52 | tokenizer = tokenization.BasicTokenizer()
53 |
54 | self.assertAllEqual(
55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"),
56 | [u"ah", u"\u535A", u"\u63A8", u"zz"])
57 |
58 | def test_basic_tokenizer_lower(self):
59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
60 |
61 | self.assertAllEqual(
62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
63 | ["hello", "!", "how", "are", "you", "?"])
64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
65 |
66 | def test_basic_tokenizer_no_lower(self):
67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
68 |
69 | self.assertAllEqual(
70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
71 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
72 |
73 | def test_wordpiece_tokenizer(self):
74 | vocab_tokens = [
75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
76 | "##ing"
77 | ]
78 |
79 | vocab = {}
80 | for (i, token) in enumerate(vocab_tokens):
81 | vocab[token] = i
82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
83 |
84 | self.assertAllEqual(tokenizer.tokenize(""), [])
85 |
86 | self.assertAllEqual(
87 | tokenizer.tokenize("unwanted running"),
88 | ["un", "##want", "##ed", "runn", "##ing"])
89 |
90 | self.assertAllEqual(
91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
92 |
93 | def test_convert_tokens_to_ids(self):
94 | vocab_tokens = [
95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
96 | "##ing"
97 | ]
98 |
99 | vocab = {}
100 | for (i, token) in enumerate(vocab_tokens):
101 | vocab[token] = i
102 |
103 | self.assertAllEqual(
104 | tokenization.convert_tokens_to_ids(
105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
106 |
107 | def test_is_whitespace(self):
108 | self.assertTrue(tokenization._is_whitespace(u" "))
109 | self.assertTrue(tokenization._is_whitespace(u"\t"))
110 | self.assertTrue(tokenization._is_whitespace(u"\r"))
111 | self.assertTrue(tokenization._is_whitespace(u"\n"))
112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
113 |
114 | self.assertFalse(tokenization._is_whitespace(u"A"))
115 | self.assertFalse(tokenization._is_whitespace(u"-"))
116 |
117 | def test_is_control(self):
118 | self.assertTrue(tokenization._is_control(u"\u0005"))
119 |
120 | self.assertFalse(tokenization._is_control(u"A"))
121 | self.assertFalse(tokenization._is_control(u" "))
122 | self.assertFalse(tokenization._is_control(u"\t"))
123 | self.assertFalse(tokenization._is_control(u"\r"))
124 | self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
125 |
126 | def test_is_punctuation(self):
127 | self.assertTrue(tokenization._is_punctuation(u"-"))
128 | self.assertTrue(tokenization._is_punctuation(u"$"))
129 | self.assertTrue(tokenization._is_punctuation(u"`"))
130 | self.assertTrue(tokenization._is_punctuation(u"."))
131 |
132 | self.assertFalse(tokenization._is_punctuation(u"A"))
133 | self.assertFalse(tokenization._is_punctuation(u" "))
134 |
135 |
136 | if __name__ == "__main__":
137 | tf.test.main()
138 |
--------------------------------------------------------------------------------