├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── build.sh ├── download_glue_data.py ├── flow ├── __init__.py ├── config │ ├── config_l2_d2_w32.json │ ├── config_l2_d3_w16.json │ ├── config_l2_d3_w32.json │ ├── config_l3_d2_w32.json │ ├── config_l3_d3_w16.json │ ├── config_l3_d3_w32.json │ └── dump_config.py ├── glow_1x1.py ├── glow_init_hook.py └── glow_ops_1x1.py ├── img └── bert-flow.png ├── modeling.py ├── optimization.py ├── optimization_bert_flow.py ├── run_siamese.py ├── scripts ├── eval_stsb.py └── train_siamese.sh ├── siamese_utils.py ├── tokenization.py └── tokenization_test.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | BERT needs to maintain permanent compatibility with the pre-trained model files, 4 | so we do not plan to make any major changes to this library (other than what was 5 | promised in the README). However, we can accept small patches related to 6 | re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # On the Sentence Embeddings from Pre-trained Language Models 2 | 3 |

4 | 5 |

6 | 7 | This is a TensorFlow implementation of the following [paper](https://arxiv.org/abs/2011.05864): 8 | 9 | ``` 10 | On the Sentence Embeddings from Pre-trained Language Models 11 | Bohan Li, Hao Zhou, Junxian He, Mingxuan Wang, Yiming Yang, Lei Li 12 | EMNLP 2020 13 | ``` 14 | 15 | 16 | 17 | Model | Spearman's rho 18 | -------------------------------------------- | :-------------: 19 | BERT-large-NLI | 77.80 20 | BERT-large-NLI-last2avg | 78.45 21 | BERT-large-NLI-flow (target, train only) | 80.54 22 | BERT-large-NLI-flow (target, train+dev+test) | 81.18 23 | 24 | 25 | Please contact bohanl1@cs.cmu.edu if you have any questions. 26 | 27 | 28 | ## Requirements 29 | 30 | * Python >= 3.6 31 | * TensorFlow >= 1.14 32 | 33 | ## Preparation 34 | 35 | ### Pretrained BERT models 36 | ```bash 37 | export BERT_PREMODELS="../bert_premodels" 38 | mkdir ${BERT_PREMODELS}; cd ${BERT_PREMODELS} 39 | 40 | # then download the pre-trained BERT models from https://github.com/google-research/bert 41 | curl -O https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip 42 | curl -O https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip 43 | 44 | ls ${BERT_PREMODELS}/uncased_L-12_H-768_A-12 # base 45 | ls ${BERT_PREMODELS}/uncased_L-24_H-1024_A-16 # large 46 | ``` 47 | 48 | ### GLUE 49 | ```bash 50 | export GLUE_DIR="../glue_data" 51 | python download_glue_data.py --data_dir=${GLUE_DIR} 52 | 53 | # then download the labeled test set of STS-B 54 | cd ../glue_data/STS-B 55 | curl -O https://raw.githubusercontent.com/kawine/usif/master/STSBenchmark/sts-test.csv 56 | ``` 57 | 58 | ### SentEval 59 | ```bash 60 | cd .. 61 | git clone https://github.com/facebookresearch/SentEval 62 | ``` 63 | 64 | ## Usage 65 | 66 | ### Fine-tune BERT with NLI supervision (optional) 67 | ```bash 68 | export OUTPUT_PARENT_DIR="../exp" 69 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data 70 | mkdir ${CACHED_DIR} 71 | 72 | export RANDOM_SEED=1234 73 | export CUDA_VISIBLE_DEVICES=0 74 | export BERT_NAME="large" 75 | export TASK_NAME="ALLNLI" 76 | unset INIT_CKPT 77 | bash scripts/train_siamese.sh train \ 78 | "--exp_name=exp_${BERT_NAME}_${RANDOM_SEED} \ 79 | --num_train_epochs=1.0 \ 80 | --learning_rate=2e-5 \ 81 | --train_batch_size=16 \ 82 | --cached_dir=${CACHED_DIR}" 83 | 84 | 85 | # evaluation 86 | export RANDOM_SEED=1234 87 | export CUDA_VISIBLE_DEVICES=0 88 | export TASK_NAME=STS-B 89 | export BERT_NAME=large 90 | export OUTPUT_PARENT_DIR="../exp" 91 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_${BERT_NAME}_${RANDOM_SEED}/model.ckpt-60108 92 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data 93 | export EXP_NAME=exp_${BERT_NAME}_${RANDOM_SEED}_eval 94 | bash scripts/train_siamese.sh predict \ 95 | "--exp_name=${EXP_NAME} \ 96 | --cached_dir=${CACHED_DIR} \ 97 | --sentence_embedding_type=avg \ 98 | --flow=0 --flow_loss=0 \ 99 | --num_examples=0 \ 100 | --num_train_epochs=1e-10" 101 | ``` 102 | 103 | Note: You may want to add `--use_xla` to speed up the BERT fine-tuning. 104 | 105 | ### Unsupervised learning of flow-based generative models 106 | ```bash 107 | export CUDA_VISIBLE_DEVICES=0 108 | export TASK_NAME=STS-B 109 | export BERT_NAME=large 110 | export OUTPUT_PARENT_DIR="../exp" 111 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_large_1234/model.ckpt-60108 112 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data 113 | bash scripts/train_siamese.sh train \ 114 | "--exp_name_prefix=exp \ 115 | --cached_dir=${CACHED_DIR} \ 116 | --sentence_embedding_type=avg-last-2 \ 117 | --flow=1 --flow_loss=1 \ 118 | --num_examples=0 \ 119 | --num_train_epochs=1.0 \ 120 | --flow_learning_rate=1e-3 \ 121 | --use_full_for_training=1" 122 | 123 | # evaluation 124 | export CUDA_VISIBLE_DEVICES=0 125 | export TASK_NAME=STS-B 126 | export BERT_NAME=large 127 | export OUTPUT_PARENT_DIR="../exp" 128 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_large_1234/model.ckpt-60108 129 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data 130 | export EXP_NAME=exp_t_STS-B_ep_1.00_lr_5.00e-05_e_avg-last-2_f_11_1.00e-03_allsplits 131 | bash scripts/train_siamese.sh predict \ 132 | "--exp_name=${EXP_NAME} \ 133 | --cached_dir=${CACHED_DIR} \ 134 | --sentence_embedding_type=avg-last-2 \ 135 | --flow=1 --flow_loss=1 \ 136 | --num_examples=0 \ 137 | --num_train_epochs=1.0 \ 138 | --flow_learning_rate=1e-3 \ 139 | --use_full_for_training=1" 140 | ``` 141 | 142 | ### Fit flow with only the training set of STS-B 143 | ```bash 144 | export CUDA_VISIBLE_DEVICES=0 145 | export TASK_NAME=STS-B 146 | export BERT_NAME=large 147 | export OUTPUT_PARENT_DIR="../exp" 148 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_large_1234/model.ckpt-60108 149 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data 150 | bash scripts/train_siamese.sh train \ 151 | "--exp_name_prefix=exp \ 152 | --cached_dir=${CACHED_DIR} \ 153 | --sentence_embedding_type=avg-last-2 \ 154 | --flow=1 --flow_loss=1 \ 155 | --num_examples=0 \ 156 | --num_train_epochs=1.0 \ 157 | --flow_learning_rate=1e-3 \ 158 | --use_full_for_training=0" 159 | 160 | # evaluation 161 | export CUDA_VISIBLE_DEVICES=0 162 | export TASK_NAME=STS-B 163 | export BERT_NAME=large 164 | export OUTPUT_PARENT_DIR="../exp" 165 | export INIT_CKPT=${OUTPUT_PARENT_DIR}/exp_large_1234/model.ckpt-60108 166 | export CACHED_DIR=${OUTPUT_PARENT_DIR}/cached_data 167 | export EXP_NAME=exp_t_STS-B_ep_1.00_lr_5.00e-05_e_avg-last-2_f_11_1.00e-03 168 | bash scripts/train_siamese.sh predict \ 169 | "--exp_name=${EXP_NAME} \ 170 | --cached_dir=${CACHED_DIR} \ 171 | --sentence_embedding_type=avg-last-2 \ 172 | --flow=1 --flow_loss=1 \ 173 | --num_examples=0 \ 174 | --num_train_epochs=1.0 \ 175 | --flow_learning_rate=1e-3 \ 176 | --use_full_for_training=1" 177 | ``` 178 | 179 | ## Download our models 180 | Our models are available at https://drive.google.com/file/d/1-vO47t5SPFfzZPKkkhSe4tXhn8u--KLR/view?usp=sharing 181 | 182 | ## Reference 183 | 184 | ``` 185 | @inproceedings{li2020emnlp, 186 | title = {On the Sentence Embeddings from Pre-trained Language Models}, 187 | author = {Bohan Li and Hao Zhou and Junxian He and Mingxuan Wang and Yiming Yang and Lei Li}, 188 | booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 189 | month = {November}, 190 | year = {2020} 191 | } 192 | 193 | ``` 194 | 195 | ## Acknowledgements 196 | 197 | A large portion of this repo is borrowed from the following projects: 198 | - https://github.com/google-research/bert 199 | - https://github.com/zihangdai/xlnet 200 | - https://github.com/tensorflow/tensor2tensor 201 | 202 | 203 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bohanli/BERT-flow/7fa8f6d4a1a73e2c2f8549799d9bafc9d6048d67/build.sh -------------------------------------------------------------------------------- /download_glue_data.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | 3 | Note: for legal reasons, we are unable to host MRPC. 4 | You can either use the version hosted by the SentEval team, which is already tokenized, 5 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 6 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 7 | You should then rename and place specific files in a folder (see below for an example). 8 | 9 | mkdir MRPC 10 | cabextract MSRParaphraseCorpus.msi -d MRPC 11 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 12 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 13 | rm MRPC/_* 14 | rm MSRParaphraseCorpus.msi 15 | 16 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. 17 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! 18 | ''' 19 | 20 | import os 21 | import sys 22 | import shutil 23 | import argparse 24 | import tempfile 25 | import urllib.request 26 | import zipfile 27 | 28 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 29 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 30 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 31 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 32 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 33 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 34 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 35 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 36 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', 37 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 38 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 39 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 40 | 41 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' 42 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' 43 | 44 | def download_and_extract(task, data_dir): 45 | print("Downloading and extracting %s..." % task) 46 | data_file = "%s.zip" % task 47 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 48 | with zipfile.ZipFile(data_file) as zip_ref: 49 | zip_ref.extractall(data_dir) 50 | os.remove(data_file) 51 | print("\tCompleted!") 52 | 53 | def format_mrpc(data_dir, path_to_data): 54 | print("Processing MRPC...") 55 | mrpc_dir = os.path.join(data_dir, "MRPC") 56 | if not os.path.isdir(mrpc_dir): 57 | os.mkdir(mrpc_dir) 58 | if path_to_data: 59 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 60 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 61 | else: 62 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) 63 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 64 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 65 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 66 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 67 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 68 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 69 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 70 | 71 | dev_ids = [] 72 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: 73 | for row in ids_fh: 74 | dev_ids.append(row.strip().split('\t')) 75 | 76 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 77 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 78 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 79 | header = data_fh.readline() 80 | train_fh.write(header) 81 | dev_fh.write(header) 82 | for row in data_fh: 83 | label, id1, id2, s1, s2 = row.strip().split('\t') 84 | if [id1, id2] in dev_ids: 85 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 86 | else: 87 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 88 | 89 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 90 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 91 | header = data_fh.readline() 92 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 93 | for idx, row in enumerate(data_fh): 94 | label, id1, id2, s1, s2 = row.strip().split('\t') 95 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 96 | print("\tCompleted!") 97 | 98 | def download_diagnostic(data_dir): 99 | print("Downloading and extracting diagnostic...") 100 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 101 | os.mkdir(os.path.join(data_dir, "diagnostic")) 102 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 103 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 104 | print("\tCompleted!") 105 | return 106 | 107 | def get_tasks(task_names): 108 | task_names = task_names.split(',') 109 | if "all" in task_names: 110 | tasks = TASKS 111 | else: 112 | tasks = [] 113 | for task_name in task_names: 114 | assert task_name in TASKS, "Task %s not found!" % task_name 115 | tasks.append(task_name) 116 | return tasks 117 | 118 | def main(arguments): 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 121 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 122 | type=str, default='all') 123 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 124 | type=str, default='') 125 | args = parser.parse_args(arguments) 126 | 127 | if not os.path.isdir(args.data_dir): 128 | os.mkdir(args.data_dir) 129 | tasks = get_tasks(args.tasks) 130 | 131 | for task in tasks: 132 | if task == 'MRPC': 133 | format_mrpc(args.data_dir, args.path_to_mrpc) 134 | elif task == 'diagnostic': 135 | download_diagnostic(args.data_dir) 136 | else: 137 | download_and_extract(task, args.data_dir) 138 | 139 | 140 | if __name__ == '__main__': 141 | sys.exit(main(sys.argv[1:])) -------------------------------------------------------------------------------- /flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bohanli/BERT-flow/7fa8f6d4a1a73e2c2f8549799d9bafc9d6048d67/flow/__init__.py -------------------------------------------------------------------------------- /flow/config/config_l2_d2_w32.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_architecture": "glow_resnet", 3 | "activation": "relu", 4 | "coupling": "additive", 5 | "coupling_width": 32, 6 | "coupling_dropout": 0.0, 7 | "top_prior": "normal", 8 | "n_levels": 2, 9 | "depth": 2, 10 | "permutation": true, 11 | "use_fp16": false 12 | } 13 | -------------------------------------------------------------------------------- /flow/config/config_l2_d3_w16.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_architecture": "glow_resnet", 3 | "activation": "relu", 4 | "coupling": "additive", 5 | "coupling_width": 32, 6 | "coupling_dropout": 0.0, 7 | "top_prior": "normal", 8 | "n_levels": 2, 9 | "depth": 3, 10 | "permutation": true, 11 | "use_fp16": false 12 | } 13 | -------------------------------------------------------------------------------- /flow/config/config_l2_d3_w32.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_architecture": "glow_resnet", 3 | "activation": "relu", 4 | "coupling": "additive", 5 | "coupling_width": 32, 6 | "coupling_dropout": 0.0, 7 | "top_prior": "normal", 8 | "n_levels": 2, 9 | "depth": 3, 10 | "permutation": true, 11 | "use_fp16": false 12 | } 13 | -------------------------------------------------------------------------------- /flow/config/config_l3_d2_w32.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_architecture": "glow_resnet", 3 | "activation": "relu", 4 | "coupling": "additive", 5 | "coupling_width": 32, 6 | "coupling_dropout": 0.0, 7 | "top_prior": "normal", 8 | "n_levels": 3, 9 | "depth": 2, 10 | "permutation": true, 11 | "use_fp16": false 12 | } 13 | -------------------------------------------------------------------------------- /flow/config/config_l3_d3_w16.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_architecture": "glow_resnet", 3 | "activation": "relu", 4 | "coupling": "additive", 5 | "coupling_width": 16, 6 | "coupling_dropout": 0.0, 7 | "top_prior": "normal", 8 | "n_levels": 3, 9 | "depth": 3, 10 | "permutation": true, 11 | "use_fp16": false 12 | } 13 | -------------------------------------------------------------------------------- /flow/config/config_l3_d3_w32.json: -------------------------------------------------------------------------------- 1 | { 2 | "latent_architecture": "glow_resnet", 3 | "activation": "relu", 4 | "coupling": "additive", 5 | "coupling_width": 32, 6 | "coupling_dropout": 0.0, 7 | "top_prior": "normal", 8 | "n_levels": 3, 9 | "depth": 3, 10 | "permutation": true, 11 | "use_fp16": false 12 | } 13 | -------------------------------------------------------------------------------- /flow/config/dump_config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | model_config = { 4 | "latent_architecture": "glow_resnet", 5 | "activation": "relu", # Activation - Relu or Gatu 6 | "coupling": "affine", # Coupling layer, additive or affine. 7 | "coupling_width": 512, 8 | "coupling_dropout": 0.0, 9 | "top_prior": "normal", 10 | "n_levels": 3, 11 | "depth": 6, 12 | 13 | "use_fp16": False # not implemented yet 14 | } 15 | 16 | with open('/mnt/cephfs_new_wj/mlnlp/libohan.05/text_flow/config/model_config_normal.json', 'w') as jp: 17 | json.dump(model_config, jp, indent=4) 18 | 19 | -------------------------------------------------------------------------------- /flow/glow_1x1.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope 3 | import flow.glow_ops_1x1 as glow_ops 4 | from flow.glow_ops_1x1 import get_shape_list 5 | import flow.glow_init_hook 6 | 7 | 8 | import numpy as np 9 | import os, sys 10 | 11 | arg_scope = tf.contrib.framework.arg_scope 12 | add_arg_scope = tf.contrib.framework.add_arg_scope 13 | 14 | class AttrDict(dict): 15 | def __init__(self, *args, **kwargs): 16 | super(AttrDict, self).__init__(*args, **kwargs) 17 | self.__dict__ = self 18 | 19 | class Glow(): 20 | 21 | def __init__(self, hparams): 22 | self.hparams = hparams 23 | 24 | @property 25 | def is_predicting(self): 26 | return not self.is_training 27 | 28 | @staticmethod 29 | def train_hooks(): 30 | #del hook_context 31 | return [glow_init_hook.GlowInitHook()] 32 | 33 | def top_prior(self): 34 | """Objective based on the prior over latent z. 35 | 36 | Returns: 37 | dist: instance of tfp.distributions.Normal, prior distribution. 38 | """ 39 | return glow_ops.top_prior( 40 | "top_prior", self.z_top_shape, learn_prior=self.hparams.top_prior) 41 | 42 | def body(self, features, is_training): 43 | if is_training: 44 | init_features = features 45 | init_op = self.objective_tower(init_features, init=True) 46 | init_op = tf.Print( 47 | init_op, [init_op], message="Triggering data-dependent init.", 48 | first_n=20) 49 | tf.compat.v1.add_to_collection("glow_init_op", init_op) 50 | return self.objective_tower(features, init=False) 51 | 52 | def objective_tower(self, features, init=True): 53 | """Objective in terms of bits-per-pixel. 54 | """ 55 | #features = tf.expand_dims(features, [1, 2]) 56 | features = features[:, tf.newaxis, tf.newaxis, :] 57 | x = features 58 | 59 | objective = 0 60 | 61 | # The arg_scope call ensures that the actnorm parameters are set such that 62 | # the per-channel output activations have zero mean and unit variance 63 | # ONLY during the first step. After that the parameters are learned 64 | # through optimisation. 65 | ops = [glow_ops.get_variable_ddi, glow_ops.actnorm, glow_ops.get_dropout] 66 | with arg_scope(ops, init=init): 67 | encoder = glow_ops.encoder_decoder 68 | 69 | self.z, encoder_objective, self.eps, _, _ = encoder( 70 | "flow", x, self.hparams, eps=None, reverse=False) 71 | objective += encoder_objective 72 | 73 | self.z_top_shape = get_shape_list(self.z) 74 | prior_dist = self.top_prior() 75 | prior_objective = tf.reduce_sum( 76 | prior_dist.log_prob(self.z), axis=[1, 2, 3]) 77 | #self.z_sample = prior_dist.sample() 78 | objective += prior_objective 79 | 80 | # bits per pixel 81 | _, h, w, c = get_shape_list(x) 82 | objective = -objective / (np.log(2) * h * w * c) 83 | 84 | self.z = tf.concat(self.eps + [self.z], axis=-1) 85 | return objective -------------------------------------------------------------------------------- /flow/glow_init_hook.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class GlowInitHook(tf.estimator.SessionRunHook): 5 | """ 6 | Hook that runs data-dependent initialization once before the first step. 7 | 8 | The init op is stored in the tf collection glow_init_op. Look at the 9 | "body" in glow.py for more details. 10 | """ 11 | 12 | def after_create_session(self, session, coord): 13 | del coord 14 | global_step = session.run(tf.compat.v1.train.get_or_create_global_step()) 15 | if global_step == 0: 16 | ddi = tf.get_collection("glow_init_op") 17 | # In-case of a multi-GPU system, this just runs the first op in the 18 | # collection. 19 | print(ddi) 20 | if ddi: 21 | session.run(ddi[0]) 22 | #input() 23 | -------------------------------------------------------------------------------- /flow/glow_ops_1x1.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from 3 | https://github.com/tensorflow/tensor2tensor/blob/8a084a4d56/tensor2tensor/models/research/glow_ops.py 4 | 5 | modifications are as follows: 6 | 1. replace tfp with tf because neither tfp 0.6 or 0.7 is compatible with tf 1.14 7 | 2. remove support for video-related operators like conv3d 8 | 3. remove support for conditional distributions 9 | """ 10 | import tensorflow as tf 11 | from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope 12 | # import tensorflow_probability as tfp 13 | 14 | import functools 15 | import numpy as np 16 | import scipy 17 | 18 | 19 | def get_shape_list(x): 20 | """Return list of dims, statically where possible.""" 21 | x = tf.convert_to_tensor(x) 22 | 23 | # If unknown rank, return dynamic shape 24 | if x.get_shape().dims is None: 25 | return tf.shape(x) 26 | 27 | static = x.get_shape().as_list() 28 | shape = tf.shape(x) 29 | 30 | ret = [] 31 | for i, dim in enumerate(static): 32 | if dim is None: 33 | dim = shape[i] 34 | ret.append(dim) 35 | return ret 36 | 37 | def get_eps(dist, x): 38 | """Z = (X - mu) / sigma.""" 39 | return (x - dist.loc) / dist.scale 40 | 41 | 42 | def set_eps(dist, eps): 43 | """Z = eps * sigma + mu.""" 44 | return eps * dist.scale + dist.loc 45 | 46 | 47 | # =============================================== 48 | @add_arg_scope 49 | def assign(w, initial_value): 50 | w = w.assign(initial_value) 51 | with tf.control_dependencies([w]): 52 | return w 53 | 54 | @add_arg_scope 55 | def get_variable_ddi(name, shape, initial_value, dtype=tf.float32, init=False, 56 | trainable=True): 57 | """Wrapper for data-dependent initialization.""" 58 | # If init is a tf bool: w is assigned dynamically at runtime. 59 | # If init is a python bool: then w is determined during graph construction. 60 | w = tf.compat.v1.get_variable(name, shape, dtype, None, trainable=trainable) 61 | if isinstance(init, bool): 62 | if init: 63 | return assign(w, initial_value) 64 | return w 65 | else: 66 | return tf.cond(init, lambda: assign(w, initial_value), lambda: w) 67 | 68 | @add_arg_scope 69 | def get_dropout(x, rate=0.0, init=True): 70 | """Dropout x with dropout_rate = rate. 71 | Apply zero dropout during init or prediction time. 72 | Args: 73 | x: 4-D Tensor, shape=(NHWC). 74 | rate: Dropout rate. 75 | init: Initialization. 76 | Returns: 77 | x: activations after dropout. 78 | """ 79 | if init or rate == 0: 80 | return x 81 | return tf.layers.dropout(x, rate=rate, training=True) # TODO 82 | 83 | def default_initializer(std=0.05): 84 | return tf.random_normal_initializer(0., std) 85 | 86 | # =============================================== 87 | 88 | # Activation normalization 89 | # Convenience function that does centering+scaling 90 | 91 | @add_arg_scope 92 | def actnorm(name, x, logscale_factor=3., reverse=False, init=False, 93 | trainable=True): 94 | """x_{ij} = s x x_{ij} + b. Per-channel scaling and bias. 95 | If init is set to True, the scaling and bias are initialized such 96 | that the mean and variance of the output activations of the first minibatch 97 | are zero and one respectively. 98 | Args: 99 | name: variable scope. 100 | x: input 101 | logscale_factor: Used in actnorm_scale. Optimizes f(ls*s') instead of f(s) 102 | where s' = s / ls. Helps in faster convergence. 103 | reverse: forward or reverse operation. 104 | init: Whether or not to do data-dependent initialization. 105 | trainable: 106 | Returns: 107 | x: output after adding bias and scaling. 108 | objective: log(sum(s)) 109 | """ 110 | var_arg_scope = arg_scope([get_variable_ddi], trainable=trainable) 111 | var_scope = tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE) 112 | 113 | with var_scope, var_arg_scope: 114 | if not reverse: 115 | x = actnorm_center(name + "_center", x, reverse, init=init) 116 | x, objective = actnorm_scale( 117 | name + "_scale", x, logscale_factor=logscale_factor, 118 | reverse=reverse, init=init) 119 | else: 120 | x, objective = actnorm_scale( 121 | name + "_scale", x, logscale_factor=logscale_factor, 122 | reverse=reverse, init=init) 123 | x = actnorm_center(name + "_center", x, reverse, init=init) 124 | return x, objective 125 | 126 | 127 | @add_arg_scope 128 | def actnorm_center(name, x, reverse=False, init=False): 129 | """Add a bias to x. 130 | Initialize such that the output of the first minibatch is zero centered 131 | per channel. 132 | Args: 133 | name: scope 134 | x: 2-D or 4-D Tensor. 135 | reverse: Forward or backward operation. 136 | init: data-dependent initialization. 137 | Returns: 138 | x_center: (x + b), if reverse is True and (x - b) otherwise. 139 | """ 140 | shape = get_shape_list(x) 141 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 142 | assert len(shape) == 2 or len(shape) == 4 143 | if len(shape) == 2: 144 | x_mean = tf.reduce_mean(x, [0], keepdims=True) 145 | b = get_variable_ddi("b", (1, shape[1]), initial_value=-x_mean, 146 | init=init) 147 | elif len(shape) == 4: 148 | x_mean = tf.reduce_mean(x, [0, 1, 2], keepdims=True) 149 | b = get_variable_ddi( 150 | "b", (1, 1, 1, shape[3]), initial_value=-x_mean, init=init) 151 | 152 | if not reverse: 153 | x += b 154 | else: 155 | x -= b 156 | return x 157 | 158 | 159 | @add_arg_scope 160 | def actnorm_scale(name, x, logscale_factor=3., reverse=False, init=False): 161 | """Per-channel scaling of x.""" 162 | x_shape = get_shape_list(x) 163 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 164 | 165 | # Variance initialization logic. 166 | assert len(x_shape) == 2 or len(x_shape) == 4 167 | if len(x_shape) == 2: 168 | x_var = tf.reduce_mean(x**2, [0], keepdims=True) 169 | logdet_factor = 1 170 | var_shape = (1, x_shape[1]) 171 | elif len(x_shape) == 4: 172 | x_var = tf.reduce_mean(x**2, [0, 1, 2], keepdims=True) 173 | logdet_factor = x_shape[1]*x_shape[2] 174 | var_shape = (1, 1, 1, x_shape[3]) 175 | 176 | init_value = tf.math.log(1.0 / (tf.sqrt(x_var) + 1e-6)) / logscale_factor 177 | logs = get_variable_ddi("logs", var_shape, initial_value=init_value, 178 | init=init) 179 | logs = logs * logscale_factor 180 | 181 | # Function and reverse function. 182 | if not reverse: 183 | x = x * tf.exp(logs) 184 | else: 185 | x = x * tf.exp(-logs) 186 | 187 | # Objective calculation, h * w * sum(log|s|) 188 | dlogdet = tf.reduce_sum(logs) * logdet_factor 189 | if reverse: 190 | dlogdet *= -1 191 | return x, dlogdet 192 | 193 | 194 | # =============================================== 195 | 196 | 197 | @add_arg_scope 198 | def invertible_1x1_conv(name, x, reverse=False, permutation=False): 199 | """1X1 convolution on x. 200 | The 1X1 convolution is parametrized as P*L*(U + sign(s)*exp(log(s))) where 201 | 1. P is a permutation matrix. 202 | 2. L is a lower triangular matrix with diagonal entries unity. 203 | 3. U is a upper triangular matrix where the diagonal entries zero. 204 | 4. s is a vector. 205 | sign(s) and P are fixed and the remaining are optimized. P, L, U and s are 206 | initialized by the PLU decomposition of a random rotation matrix. 207 | Args: 208 | name: scope 209 | x: Input Tensor. 210 | reverse: whether the pass is from z -> x or x -> z. 211 | Returns: 212 | x_conv: x after a 1X1 convolution is applied on x. 213 | objective: sum(log(s)) 214 | """ 215 | _, height, width, channels = get_shape_list(x) 216 | w_shape = [channels, channels] 217 | 218 | if permutation: 219 | np_w = np.zeros((channels, channels)).astype("float32") 220 | for i in range(channels): 221 | np_w[i][channels-1-i] = 1. 222 | 223 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 224 | w = tf.compat.v1.get_variable("w", initializer=np_w, trainable=False) 225 | 226 | # If height or width cannot be statically determined then they end up as 227 | # tf.int32 tensors, which cannot be directly multiplied with a floating 228 | # point tensor without a cast. 229 | objective = 0. 230 | if not reverse: 231 | w = tf.reshape(w, [1, 1] + w_shape) 232 | x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format="NHWC") 233 | else: 234 | w_inv = tf.reshape(tf.linalg.inv(w), [1, 1] + w_shape) 235 | x = tf.nn.conv2d( 236 | x, w_inv, [1, 1, 1, 1], "SAME", data_format="NHWC") 237 | objective *= -1 238 | return x, objective 239 | else: 240 | # Random rotation-matrix Q 241 | random_matrix = np.random.rand(channels, channels) 242 | np_w = scipy.linalg.qr(random_matrix)[0].astype("float32") 243 | 244 | # Initialize P,L,U and s from the LU decomposition of a random rotation matrix 245 | np_p, np_l, np_u = scipy.linalg.lu(np_w) 246 | np_s = np.diag(np_u) 247 | np_sign_s = np.sign(np_s) 248 | np_log_s = np.log(np.abs(np_s)) 249 | np_u = np.triu(np_u, k=1) 250 | 251 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 252 | p = tf.compat.v1.get_variable("P", initializer=np_p, trainable=False) 253 | l = tf.compat.v1.get_variable("L", initializer=np_l) 254 | sign_s = tf.compat.v1.get_variable( 255 | "sign_S", initializer=np_sign_s, trainable=False) 256 | log_s = tf.compat.v1.get_variable("log_S", initializer=np_log_s) 257 | u = tf.compat.v1.get_variable("U", initializer=np_u) 258 | 259 | # W = P * L * (U + sign_s * exp(log_s)) 260 | l_mask = np.tril(np.ones([channels, channels], dtype=np.float32), -1) 261 | l = l * l_mask + tf.eye(channels, channels) 262 | u = u * np.transpose(l_mask) + tf.linalg.diag(sign_s * tf.exp(log_s)) 263 | w = tf.matmul(p, tf.matmul(l, u)) 264 | 265 | # If height or width cannot be statically determined then they end up as 266 | # tf.int32 tensors, which cannot be directly multiplied with a floating 267 | # point tensor without a cast. 268 | objective = tf.reduce_sum(log_s) * tf.cast(height * width, log_s.dtype) 269 | if not reverse: 270 | w = tf.reshape(w, [1, 1] + w_shape) 271 | x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format="NHWC") 272 | else: 273 | w_inv = tf.reshape(tf.linalg.inv(w), [1, 1] + w_shape) 274 | x = tf.nn.conv2d( 275 | x, w_inv, [1, 1, 1, 1], "SAME", data_format="NHWC") 276 | objective *= -1 277 | return x, objective 278 | 279 | 280 | 281 | 282 | # =============================================== 283 | 284 | def add_edge_bias(x, filter_size): 285 | """Pad x and concatenates an edge bias across the depth of x. 286 | The edge bias can be thought of as a binary feature which is unity when 287 | the filter is being convolved over an edge and zero otherwise. 288 | Args: 289 | x: Input tensor, shape (NHWC) 290 | filter_size: filter_size to determine padding. 291 | Returns: 292 | x_pad: Input tensor, shape (NHW(c+1)) 293 | """ 294 | x_shape = get_shape_list(x) 295 | if filter_size[0] == 1 and filter_size[1] == 1: 296 | return x 297 | a = (filter_size[0] - 1) // 2 # vertical padding size 298 | b = (filter_size[1] - 1) // 2 # horizontal padding size 299 | padding = [[0, 0], [a, a], [b, b], [0, 0]] 300 | x_bias = tf.zeros(x_shape[:-1] + [1]) 301 | 302 | x = tf.pad(x, padding) 303 | x_pad = tf.pad(x_bias, padding, constant_values=1) 304 | return tf.concat([x, x_pad], axis=3) 305 | 306 | 307 | @add_arg_scope 308 | def conv(name, x, output_channels, filter_size=None, stride=None, 309 | logscale_factor=3.0, apply_actnorm=True, conv_init="default", 310 | dilations=None): 311 | """Convolutional layer with edge bias padding and optional actnorm. 312 | If x is 5-dimensional, actnorm is applied independently across every 313 | time-step. 314 | Args: 315 | name: variable scope. 316 | x: 4-D Tensor or 5-D Tensor of shape NHWC or NTHWC 317 | output_channels: Number of output channels. 318 | filter_size: list of ints, if None [3, 3] and [2, 3, 3] are defaults for 319 | 4-D and 5-D input tensors respectively. 320 | stride: list of ints, default stride: 1 321 | logscale_factor: see actnorm for parameter meaning. 322 | apply_actnorm: if apply_actnorm the activations of the first minibatch 323 | have zero mean and unit variance. Else, there is no scaling 324 | applied. 325 | conv_init: default or zeros. default is a normal distribution with 0.05 std. 326 | dilations: List of integers, apply dilations. 327 | Returns: 328 | x: actnorm(conv2d(x)) 329 | Raises: 330 | ValueError: if init is set to "zeros" and apply_actnorm is set to True. 331 | """ 332 | if conv_init == "zeros" and apply_actnorm: 333 | raise ValueError("apply_actnorm is unstable when init is set to zeros.") 334 | 335 | x_shape = get_shape_list(x) 336 | is_2d = len(x_shape) == 4 337 | num_steps = x_shape[1] 338 | 339 | # set filter_size, stride and in_channels 340 | if is_2d: 341 | if filter_size is None: 342 | filter_size = [1, 1] # filter_size = [3, 3] 343 | if stride is None: 344 | stride = [1, 1] 345 | if dilations is None: 346 | dilations = [1, 1, 1, 1] 347 | actnorm_func = actnorm 348 | x = add_edge_bias(x, filter_size=filter_size) 349 | conv_filter = tf.nn.conv2d 350 | else: 351 | raise NotImplementedError('x must be a NHWC 4-D Tensor!') 352 | 353 | in_channels = get_shape_list(x)[-1] 354 | filter_shape = filter_size + [in_channels, output_channels] 355 | stride_shape = [1] + stride + [1] 356 | 357 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 358 | 359 | if conv_init == "default": 360 | initializer = default_initializer() 361 | elif conv_init == "zeros": 362 | initializer = tf.zeros_initializer() 363 | 364 | w = tf.compat.v1.get_variable("W", filter_shape, tf.float32, initializer=initializer) 365 | x = conv_filter(x, w, stride_shape, padding="VALID", dilations=dilations) 366 | if apply_actnorm: 367 | x, _ = actnorm_func("actnorm", x, logscale_factor=logscale_factor) 368 | else: 369 | x += tf.compat.v1.get_variable("b", [1, 1, 1, output_channels], 370 | initializer=tf.zeros_initializer()) 371 | logs = tf.compat.v1.get_variable("logs", [1, output_channels], 372 | initializer=tf.zeros_initializer()) 373 | x *= tf.exp(logs * logscale_factor) 374 | return x 375 | 376 | 377 | @add_arg_scope 378 | def conv_block(name, x, mid_channels, dilations=None, activation="relu", 379 | dropout=0.0): 380 | """2 layer conv block used in the affine coupling layer. 381 | Args: 382 | name: variable scope. 383 | x: 4-D or 5-D Tensor. 384 | mid_channels: Output channels of the second layer. 385 | dilations: Optional, list of integers. 386 | activation: relu or gatu. 387 | If relu, the second layer is relu(W*x) 388 | If gatu, the second layer is tanh(W1*x) * sigmoid(W2*x) 389 | dropout: Dropout probability. 390 | Returns: 391 | x: 4-D Tensor: Output activations. 392 | """ 393 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 394 | 395 | x_shape = get_shape_list(x) 396 | is_2d = len(x_shape) == 4 397 | num_steps = x_shape[1] 398 | if is_2d: 399 | first_filter = [1, 1] # first_filter = [3, 3] 400 | second_filter = [1, 1] 401 | else: 402 | raise NotImplementedError('x must be a NHWC 4-D Tensor!') 403 | 404 | # Edge Padding + conv2d + actnorm + relu: 405 | # [output: 512 channels] 406 | x = conv("1_1", x, output_channels=mid_channels, filter_size=first_filter, 407 | dilations=dilations) 408 | x = tf.nn.relu(x) 409 | x = get_dropout(x, rate=dropout) 410 | 411 | # Padding + conv2d + actnorm + activation. 412 | # [input, output: 512 channels] 413 | if activation == "relu": 414 | x = conv("1_2", x, output_channels=mid_channels, 415 | filter_size=second_filter, dilations=dilations) 416 | x = tf.nn.relu(x) 417 | elif activation == "gatu": 418 | # x = tanh(w1*x) * sigm(w2*x) 419 | x_tanh = conv("1_tanh", x, output_channels=mid_channels, 420 | filter_size=second_filter, dilations=dilations) 421 | x_sigm = conv("1_sigm", x, output_channels=mid_channels, 422 | filter_size=second_filter, dilations=dilations) 423 | x = tf.nn.tanh(x_tanh) * tf.nn.sigmoid(x_sigm) 424 | 425 | x = get_dropout(x, rate=dropout) 426 | return x 427 | 428 | 429 | @add_arg_scope 430 | def conv_stack(name, x, mid_channels, output_channels, dilations=None, 431 | activation="relu", dropout=0.0): 432 | """3-layer convolutional stack. 433 | Args: 434 | name: variable scope. 435 | x: 5-D Tensor. 436 | mid_channels: Number of output channels of the first layer. 437 | output_channels: Number of output channels. 438 | dilations: Dilations to apply in the first 3x3 layer and the last 3x3 layer. 439 | By default, apply no dilations. 440 | activation: relu or gatu. 441 | If relu, the second layer is relu(W*x) 442 | If gatu, the second layer is tanh(W1*x) * sigmoid(W2*x) 443 | dropout: float, 0.0 444 | Returns: 445 | output: output of 3 layer conv network. 446 | """ 447 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 448 | 449 | x = conv_block("conv_block", x, mid_channels=mid_channels, 450 | dilations=dilations, activation=activation, 451 | dropout=dropout) 452 | 453 | # Final layer. 454 | x = conv("zeros", x, apply_actnorm=False, conv_init="zeros", 455 | output_channels=output_channels, dilations=dilations) 456 | return x 457 | 458 | 459 | @add_arg_scope 460 | def additive_coupling(name, x, mid_channels=512, reverse=False, 461 | activation="relu", dropout=0.0): 462 | """Reversible additive coupling layer. 463 | Args: 464 | name: variable scope. 465 | x: 4-D Tensor, shape=(NHWC). 466 | mid_channels: number of channels in the coupling layer. 467 | reverse: Forward or reverse operation. 468 | activation: "relu" or "gatu" 469 | dropout: default, 0.0 470 | Returns: 471 | output: 4-D Tensor, shape=(NHWC) 472 | objective: 0.0 473 | """ 474 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 475 | output_channels = get_shape_list(x)[-1] // 2 476 | x1, x2 = tf.split(x, num_or_size_splits=2, axis=-1) 477 | 478 | z1 = x1 479 | shift = conv_stack("nn", x1, mid_channels, output_channels=output_channels, 480 | activation=activation, dropout=dropout) 481 | 482 | if not reverse: 483 | z2 = x2 + shift 484 | else: 485 | z2 = x2 - shift 486 | return tf.concat([z1, z2], axis=3), 0.0 487 | 488 | 489 | @add_arg_scope 490 | def affine_coupling(name, x, mid_channels=512, activation="relu", 491 | reverse=False, dropout=0.0): 492 | """Reversible affine coupling layer. 493 | Args: 494 | name: variable scope. 495 | x: 4-D Tensor. 496 | mid_channels: number of channels in the coupling layer. 497 | activation: Can be either "relu" or "gatu". 498 | reverse: Forward or reverse operation. 499 | dropout: default, 0.0 500 | Returns: 501 | output: x shifted and scaled by an affine transformation. 502 | objective: log-determinant of the jacobian 503 | """ 504 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 505 | x_shape = get_shape_list(x) 506 | x1, x2 = tf.split(x, num_or_size_splits=2, axis=-1) 507 | 508 | # scale, shift = NN(x1) 509 | # If reverse: 510 | # z2 = scale * (x2 + shift) 511 | # Else: 512 | # z2 = (x2 / scale) - shift 513 | z1 = x1 514 | log_scale_and_shift = conv_stack( 515 | "nn", x1, mid_channels, x_shape[-1], activation=activation, 516 | dropout=dropout) 517 | shift = log_scale_and_shift[:, :, :, 0::2] 518 | scale = tf.nn.sigmoid(log_scale_and_shift[:, :, :, 1::2] + 2.0) 519 | if not reverse: 520 | z2 = (x2 + shift) * scale 521 | else: 522 | z2 = x2 / scale - shift 523 | 524 | objective = tf.reduce_sum(tf.math.log(scale), axis=[1, 2, 3]) 525 | if reverse: 526 | objective *= -1 527 | return tf.concat([z1, z2], axis=3), objective 528 | 529 | 530 | # =============================================== 531 | 532 | 533 | @add_arg_scope 534 | def single_conv_dist(name, x, output_channels=None): 535 | """A 1x1 convolution mapping x to a standard normal distribution at init. 536 | Args: 537 | name: variable scope. 538 | x: 4-D Tensor. 539 | output_channels: number of channels of the mean and std. 540 | """ 541 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 542 | x_shape = get_shape_list(x) 543 | if output_channels is None: 544 | output_channels = x_shape[-1] 545 | mean_log_scale = conv("conv2d", x, output_channels=2*output_channels, 546 | conv_init="zeros", apply_actnorm=False) 547 | mean = mean_log_scale[:, :, :, 0::2] 548 | log_scale = mean_log_scale[:, :, :, 1::2] 549 | return tf.distributions.Normal(mean, tf.exp(log_scale)) 550 | 551 | 552 | # # =============================================== 553 | 554 | 555 | @add_arg_scope 556 | def revnet_step(name, x, hparams, reverse=True): 557 | """One step of glow generative flow. 558 | Actnorm + invertible 1X1 conv + affine_coupling. 559 | Args: 560 | name: used for variable scope. 561 | x: input 562 | hparams: coupling_width is the only hparam that is being used in 563 | this function. 564 | reverse: forward or reverse pass. 565 | Returns: 566 | z: Output of one step of reversible flow. 567 | """ 568 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 569 | if hparams.coupling == "additive": 570 | coupling_layer = functools.partial( 571 | additive_coupling, name="additive", reverse=reverse, 572 | mid_channels=hparams.coupling_width, 573 | activation=hparams.activation, 574 | dropout=hparams.coupling_dropout if hparams.is_training else 0) 575 | else: 576 | coupling_layer = functools.partial( 577 | affine_coupling, name="affine", reverse=reverse, 578 | mid_channels=hparams.coupling_width, 579 | activation=hparams.activation, 580 | dropout=hparams.coupling_dropout if hparams.is_training else 0) 581 | 582 | if "permutation" in hparams and hparams["permutation"] == True: 583 | ops = [ 584 | functools.partial(actnorm, name="actnorm", reverse=reverse), 585 | functools.partial(invertible_1x1_conv, name="invertible", reverse=reverse, permutation=True), 586 | coupling_layer] 587 | else: 588 | ops = [ 589 | functools.partial(actnorm, name="actnorm", reverse=reverse), 590 | functools.partial(invertible_1x1_conv, name="invertible", reverse=reverse), 591 | coupling_layer] 592 | 593 | if reverse: 594 | ops = ops[::-1] 595 | 596 | objective = 0.0 597 | for op in ops: 598 | x, curr_obj = op(x=x) 599 | objective += curr_obj 600 | return x, objective 601 | 602 | 603 | def revnet(name, x, hparams, reverse=True): 604 | """'hparams.depth' steps of generative flow. 605 | Args: 606 | name: variable scope for the revnet block. 607 | x: 4-D Tensor, shape=(NHWC). 608 | hparams: HParams. 609 | reverse: bool, forward or backward pass. 610 | Returns: 611 | x: 4-D Tensor, shape=(NHWC). 612 | objective: float. 613 | """ 614 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 615 | steps = np.arange(hparams.depth) 616 | if reverse: 617 | steps = steps[::-1] 618 | 619 | objective = 0.0 620 | for step in steps: 621 | x, curr_obj = revnet_step( 622 | "revnet_step_%d" % step, x, hparams, reverse=reverse) 623 | objective += curr_obj 624 | return x, objective 625 | 626 | # =============================================== 627 | 628 | @add_arg_scope 629 | def compute_prior(name, z, latent, hparams, condition=False, state=None, 630 | temperature=1.0): 631 | """Distribution on z_t conditioned on z_{t-1} and latent. 632 | Args: 633 | name: variable scope. 634 | z: 4-D Tensor. 635 | latent: optional, 636 | if hparams.latent_dist_encoder == "pointwise", this is a list 637 | of 4-D Tensors of length hparams.num_cond_latents. 638 | else, this is just a 4-D Tensor 639 | The first-three dimensions of the latent should be the same as z. 640 | hparams: next_frame_glow_hparams. 641 | condition: Whether or not to condition the distribution on latent. 642 | state: tf.nn.rnn_cell.LSTMStateTuple. 643 | the current state of a LSTM used to model the distribution. Used 644 | only if hparams.latent_dist_encoder = "conv_lstm". 645 | temperature: float, temperature with which to sample from the Gaussian. 646 | Returns: 647 | prior_dist: instance of tfp.distributions.Normal 648 | state: Returns updated state. 649 | Raises: 650 | ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape 651 | of latent is different from z. 652 | """ 653 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 654 | z_shape = get_shape_list(z) 655 | h = tf.zeros(z_shape, dtype=tf.float32) 656 | prior_dist = tf.distributions.Normal(h, tf.exp(h)) 657 | return prior_dist, state 658 | 659 | 660 | 661 | @add_arg_scope 662 | def split(name, x, reverse=False, eps=None, eps_std=None, cond_latents=None, 663 | hparams=None, state=None, condition=False, temperature=1.0): 664 | """Splits / concatenates x into x1 and x2 across number of channels. 665 | For the forward pass, x2 is assumed be gaussian, 666 | i.e P(x2 | x1) ~ N(mu, sigma) where mu and sigma are the outputs of 667 | a network conditioned on x1 and optionally on cond_latents. 668 | For the reverse pass, x2 is determined from mu(x1) and sigma(x1). 669 | This is deterministic/stochastic depending on whether eps is provided. 670 | Args: 671 | name: variable scope. 672 | x: 4-D Tensor, shape (NHWC). 673 | reverse: Forward or reverse pass. 674 | eps: If eps is provided, x2 is set to be mu(x1) + eps * sigma(x1). 675 | eps_std: Sample x2 with the provided eps_std. 676 | cond_latents: optionally condition x2 on cond_latents. 677 | hparams: next_frame_glow hparams. 678 | state: tf.nn.rnn_cell.LSTMStateTuple.. Current state of the LSTM over z_2. 679 | Used only when hparams.latent_dist_encoder == "conv_lstm" 680 | condition: bool, Whether or not to condition the distribution on 681 | cond_latents. 682 | temperature: Temperature with which to sample from the gaussian. 683 | Returns: 684 | If reverse: 685 | x: 4-D Tensor, concats input and x2 across channels. 686 | x2: 4-D Tensor, a sample from N(mu(x1), sigma(x1)) 687 | Else: 688 | x1: 4-D Tensor, Output of the split operation. 689 | logpb: log-probability of x2 belonging to mu(x1), sigma(x1) 690 | eps: 4-D Tensor, (x2 - mu(x1)) / sigma(x1) 691 | x2: 4-D Tensor, Latent representation at the current level. 692 | state: Current LSTM state. 693 | 4-D Tensor, only if hparams.latent_dist_encoder is set to conv_lstm. 694 | Raises: 695 | ValueError: If latent is provided and shape is not equal to NHW(C/2) 696 | where (NHWC) is the size of x. 697 | """ 698 | # TODO(mechcoder) Change the return type to be a dict. 699 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 700 | if not reverse: 701 | x1, x2 = tf.split(x, num_or_size_splits=2, axis=-1) 702 | 703 | # objective: P(x2|x1) ~N(x2 ; NN(x1)) 704 | prior_dist, state = compute_prior( 705 | "prior_on_z2", x1, cond_latents, hparams, condition, state=state) 706 | logpb = tf.reduce_sum(prior_dist.log_prob(x2), axis=[1, 2, 3]) 707 | eps = get_eps(prior_dist, x2) 708 | return x1, logpb, eps, x2, state 709 | else: 710 | prior_dist, state = compute_prior( 711 | "prior_on_z2", x, cond_latents, hparams, condition, state=state, 712 | temperature=temperature) 713 | if eps is not None: 714 | x2 = set_eps(prior_dist, eps) 715 | elif eps_std is not None: 716 | x2 = eps_std * tf.random_normal(get_shape_list(x)) 717 | else: 718 | x2 = prior_dist.sample() 719 | return tf.concat([x, x2], 3), x2, state 720 | 721 | 722 | @add_arg_scope 723 | def squeeze(name, x, factor=2, reverse=True): 724 | """Block-wise spatial squeezing of x to increase the number of channels. 725 | Args: 726 | name: Used for variable scoping. 727 | x: 4-D Tensor of shape (batch_size X H X W X C) 728 | factor: Factor by which the spatial dimensions should be squeezed. 729 | reverse: Squueze or unsqueeze operation. 730 | Returns: 731 | x: 4-D Tensor of shape (batch_size X (H//factor) X (W//factor) X 732 | (cXfactor^2). If reverse is True, then it is factor = (1 / factor) 733 | """ 734 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 735 | shape = get_shape_list(x) 736 | if factor == 1: 737 | return x 738 | height = int(shape[1]) 739 | width = int(shape[2]) 740 | n_channels = int(shape[3]) 741 | 742 | if not reverse: 743 | assert height % factor == 0 and width % factor == 0 744 | x = tf.reshape(x, [-1, height//factor, factor, 745 | width//factor, factor, n_channels]) 746 | x = tf.transpose(x, [0, 1, 3, 5, 2, 4]) 747 | x = tf.reshape(x, [-1, height//factor, width // 748 | factor, n_channels*factor*factor]) 749 | else: 750 | x = tf.reshape( 751 | x, (-1, height, width, int(n_channels/factor**2), factor, factor)) 752 | x = tf.transpose(x, [0, 1, 4, 2, 5, 3]) 753 | x = tf.reshape(x, (-1, int(height*factor), 754 | int(width*factor), int(n_channels/factor**2))) 755 | return x 756 | 757 | 758 | def get_cond_latents_at_level(cond_latents, level, hparams): 759 | """Returns a single or list of conditional latents at level 'level'.""" 760 | if cond_latents: 761 | if hparams.latent_dist_encoder in ["conv_net", "conv3d_net"]: 762 | return [cond_latent[level] for cond_latent in cond_latents] 763 | elif hparams.latent_dist_encoder in ["pointwise", "conv_lstm"]: 764 | return cond_latents[level] 765 | 766 | 767 | def check_cond_latents(cond_latents, hparams): 768 | """Shape checking for cond_latents.""" 769 | if cond_latents is None: 770 | return 771 | if not isinstance(cond_latents[0], list): 772 | cond_latents = [cond_latents] 773 | exp_num_latents = hparams.num_cond_latents 774 | if hparams.latent_dist_encoder == "conv_net": 775 | exp_num_latents += int(hparams.cond_first_frame) 776 | if len(cond_latents) != exp_num_latents: 777 | raise ValueError("Expected number of cond_latents: %d, got %d" % 778 | (exp_num_latents, len(cond_latents))) 779 | for cond_latent in cond_latents: 780 | if len(cond_latent) != hparams.n_levels - 1: 781 | raise ValueError("Expected level_latents to be %d, got %d" % 782 | (hparams.n_levels - 1, len(cond_latent))) 783 | 784 | 785 | @add_arg_scope 786 | def encoder_decoder(name, x, hparams, eps=None, reverse=False, 787 | cond_latents=None, condition=False, states=None, 788 | temperature=1.0): 789 | """Glow encoder-decoder. n_levels of (Squeeze + Flow + Split.) operations. 790 | Args: 791 | name: variable scope. 792 | x: 4-D Tensor, shape=(NHWC). 793 | hparams: HParams. 794 | eps: Stores (glow(x) - mu) / sigma during the forward pass. 795 | Used only to test if the network is reversible. 796 | reverse: Forward or reverse pass. 797 | cond_latents: list of lists of tensors. 798 | outer length equals hparams.num_cond_latents 799 | innter length equals hparams.num_levels - 1. 800 | condition: If set to True, condition the encoder/decoder on cond_latents. 801 | states: LSTM states, used only if hparams.latent_dist_encoder is set 802 | to "conv_lstm. 803 | temperature: Temperature set during sampling. 804 | Returns: 805 | x: If reverse, decoded image, else the encoded glow latent representation. 806 | objective: log-likelihood. 807 | eps: list of tensors, shape=(num_levels-1). 808 | Stores (glow(x) - mu_level(x)) / sigma_level(x)) for each level. 809 | all_latents: list of tensors, shape=(num_levels-1). 810 | Latent representations for each level. 811 | new_states: list of tensors, shape=(num_levels-1). 812 | useful only if hparams.latent_dist_encoder="conv_lstm", returns 813 | the current state of each level. 814 | """ 815 | # TODO(mechcoder) Change return_type to a dict to be backward compatible. 816 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 817 | 818 | if states and len(states) != hparams.n_levels - 1: 819 | raise ValueError("Expected length of states to be %d, got %d" % 820 | (hparams.n_levels - 1, len(states))) 821 | if states is None: 822 | states = [None] * (hparams.n_levels - 1) 823 | if eps and len(eps) != hparams.n_levels - 1: 824 | raise ValueError("Expected length of eps to be %d, got %d" % 825 | (hparams.n_levels - 1, len(eps))) 826 | if eps is None: 827 | eps = [None] * (hparams.n_levels - 1) 828 | check_cond_latents(cond_latents, hparams) 829 | 830 | objective = 0.0 831 | all_eps = [] 832 | all_latents = [] 833 | new_states = [] 834 | 835 | if not reverse: 836 | # Squeeze + Flow + Split 837 | for level in range(hparams.n_levels): 838 | # x = squeeze("squeeze_%d" % level, x, factor=2, reverse=False) 839 | 840 | x, obj = revnet("revnet_%d" % level, x, hparams, reverse=False) 841 | objective += obj 842 | 843 | if level < hparams.n_levels - 1: 844 | curr_cond_latents = get_cond_latents_at_level( 845 | cond_latents, level, hparams) 846 | x, obj, eps, z, state = split("split_%d" % level, x, reverse=False, 847 | cond_latents=curr_cond_latents, 848 | condition=condition, 849 | hparams=hparams, state=states[level]) 850 | objective += obj 851 | all_eps.append(eps) 852 | all_latents.append(z) 853 | new_states.append(state) 854 | 855 | return x, objective, all_eps, all_latents, new_states 856 | 857 | else: 858 | for level in reversed(range(hparams.n_levels)): 859 | if level < hparams.n_levels - 1: 860 | 861 | curr_cond_latents = get_cond_latents_at_level( 862 | cond_latents, level, hparams) 863 | 864 | x, latent, state = split("split_%d" % level, x, eps=eps[level], 865 | reverse=True, cond_latents=curr_cond_latents, 866 | condition=condition, hparams=hparams, 867 | state=states[level], 868 | temperature=temperature) 869 | new_states.append(state) 870 | all_latents.append(latent) 871 | 872 | x, obj = revnet( 873 | "revnet_%d" % level, x, hparams=hparams, reverse=True) 874 | objective += obj 875 | # x = squeeze("squeeze_%d" % level, x, reverse=True) 876 | return x, objective, all_latents[::-1], new_states[::-1] 877 | 878 | 879 | # =============================================== 880 | 881 | 882 | @add_arg_scope 883 | def top_prior(name, z_shape, learn_prior="normal", temperature=1.0): 884 | """Unconditional prior distribution. 885 | Args: 886 | name: variable scope 887 | z_shape: Shape of the mean / scale of the prior distribution. 888 | learn_prior: Possible options are "normal" and "single_conv". 889 | If set to "single_conv", the gaussian is parametrized by a 890 | single convolutional layer whose input are an array of zeros 891 | and initialized such that the mean and std are zero and one. 892 | If set to "normal", the prior is just a Gaussian with zero 893 | mean and unit variance. 894 | temperature: Temperature with which to sample from the Gaussian. 895 | Returns: 896 | objective: 1-D Tensor shape=(batch_size,) summed across spatial components. 897 | Raises: 898 | ValueError: If learn_prior not in "normal" or "single_conv" 899 | """ 900 | with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE): 901 | h = tf.zeros(z_shape, dtype=tf.float32) 902 | prior_dist = tf.distributions.Normal(h, tf.exp(h)) 903 | return prior_dist -------------------------------------------------------------------------------- /img/bert-flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bohanli/BERT-flow/7fa8f6d4a1a73e2c2f8549799d9bafc9d6048d67/img/bert-flow.png -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import numpy as np 27 | import six 28 | import tensorflow as tf 29 | 30 | 31 | class BertConfig(object): 32 | """Configuration for `BertModel`.""" 33 | 34 | def __init__(self, 35 | vocab_size, 36 | hidden_size=768, 37 | num_hidden_layers=12, 38 | num_attention_heads=12, 39 | intermediate_size=3072, 40 | hidden_act="gelu", 41 | hidden_dropout_prob=0.1, 42 | attention_probs_dropout_prob=0.1, 43 | max_position_embeddings=512, 44 | type_vocab_size=16, 45 | initializer_range=0.02): 46 | """Constructs BertConfig. 47 | 48 | Args: 49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 50 | hidden_size: Size of the encoder layers and the pooler layer. 51 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 52 | num_attention_heads: Number of attention heads for each attention layer in 53 | the Transformer encoder. 54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 55 | layer in the Transformer encoder. 56 | hidden_act: The non-linear activation function (function or string) in the 57 | encoder and pooler. 58 | hidden_dropout_prob: The dropout probability for all fully connected 59 | layers in the embeddings, encoder, and pooler. 60 | attention_probs_dropout_prob: The dropout ratio for the attention 61 | probabilities. 62 | max_position_embeddings: The maximum sequence length that this model might 63 | ever be used with. Typically set this to something large just in case 64 | (e.g., 512 or 1024 or 2048). 65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 66 | `BertModel`. 67 | initializer_range: The stdev of the truncated_normal_initializer for 68 | initializing all weight matrices. 69 | """ 70 | self.vocab_size = vocab_size 71 | self.hidden_size = hidden_size 72 | self.num_hidden_layers = num_hidden_layers 73 | self.num_attention_heads = num_attention_heads 74 | self.hidden_act = hidden_act 75 | self.intermediate_size = intermediate_size 76 | self.hidden_dropout_prob = hidden_dropout_prob 77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 78 | self.max_position_embeddings = max_position_embeddings 79 | self.type_vocab_size = type_vocab_size 80 | self.initializer_range = initializer_range 81 | 82 | @classmethod 83 | def from_dict(cls, json_object): 84 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 85 | config = BertConfig(vocab_size=None) 86 | for (key, value) in six.iteritems(json_object): 87 | config.__dict__[key] = value 88 | return config 89 | 90 | @classmethod 91 | def from_json_file(cls, json_file): 92 | """Constructs a `BertConfig` from a json file of parameters.""" 93 | with tf.gfile.GFile(json_file, "r") as reader: 94 | text = reader.read() 95 | return cls.from_dict(json.loads(text)) 96 | 97 | def to_dict(self): 98 | """Serializes this instance to a Python dictionary.""" 99 | output = copy.deepcopy(self.__dict__) 100 | return output 101 | 102 | def to_json_string(self): 103 | """Serializes this instance to a JSON string.""" 104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 105 | 106 | 107 | class BertModel(object): 108 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 109 | 110 | Example usage: 111 | 112 | ```python 113 | # Already been converted into WordPiece token ids 114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 117 | 118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 120 | 121 | model = modeling.BertModel(config=config, is_training=True, 122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 123 | 124 | label_embeddings = tf.get_variable(...) 125 | pooled_output = model.get_pooled_output() 126 | logits = tf.matmul(pooled_output, label_embeddings) 127 | ... 128 | ``` 129 | """ 130 | 131 | def __init__(self, 132 | config, 133 | is_training, 134 | input_ids, 135 | input_mask=None, 136 | token_type_ids=None, 137 | use_one_hot_embeddings=False, 138 | scope=None): 139 | """Constructor for BertModel. 140 | 141 | Args: 142 | config: `BertConfig` instance. 143 | is_training: bool. true for training model, false for eval model. Controls 144 | whether dropout will be applied. 145 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 146 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 148 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 149 | embeddings or tf.embedding_lookup() for the word embeddings. 150 | scope: (optional) variable scope. Defaults to "bert". 151 | 152 | Raises: 153 | ValueError: The config is invalid or one of the input tensor shapes 154 | is invalid. 155 | """ 156 | config = copy.deepcopy(config) 157 | if not is_training: 158 | config.hidden_dropout_prob = 0.0 159 | config.attention_probs_dropout_prob = 0.0 160 | 161 | input_shape = get_shape_list(input_ids, expected_rank=2) 162 | batch_size = input_shape[0] 163 | seq_length = input_shape[1] 164 | 165 | if input_mask is None: 166 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 167 | 168 | if token_type_ids is None: 169 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 170 | 171 | with tf.variable_scope(scope, default_name="bert"): 172 | with tf.variable_scope("embeddings"): 173 | # Perform embedding lookup on the word ids. 174 | (self.embedding_output, self.embedding_table) = embedding_lookup( 175 | input_ids=input_ids, 176 | vocab_size=config.vocab_size, 177 | embedding_size=config.hidden_size, 178 | initializer_range=config.initializer_range, 179 | word_embedding_name="word_embeddings", 180 | use_one_hot_embeddings=use_one_hot_embeddings) 181 | 182 | # Add positional embeddings and token type embeddings, then layer 183 | # normalize and perform dropout. 184 | self.embedding_output = embedding_postprocessor( 185 | input_tensor=self.embedding_output, 186 | use_token_type=True, 187 | token_type_ids=token_type_ids, 188 | token_type_vocab_size=config.type_vocab_size, 189 | token_type_embedding_name="token_type_embeddings", 190 | use_position_embeddings=True, 191 | position_embedding_name="position_embeddings", 192 | initializer_range=config.initializer_range, 193 | max_position_embeddings=config.max_position_embeddings, 194 | dropout_prob=config.hidden_dropout_prob) 195 | 196 | with tf.variable_scope("encoder"): 197 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 198 | # mask of shape [batch_size, seq_length, seq_length] which is used 199 | # for the attention scores. 200 | attention_mask = create_attention_mask_from_input_mask( 201 | input_ids, input_mask) 202 | 203 | # Run the stacked transformer. 204 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 205 | self.all_encoder_layers = transformer_model( 206 | input_tensor=self.embedding_output, 207 | attention_mask=attention_mask, 208 | hidden_size=config.hidden_size, 209 | num_hidden_layers=config.num_hidden_layers, 210 | num_attention_heads=config.num_attention_heads, 211 | intermediate_size=config.intermediate_size, 212 | intermediate_act_fn=get_activation(config.hidden_act), 213 | hidden_dropout_prob=config.hidden_dropout_prob, 214 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 215 | initializer_range=config.initializer_range, 216 | do_return_all_layers=True) 217 | 218 | self.sequence_output = self.all_encoder_layers[-1] 219 | # The "pooler" converts the encoded sequence tensor of shape 220 | # [batch_size, seq_length, hidden_size] to a tensor of shape 221 | # [batch_size, hidden_size]. This is necessary for segment-level 222 | # (or segment-pair-level) classification tasks where we need a fixed 223 | # dimensional representation of the segment. 224 | with tf.variable_scope("pooler"): 225 | # We "pool" the model by simply taking the hidden state corresponding 226 | # to the first token. We assume that this has been pre-trained 227 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 228 | self.pooled_output = tf.layers.dense( 229 | first_token_tensor, 230 | config.hidden_size, 231 | activation=tf.tanh, 232 | kernel_initializer=create_initializer(config.initializer_range)) 233 | 234 | def get_pooled_output(self): 235 | return self.pooled_output 236 | 237 | def get_sequence_output(self): 238 | """Gets final hidden layer of encoder. 239 | 240 | Returns: 241 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 242 | to the final hidden of the transformer encoder. 243 | """ 244 | return self.sequence_output 245 | 246 | def get_all_encoder_layers(self): 247 | return self.all_encoder_layers 248 | 249 | def get_embedding_output(self): 250 | """Gets output of the embedding lookup (i.e., input to the transformer). 251 | 252 | Returns: 253 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 254 | to the output of the embedding layer, after summing the word 255 | embeddings with the positional embeddings and the token type embeddings, 256 | then performing layer normalization. This is the input to the transformer. 257 | """ 258 | return self.embedding_output 259 | 260 | def get_embedding_table(self): 261 | return self.embedding_table 262 | 263 | 264 | def gelu(x): 265 | """Gaussian Error Linear Unit. 266 | 267 | This is a smoother version of the RELU. 268 | Original paper: https://arxiv.org/abs/1606.08415 269 | Args: 270 | x: float Tensor to perform activation. 271 | 272 | Returns: 273 | `x` with the GELU activation applied. 274 | """ 275 | cdf = 0.5 * (1.0 + tf.tanh( 276 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 277 | return x * cdf 278 | 279 | 280 | def get_activation(activation_string): 281 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 282 | 283 | Args: 284 | activation_string: String name of the activation function. 285 | 286 | Returns: 287 | A Python function corresponding to the activation function. If 288 | `activation_string` is None, empty, or "linear", this will return None. 289 | If `activation_string` is not a string, it will return `activation_string`. 290 | 291 | Raises: 292 | ValueError: The `activation_string` does not correspond to a known 293 | activation. 294 | """ 295 | 296 | # We assume that anything that"s not a string is already an activation 297 | # function, so we just return it. 298 | if not isinstance(activation_string, six.string_types): 299 | return activation_string 300 | 301 | if not activation_string: 302 | return None 303 | 304 | act = activation_string.lower() 305 | if act == "linear": 306 | return None 307 | elif act == "relu": 308 | return tf.nn.relu 309 | elif act == "gelu": 310 | return gelu 311 | elif act == "tanh": 312 | return tf.tanh 313 | else: 314 | raise ValueError("Unsupported activation: %s" % act) 315 | 316 | 317 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 318 | """Compute the union of the current variables and checkpoint variables.""" 319 | assignment_map = {} 320 | initialized_variable_names = {} 321 | 322 | name_to_variable = collections.OrderedDict() 323 | for var in tvars: 324 | name = var.name 325 | m = re.match("^(.*):\\d+$", name) 326 | if m is not None: 327 | name = m.group(1) 328 | name_to_variable[name] = var 329 | 330 | init_vars = tf.train.list_variables(init_checkpoint) 331 | 332 | assignment_map = collections.OrderedDict() 333 | for x in init_vars: 334 | (name, var) = (x[0], x[1]) 335 | if name not in name_to_variable: 336 | continue 337 | assignment_map[name] = name 338 | initialized_variable_names[name] = 1 339 | initialized_variable_names[name + ":0"] = 1 340 | 341 | return (assignment_map, initialized_variable_names) 342 | 343 | 344 | def dropout(input_tensor, dropout_prob): 345 | """Perform dropout. 346 | 347 | Args: 348 | input_tensor: float Tensor. 349 | dropout_prob: Python float. The probability of dropping out a value (NOT of 350 | *keeping* a dimension as in `tf.nn.dropout`). 351 | 352 | Returns: 353 | A version of `input_tensor` with dropout applied. 354 | """ 355 | if dropout_prob is None or dropout_prob == 0.0: 356 | return input_tensor 357 | 358 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 359 | return output 360 | 361 | 362 | def layer_norm(input_tensor, name=None): 363 | """Run layer normalization on the last dimension of the tensor.""" 364 | return tf.contrib.layers.layer_norm( 365 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 366 | 367 | 368 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 369 | """Runs layer normalization followed by dropout.""" 370 | output_tensor = layer_norm(input_tensor, name) 371 | output_tensor = dropout(output_tensor, dropout_prob) 372 | return output_tensor 373 | 374 | 375 | def create_initializer(initializer_range=0.02): 376 | """Creates a `truncated_normal_initializer` with the given range.""" 377 | return tf.truncated_normal_initializer(stddev=initializer_range) 378 | 379 | 380 | def embedding_lookup(input_ids, 381 | vocab_size, 382 | embedding_size=128, 383 | initializer_range=0.02, 384 | word_embedding_name="word_embeddings", 385 | use_one_hot_embeddings=False): 386 | """Looks up words embeddings for id tensor. 387 | 388 | Args: 389 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 390 | ids. 391 | vocab_size: int. Size of the embedding vocabulary. 392 | embedding_size: int. Width of the word embeddings. 393 | initializer_range: float. Embedding initialization range. 394 | word_embedding_name: string. Name of the embedding table. 395 | use_one_hot_embeddings: bool. If True, use one-hot method for word 396 | embeddings. If False, use `tf.gather()`. 397 | 398 | Returns: 399 | float Tensor of shape [batch_size, seq_length, embedding_size]. 400 | """ 401 | # This function assumes that the input is of shape [batch_size, seq_length, 402 | # num_inputs]. 403 | # 404 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 405 | # reshape to [batch_size, seq_length, 1]. 406 | if input_ids.shape.ndims == 2: 407 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 408 | 409 | embedding_table = tf.get_variable( 410 | name=word_embedding_name, 411 | shape=[vocab_size, embedding_size], 412 | initializer=create_initializer(initializer_range)) 413 | 414 | flat_input_ids = tf.reshape(input_ids, [-1]) 415 | if use_one_hot_embeddings: 416 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 417 | output = tf.matmul(one_hot_input_ids, embedding_table) 418 | else: 419 | output = tf.gather(embedding_table, flat_input_ids) 420 | 421 | input_shape = get_shape_list(input_ids) 422 | 423 | output = tf.reshape(output, 424 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 425 | return (output, embedding_table) 426 | 427 | 428 | def embedding_postprocessor(input_tensor, 429 | use_token_type=False, 430 | token_type_ids=None, 431 | token_type_vocab_size=16, 432 | token_type_embedding_name="token_type_embeddings", 433 | use_position_embeddings=True, 434 | position_embedding_name="position_embeddings", 435 | initializer_range=0.02, 436 | max_position_embeddings=512, 437 | dropout_prob=0.1): 438 | """Performs various post-processing on a word embedding tensor. 439 | 440 | Args: 441 | input_tensor: float Tensor of shape [batch_size, seq_length, 442 | embedding_size]. 443 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 444 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 445 | Must be specified if `use_token_type` is True. 446 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 447 | token_type_embedding_name: string. The name of the embedding table variable 448 | for token type ids. 449 | use_position_embeddings: bool. Whether to add position embeddings for the 450 | position of each token in the sequence. 451 | position_embedding_name: string. The name of the embedding table variable 452 | for positional embeddings. 453 | initializer_range: float. Range of the weight initialization. 454 | max_position_embeddings: int. Maximum sequence length that might ever be 455 | used with this model. This can be longer than the sequence length of 456 | input_tensor, but cannot be shorter. 457 | dropout_prob: float. Dropout probability applied to the final output tensor. 458 | 459 | Returns: 460 | float tensor with same shape as `input_tensor`. 461 | 462 | Raises: 463 | ValueError: One of the tensor shapes or input values is invalid. 464 | """ 465 | input_shape = get_shape_list(input_tensor, expected_rank=3) 466 | batch_size = input_shape[0] 467 | seq_length = input_shape[1] 468 | width = input_shape[2] 469 | 470 | output = input_tensor 471 | 472 | if use_token_type: 473 | if token_type_ids is None: 474 | raise ValueError("`token_type_ids` must be specified if" 475 | "`use_token_type` is True.") 476 | token_type_table = tf.get_variable( 477 | name=token_type_embedding_name, 478 | shape=[token_type_vocab_size, width], 479 | initializer=create_initializer(initializer_range)) 480 | # This vocab will be small so we always do one-hot here, since it is always 481 | # faster for a small vocabulary. 482 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 483 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 484 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 485 | token_type_embeddings = tf.reshape(token_type_embeddings, 486 | [batch_size, seq_length, width]) 487 | output += token_type_embeddings 488 | 489 | if use_position_embeddings: 490 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 491 | with tf.control_dependencies([assert_op]): 492 | full_position_embeddings = tf.get_variable( 493 | name=position_embedding_name, 494 | shape=[max_position_embeddings, width], 495 | initializer=create_initializer(initializer_range)) 496 | # Since the position embedding table is a learned variable, we create it 497 | # using a (long) sequence length `max_position_embeddings`. The actual 498 | # sequence length might be shorter than this, for faster training of 499 | # tasks that do not have long sequences. 500 | # 501 | # So `full_position_embeddings` is effectively an embedding table 502 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 503 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 504 | # perform a slice. 505 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 506 | [seq_length, -1]) 507 | num_dims = len(output.shape.as_list()) 508 | 509 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 510 | # we broadcast among the first dimensions, which is typically just 511 | # the batch size. 512 | position_broadcast_shape = [] 513 | for _ in range(num_dims - 2): 514 | position_broadcast_shape.append(1) 515 | position_broadcast_shape.extend([seq_length, width]) 516 | position_embeddings = tf.reshape(position_embeddings, 517 | position_broadcast_shape) 518 | output += position_embeddings 519 | 520 | output = layer_norm_and_dropout(output, dropout_prob) 521 | return output 522 | 523 | 524 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 525 | """Create 3D attention mask from a 2D tensor mask. 526 | 527 | Args: 528 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 529 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 530 | 531 | Returns: 532 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 533 | """ 534 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 535 | batch_size = from_shape[0] 536 | from_seq_length = from_shape[1] 537 | 538 | to_shape = get_shape_list(to_mask, expected_rank=2) 539 | to_seq_length = to_shape[1] 540 | 541 | to_mask = tf.cast( 542 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 543 | 544 | # We don't assume that `from_tensor` is a mask (although it could be). We 545 | # don't actually care if we attend *from* padding tokens (only *to* padding) 546 | # tokens so we create a tensor of all ones. 547 | # 548 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 549 | broadcast_ones = tf.ones( 550 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 551 | 552 | # Here we broadcast along two dimensions to create the mask. 553 | mask = broadcast_ones * to_mask 554 | 555 | return mask 556 | 557 | 558 | def attention_layer(from_tensor, 559 | to_tensor, 560 | attention_mask=None, 561 | num_attention_heads=1, 562 | size_per_head=512, 563 | query_act=None, 564 | key_act=None, 565 | value_act=None, 566 | attention_probs_dropout_prob=0.0, 567 | initializer_range=0.02, 568 | do_return_2d_tensor=False, 569 | batch_size=None, 570 | from_seq_length=None, 571 | to_seq_length=None): 572 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 573 | 574 | This is an implementation of multi-headed attention based on "Attention 575 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 576 | this is self-attention. Each timestep in `from_tensor` attends to the 577 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 578 | 579 | This function first projects `from_tensor` into a "query" tensor and 580 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 581 | of tensors of length `num_attention_heads`, where each tensor is of shape 582 | [batch_size, seq_length, size_per_head]. 583 | 584 | Then, the query and key tensors are dot-producted and scaled. These are 585 | softmaxed to obtain attention probabilities. The value tensors are then 586 | interpolated by these probabilities, then concatenated back to a single 587 | tensor and returned. 588 | 589 | In practice, the multi-headed attention are done with transposes and 590 | reshapes rather than actual separate tensors. 591 | 592 | Args: 593 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 594 | from_width]. 595 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 596 | attention_mask: (optional) int32 Tensor of shape [batch_size, 597 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 598 | attention scores will effectively be set to -infinity for any positions in 599 | the mask that are 0, and will be unchanged for positions that are 1. 600 | num_attention_heads: int. Number of attention heads. 601 | size_per_head: int. Size of each attention head. 602 | query_act: (optional) Activation function for the query transform. 603 | key_act: (optional) Activation function for the key transform. 604 | value_act: (optional) Activation function for the value transform. 605 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 606 | attention probabilities. 607 | initializer_range: float. Range of the weight initializer. 608 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 609 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 610 | output will be of shape [batch_size, from_seq_length, num_attention_heads 611 | * size_per_head]. 612 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 613 | of the 3D version of the `from_tensor` and `to_tensor`. 614 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 615 | of the 3D version of the `from_tensor`. 616 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 617 | of the 3D version of the `to_tensor`. 618 | 619 | Returns: 620 | float Tensor of shape [batch_size, from_seq_length, 621 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 622 | true, this will be of shape [batch_size * from_seq_length, 623 | num_attention_heads * size_per_head]). 624 | 625 | Raises: 626 | ValueError: Any of the arguments or tensor shapes are invalid. 627 | """ 628 | 629 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 630 | seq_length, width): 631 | output_tensor = tf.reshape( 632 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 633 | 634 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 635 | return output_tensor 636 | 637 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 638 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 639 | 640 | if len(from_shape) != len(to_shape): 641 | raise ValueError( 642 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 643 | 644 | if len(from_shape) == 3: 645 | batch_size = from_shape[0] 646 | from_seq_length = from_shape[1] 647 | to_seq_length = to_shape[1] 648 | elif len(from_shape) == 2: 649 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 650 | raise ValueError( 651 | "When passing in rank 2 tensors to attention_layer, the values " 652 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 653 | "must all be specified.") 654 | 655 | # Scalar dimensions referenced here: 656 | # B = batch size (number of sequences) 657 | # F = `from_tensor` sequence length 658 | # T = `to_tensor` sequence length 659 | # N = `num_attention_heads` 660 | # H = `size_per_head` 661 | 662 | from_tensor_2d = reshape_to_matrix(from_tensor) 663 | to_tensor_2d = reshape_to_matrix(to_tensor) 664 | 665 | # `query_layer` = [B*F, N*H] 666 | query_layer = tf.layers.dense( 667 | from_tensor_2d, 668 | num_attention_heads * size_per_head, 669 | activation=query_act, 670 | name="query", 671 | kernel_initializer=create_initializer(initializer_range)) 672 | 673 | # `key_layer` = [B*T, N*H] 674 | key_layer = tf.layers.dense( 675 | to_tensor_2d, 676 | num_attention_heads * size_per_head, 677 | activation=key_act, 678 | name="key", 679 | kernel_initializer=create_initializer(initializer_range)) 680 | 681 | # `value_layer` = [B*T, N*H] 682 | value_layer = tf.layers.dense( 683 | to_tensor_2d, 684 | num_attention_heads * size_per_head, 685 | activation=value_act, 686 | name="value", 687 | kernel_initializer=create_initializer(initializer_range)) 688 | 689 | # `query_layer` = [B, N, F, H] 690 | query_layer = transpose_for_scores(query_layer, batch_size, 691 | num_attention_heads, from_seq_length, 692 | size_per_head) 693 | 694 | # `key_layer` = [B, N, T, H] 695 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 696 | to_seq_length, size_per_head) 697 | 698 | # Take the dot product between "query" and "key" to get the raw 699 | # attention scores. 700 | # `attention_scores` = [B, N, F, T] 701 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 702 | attention_scores = tf.multiply(attention_scores, 703 | 1.0 / math.sqrt(float(size_per_head))) 704 | 705 | if attention_mask is not None: 706 | # `attention_mask` = [B, 1, F, T] 707 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 708 | 709 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 710 | # masked positions, this operation will create a tensor which is 0.0 for 711 | # positions we want to attend and -10000.0 for masked positions. 712 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 713 | 714 | # Since we are adding it to the raw scores before the softmax, this is 715 | # effectively the same as removing these entirely. 716 | attention_scores += adder 717 | 718 | # Normalize the attention scores to probabilities. 719 | # `attention_probs` = [B, N, F, T] 720 | attention_probs = tf.nn.softmax(attention_scores) 721 | 722 | # This is actually dropping out entire tokens to attend to, which might 723 | # seem a bit unusual, but is taken from the original Transformer paper. 724 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 725 | 726 | # `value_layer` = [B, T, N, H] 727 | value_layer = tf.reshape( 728 | value_layer, 729 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 730 | 731 | # `value_layer` = [B, N, T, H] 732 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 733 | 734 | # `context_layer` = [B, N, F, H] 735 | context_layer = tf.matmul(attention_probs, value_layer) 736 | 737 | # `context_layer` = [B, F, N, H] 738 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 739 | 740 | if do_return_2d_tensor: 741 | # `context_layer` = [B*F, N*H] 742 | context_layer = tf.reshape( 743 | context_layer, 744 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 745 | else: 746 | # `context_layer` = [B, F, N*H] 747 | context_layer = tf.reshape( 748 | context_layer, 749 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 750 | 751 | return context_layer 752 | 753 | 754 | def transformer_model(input_tensor, 755 | attention_mask=None, 756 | hidden_size=768, 757 | num_hidden_layers=12, 758 | num_attention_heads=12, 759 | intermediate_size=3072, 760 | intermediate_act_fn=gelu, 761 | hidden_dropout_prob=0.1, 762 | attention_probs_dropout_prob=0.1, 763 | initializer_range=0.02, 764 | do_return_all_layers=False): 765 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 766 | 767 | This is almost an exact implementation of the original Transformer encoder. 768 | 769 | See the original paper: 770 | https://arxiv.org/abs/1706.03762 771 | 772 | Also see: 773 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 774 | 775 | Args: 776 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 777 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 778 | seq_length], with 1 for positions that can be attended to and 0 in 779 | positions that should not be. 780 | hidden_size: int. Hidden size of the Transformer. 781 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 782 | num_attention_heads: int. Number of attention heads in the Transformer. 783 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 784 | forward) layer. 785 | intermediate_act_fn: function. The non-linear activation function to apply 786 | to the output of the intermediate/feed-forward layer. 787 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 788 | attention_probs_dropout_prob: float. Dropout probability of the attention 789 | probabilities. 790 | initializer_range: float. Range of the initializer (stddev of truncated 791 | normal). 792 | do_return_all_layers: Whether to also return all layers or just the final 793 | layer. 794 | 795 | Returns: 796 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 797 | hidden layer of the Transformer. 798 | 799 | Raises: 800 | ValueError: A Tensor shape or parameter is invalid. 801 | """ 802 | if hidden_size % num_attention_heads != 0: 803 | raise ValueError( 804 | "The hidden size (%d) is not a multiple of the number of attention " 805 | "heads (%d)" % (hidden_size, num_attention_heads)) 806 | 807 | attention_head_size = int(hidden_size / num_attention_heads) 808 | input_shape = get_shape_list(input_tensor, expected_rank=3) 809 | batch_size = input_shape[0] 810 | seq_length = input_shape[1] 811 | input_width = input_shape[2] 812 | 813 | # The Transformer performs sum residuals on all layers so the input needs 814 | # to be the same as the hidden size. 815 | if input_width != hidden_size: 816 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 817 | (input_width, hidden_size)) 818 | 819 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 820 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 821 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 822 | # help the optimizer. 823 | prev_output = reshape_to_matrix(input_tensor) 824 | 825 | all_layer_outputs = [] 826 | for layer_idx in range(num_hidden_layers): 827 | with tf.variable_scope("layer_%d" % layer_idx): 828 | layer_input = prev_output 829 | 830 | with tf.variable_scope("attention"): 831 | attention_heads = [] 832 | with tf.variable_scope("self"): 833 | attention_head = attention_layer( 834 | from_tensor=layer_input, 835 | to_tensor=layer_input, 836 | attention_mask=attention_mask, 837 | num_attention_heads=num_attention_heads, 838 | size_per_head=attention_head_size, 839 | attention_probs_dropout_prob=attention_probs_dropout_prob, 840 | initializer_range=initializer_range, 841 | do_return_2d_tensor=True, 842 | batch_size=batch_size, 843 | from_seq_length=seq_length, 844 | to_seq_length=seq_length) 845 | attention_heads.append(attention_head) 846 | 847 | attention_output = None 848 | if len(attention_heads) == 1: 849 | attention_output = attention_heads[0] 850 | else: 851 | # In the case where we have other sequences, we just concatenate 852 | # them to the self-attention head before the projection. 853 | attention_output = tf.concat(attention_heads, axis=-1) 854 | 855 | # Run a linear projection of `hidden_size` then add a residual 856 | # with `layer_input`. 857 | with tf.variable_scope("output"): 858 | attention_output = tf.layers.dense( 859 | attention_output, 860 | hidden_size, 861 | kernel_initializer=create_initializer(initializer_range)) 862 | attention_output = dropout(attention_output, hidden_dropout_prob) 863 | attention_output = layer_norm(attention_output + layer_input) 864 | 865 | # The activation is only applied to the "intermediate" hidden layer. 866 | with tf.variable_scope("intermediate"): 867 | intermediate_output = tf.layers.dense( 868 | attention_output, 869 | intermediate_size, 870 | activation=intermediate_act_fn, 871 | kernel_initializer=create_initializer(initializer_range)) 872 | 873 | # Down-project back to `hidden_size` then add the residual. 874 | with tf.variable_scope("output"): 875 | layer_output = tf.layers.dense( 876 | intermediate_output, 877 | hidden_size, 878 | kernel_initializer=create_initializer(initializer_range)) 879 | layer_output = dropout(layer_output, hidden_dropout_prob) 880 | layer_output = layer_norm(layer_output + attention_output) 881 | prev_output = layer_output 882 | all_layer_outputs.append(layer_output) 883 | 884 | if do_return_all_layers: 885 | final_outputs = [] 886 | for layer_output in all_layer_outputs: 887 | final_output = reshape_from_matrix(layer_output, input_shape) 888 | final_outputs.append(final_output) 889 | return final_outputs 890 | else: 891 | final_output = reshape_from_matrix(prev_output, input_shape) 892 | return final_output 893 | 894 | 895 | def get_shape_list(tensor, expected_rank=None, name=None): 896 | """Returns a list of the shape of tensor, preferring static dimensions. 897 | 898 | Args: 899 | tensor: A tf.Tensor object to find the shape of. 900 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 901 | specified and the `tensor` has a different rank, and exception will be 902 | thrown. 903 | name: Optional name of the tensor for the error message. 904 | 905 | Returns: 906 | A list of dimensions of the shape of tensor. All static dimensions will 907 | be returned as python integers, and dynamic dimensions will be returned 908 | as tf.Tensor scalars. 909 | """ 910 | if name is None: 911 | name = tensor.name 912 | 913 | if expected_rank is not None: 914 | assert_rank(tensor, expected_rank, name) 915 | 916 | shape = tensor.shape.as_list() 917 | 918 | non_static_indexes = [] 919 | for (index, dim) in enumerate(shape): 920 | if dim is None: 921 | non_static_indexes.append(index) 922 | 923 | if not non_static_indexes: 924 | return shape 925 | 926 | dyn_shape = tf.shape(tensor) 927 | for index in non_static_indexes: 928 | shape[index] = dyn_shape[index] 929 | return shape 930 | 931 | 932 | def reshape_to_matrix(input_tensor): 933 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 934 | ndims = input_tensor.shape.ndims 935 | if ndims < 2: 936 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 937 | (input_tensor.shape)) 938 | if ndims == 2: 939 | return input_tensor 940 | 941 | width = input_tensor.shape[-1] 942 | output_tensor = tf.reshape(input_tensor, [-1, width]) 943 | return output_tensor 944 | 945 | 946 | def reshape_from_matrix(output_tensor, orig_shape_list): 947 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 948 | if len(orig_shape_list) == 2: 949 | return output_tensor 950 | 951 | output_shape = get_shape_list(output_tensor) 952 | 953 | orig_dims = orig_shape_list[0:-1] 954 | width = output_shape[-1] 955 | 956 | return tf.reshape(output_tensor, orig_dims + [width]) 957 | 958 | 959 | def assert_rank(tensor, expected_rank, name=None): 960 | """Raises an exception if the tensor rank is not of the expected rank. 961 | 962 | Args: 963 | tensor: A tf.Tensor to check the rank of. 964 | expected_rank: Python integer or list of integers, expected rank. 965 | name: Optional name of the tensor for the error message. 966 | 967 | Raises: 968 | ValueError: If the expected shape doesn't match the actual shape. 969 | """ 970 | if name is None: 971 | name = tensor.name 972 | 973 | expected_rank_dict = {} 974 | if isinstance(expected_rank, six.integer_types): 975 | expected_rank_dict[expected_rank] = True 976 | else: 977 | for x in expected_rank: 978 | expected_rank_dict[x] = True 979 | 980 | actual_rank = tensor.shape.ndims 981 | if actual_rank not in expected_rank_dict: 982 | scope_name = tf.get_variable_scope().name 983 | raise ValueError( 984 | "For the tensor `%s` in scope `%s`, the actual rank " 985 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 986 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 987 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /optimization_bert_flow.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, flow_loss, init_lr, init_flow_lr, 26 | num_train_steps, num_warmup_steps, use_tpu): 27 | """Creates an optimizer training op.""" 28 | global_step = tf.train.get_or_create_global_step() 29 | 30 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 31 | 32 | # Implements linear decay of the learning rate. 33 | learning_rate = tf.train.polynomial_decay( 34 | learning_rate, 35 | global_step, 36 | num_train_steps, 37 | end_learning_rate=0.0, 38 | power=1.0, 39 | cycle=False) 40 | 41 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 42 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 43 | if num_warmup_steps: 44 | global_steps_int = tf.cast(global_step, tf.int32) 45 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 46 | 47 | global_steps_float = tf.cast(global_steps_int, tf.float32) 48 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 49 | 50 | warmup_percent_done = global_steps_float / warmup_steps_float 51 | warmup_learning_rate = init_lr * warmup_percent_done 52 | 53 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 54 | learning_rate = ( 55 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 56 | 57 | if type(flow_loss) != "NoneType": 58 | # bert 59 | # It is recommended that you use this optimizer for fine tuning, since this 60 | # is how the model was trained (note that the Adam m/v variables are NOT 61 | # loaded from init_checkpoint.) 62 | optimizer = AdamWeightDecayOptimizer( 63 | learning_rate=learning_rate, 64 | weight_decay_rate=0.01, 65 | beta_1=0.9, 66 | beta_2=0.999, 67 | epsilon=1e-6, 68 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 69 | if use_tpu: 70 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 71 | 72 | tvars = [v for v in tf.trainable_variables() if not v.name.startswith("bert/flow")] 73 | grads = tf.gradients(loss, tvars) 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | train_op = optimizer.apply_gradients( 76 | zip(grads, tvars), global_step=global_step) 77 | 78 | ######################## 79 | # flow 80 | flow_optimizer = AdamWeightDecayOptimizer( 81 | learning_rate=init_flow_lr, #learning_rate / init_lr * 82 | weight_decay_rate=0.01, 83 | beta_1=0.9, 84 | beta_2=0.999, 85 | epsilon=1e-6, 86 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 87 | if use_tpu: 88 | flow_optimizer = tf.contrib.tpu.CrossShardOptimizer(flow_optimizer) 89 | 90 | flow_tvars = tf.trainable_variables("bert/flow") 91 | flow_grads = tf.gradients(flow_loss, flow_tvars) 92 | (flow_grads, _) = tf.clip_by_global_norm(flow_grads, clip_norm=1.0) 93 | flow_train_op = flow_optimizer.apply_gradients( 94 | zip(flow_grads, flow_tvars), global_step=global_step) 95 | 96 | ######################## 97 | 98 | # Normally the global step update is done inside of `apply_gradients`. 99 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 100 | # a different optimizer, you should probably take this line out. 101 | new_global_step = global_step + 1 102 | train_op = tf.group(train_op, flow_train_op, [global_step.assign(new_global_step)]) 103 | return train_op 104 | else: 105 | raise NotImplementedError 106 | 107 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 108 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 109 | 110 | def __init__(self, 111 | learning_rate, 112 | weight_decay_rate=0.0, 113 | beta_1=0.9, 114 | beta_2=0.999, 115 | epsilon=1e-6, 116 | exclude_from_weight_decay=None, 117 | name="AdamWeightDecayOptimizer"): 118 | """Constructs a AdamWeightDecayOptimizer.""" 119 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 120 | 121 | self.learning_rate = learning_rate 122 | self.weight_decay_rate = weight_decay_rate 123 | self.beta_1 = beta_1 124 | self.beta_2 = beta_2 125 | self.epsilon = epsilon 126 | self.exclude_from_weight_decay = exclude_from_weight_decay 127 | 128 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 129 | """See base class.""" 130 | assignments = [] 131 | for (grad, param) in grads_and_vars: 132 | if grad is None or param is None: 133 | continue 134 | 135 | param_name = self._get_variable_name(param.name) 136 | 137 | m = tf.get_variable( 138 | name=param_name + "/adam_m", 139 | shape=param.shape.as_list(), 140 | dtype=tf.float32, 141 | trainable=False, 142 | initializer=tf.zeros_initializer()) 143 | v = tf.get_variable( 144 | name=param_name + "/adam_v", 145 | shape=param.shape.as_list(), 146 | dtype=tf.float32, 147 | trainable=False, 148 | initializer=tf.zeros_initializer()) 149 | 150 | # Standard Adam update. 151 | next_m = ( 152 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 153 | next_v = ( 154 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 155 | tf.square(grad))) 156 | 157 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 158 | 159 | # Just adding the square of the weights to the loss function is *not* 160 | # the correct way of using L2 regularization/weight decay with Adam, 161 | # since that will interact with the m and v parameters in strange ways. 162 | # 163 | # Instead we want ot decay the weights in a manner that doesn't interact 164 | # with the m/v parameters. This is equivalent to adding the square 165 | # of the weights to the loss with plain (non-momentum) SGD. 166 | if self._do_use_weight_decay(param_name): 167 | update += self.weight_decay_rate * param 168 | 169 | update_with_lr = self.learning_rate * update 170 | 171 | next_param = param - update_with_lr 172 | 173 | assignments.extend( 174 | [param.assign(next_param), 175 | m.assign(next_m), 176 | v.assign(next_v)]) 177 | return tf.group(*assignments, name=name) 178 | 179 | def _do_use_weight_decay(self, param_name): 180 | """Whether to use L2 weight decay for `param_name`.""" 181 | if not self.weight_decay_rate: 182 | return False 183 | if self.exclude_from_weight_decay: 184 | for r in self.exclude_from_weight_decay: 185 | if re.search(r, param_name) is not None: 186 | return False 187 | return True 188 | 189 | def _get_variable_name(self, param_name): 190 | """Get the variable name from the tensor name.""" 191 | m = re.match("^(.*):\\d+$", param_name) 192 | if m is not None: 193 | param_name = m.group(1) 194 | return param_name 195 | -------------------------------------------------------------------------------- /run_siamese.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner for regression tasks 16 | A large portion of the code is adapted from 17 | https://github.com/zihangdai/xlnet/blob/master/run_classifier.py 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import warnings 25 | warnings.simplefilter(action='ignore', category=FutureWarning) 26 | 27 | import os 28 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" 29 | 30 | import collections 31 | import csv 32 | 33 | import modeling 34 | import optimization 35 | import tokenization 36 | 37 | import tensorflow as tf 38 | 39 | import random 40 | import numpy as np 41 | 42 | from flow.glow_1x1 import AttrDict, Glow 43 | from flow.glow_init_hook import GlowInitHook 44 | import optimization_bert_flow 45 | import json 46 | 47 | from siamese_utils import StsbProcessor, SickRProcessor, MnliProcessor, QqpProcessor, \ 48 | SnliTrainProcessor, SnliDevTestProcessor, \ 49 | Sts_12_16_Processor, MrpcRegressionProcessor, QnliRegressionProcessor, \ 50 | file_based_convert_examples_to_features, file_based_input_fn_builder, \ 51 | get_input_mask_segment 52 | 53 | flags = tf.flags 54 | 55 | FLAGS = flags.FLAGS 56 | 57 | # model 58 | flags.DEFINE_string("bert_config_file", None, 59 | "The config json file corresponding to the pre-trained BERT model. " 60 | "This specifies the model architecture.") 61 | flags.DEFINE_integer("max_seq_length", 128, 62 | "The maximum total input sequence length after WordPiece tokenization. " 63 | "Sequences longer than this will be truncated, and sequences shorter " 64 | "than this will be padded.") 65 | flags.DEFINE_string("init_checkpoint", None, 66 | "Initial checkpoint (usually from a pre-trained BERT model).") 67 | flags.DEFINE_string("vocab_file", None, 68 | "The vocabulary file that the BERT model was trained on.") 69 | flags.DEFINE_bool("do_lower_case", True, 70 | "Whether to lower case the input text. Should be True for uncased " 71 | "models and False for cased models.") 72 | 73 | 74 | # task and data 75 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 76 | flags.DEFINE_string("data_dir", None, 77 | "The input data dir. Should contain the .tsv files (or other data files) " 78 | "for the task.") 79 | flags.DEFINE_float("label_min", 0., None) 80 | flags.DEFINE_float("label_max", 5., None) 81 | 82 | # exp 83 | flags.DEFINE_string("output_parent_dir", None, None) 84 | flags.DEFINE_string("exp_name", None, None) 85 | flags.DEFINE_string("exp_name_prefix", None, None) 86 | flags.DEFINE_integer("log_every_step", 10, None) 87 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 88 | "How often to save the model checkpoint.") 89 | flags.DEFINE_bool("use_xla", False, None) 90 | flags.DEFINE_integer("seed", 1234, None) 91 | flags.DEFINE_string("cached_dir", None, 92 | "Path to cached training and dev tfrecord file. " 93 | "The file will be generated if not exist.") 94 | 95 | # training 96 | flags.DEFINE_bool("do_train", False, None) 97 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 98 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 99 | flags.DEFINE_float("num_train_epochs", 3.0, 100 | "Total number of training epochs to perform.") 101 | flags.DEFINE_float("warmup_proportion", 0.1, 102 | "Proportion of training to perform linear learning rate warmup for. " 103 | "E.g., 0.1 = 10% of training.") 104 | flags.DEFINE_bool("early_stopping", False, None) 105 | 106 | flags.DEFINE_integer("start_delay_secs", 120, "for tf.estimator.EvalSpec") 107 | flags.DEFINE_integer("throttle_secs", 600, "for tf.estimator.EvalSpec") 108 | 109 | # eval 110 | flags.DEFINE_bool("do_eval", False, None) 111 | flags.DEFINE_bool("do_predict", False, None) 112 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 113 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 114 | flags.DEFINE_bool("predict_pool", False, None) 115 | flags.DEFINE_bool("do_predict_on_dev", False, None) 116 | flags.DEFINE_bool("do_predict_on_full", False, None) 117 | flags.DEFINE_string("eval_checkpoint_name", None, "filename of a finetuned checkpoint") 118 | flags.DEFINE_bool("auc", False, None) 119 | 120 | # sentence embedding related parameters 121 | flags.DEFINE_string("sentence_embedding_type", "avg", "avg, cls, ...") 122 | 123 | # flow parameters 124 | flags.DEFINE_integer("flow", 0, "use flow or not") 125 | flags.DEFINE_integer("flow_loss", 0, "use flow loss or not") 126 | flags.DEFINE_float("flow_learning_rate", 1e-3, "The initial learning rate for Adam.") 127 | flags.DEFINE_string("flow_model_config", "config_l3_d3_w32", None) 128 | 129 | # unsupervised or semi-supervised related parameters 130 | flags.DEFINE_integer("num_examples", -1, "# of labeled training examples") 131 | flags.DEFINE_integer("use_full_for_training", 0, None) 132 | flags.DEFINE_integer("dupe_factor", 1, "Number of times to duplicate the input data (with different masks).") 133 | 134 | # nli related parameters 135 | # flags.DEFINE_integer("use_snli_full", 0, "augment MNLI training data with SNLI") 136 | flags.DEFINE_float("l2_penalty", -1, "penalize l2 norm of sentence embeddings") 137 | 138 | # dimension reduction related parameters 139 | flags.DEFINE_integer("low_dim", -1, "avg pooling over the embedding") 140 | 141 | # senteval 142 | flags.DEFINE_bool("do_senteval", False, None) 143 | flags.DEFINE_string("senteval_tasks", "", None) 144 | 145 | def get_embedding(bert_config, is_training, 146 | input_ids, input_mask, segment_ids, scope=None): 147 | 148 | model = modeling.BertModel( 149 | config=bert_config, 150 | is_training=is_training, 151 | input_ids=input_ids, 152 | input_mask=input_mask, 153 | token_type_ids=segment_ids, 154 | scope=scope) 155 | 156 | if FLAGS.sentence_embedding_type == "avg": 157 | sequence = model.get_sequence_output() # [batch_size, seq_length, hidden_size] 158 | input_mask_ = tf.cast(tf.expand_dims(input_mask, axis=-1), dtype=tf.float32) 159 | pooled = tf.reduce_sum(sequence * input_mask_, axis=1) / tf.reduce_sum(input_mask_, axis=1) 160 | elif FLAGS.sentence_embedding_type == "cls": 161 | pooled = model.get_pooled_output() 162 | elif FLAGS.sentence_embedding_type.startswith("avg-last-last-"): 163 | pooled = 0 164 | n_last = int(FLAGS.sentence_embedding_type[-1]) 165 | input_mask_ = tf.cast(tf.expand_dims(input_mask, axis=-1), dtype=tf.float32) 166 | sequence = model.all_encoder_layers[-n_last] # [batch_size, seq_length, hidden_size] 167 | pooled += tf.reduce_sum(sequence * input_mask_, axis=1) / tf.reduce_sum(input_mask_, axis=1) 168 | elif FLAGS.sentence_embedding_type.startswith("avg-last-"): 169 | pooled = 0 170 | n_last = int(FLAGS.sentence_embedding_type[-1]) 171 | input_mask_ = tf.cast(tf.expand_dims(input_mask, axis=-1), dtype=tf.float32) 172 | for i in range(n_last): 173 | sequence = model.all_encoder_layers[-i] # [batch_size, seq_length, hidden_size] 174 | pooled += tf.reduce_sum(sequence * input_mask_, axis=1) / tf.reduce_sum(input_mask_, axis=1) 175 | pooled /= float(n_last) 176 | elif FLAGS.sentence_embedding_type.startswith("avg-last-concat-"): 177 | pooled = [] 178 | n_last = int(FLAGS.sentence_embedding_type[-1]) 179 | input_mask_ = tf.cast(tf.expand_dims(input_mask, axis=-1), dtype=tf.float32) 180 | for i in range(n_last): 181 | sequence = model.all_encoder_layers[-i] # [batch_size, seq_length, hidden_size] 182 | pooled += [tf.reduce_sum(sequence * input_mask_, axis=1) / tf.reduce_sum(input_mask_, axis=1)] 183 | pooled = tf.concat(pooled, axis=-1) 184 | else: 185 | raise NotImplementedError 186 | 187 | # flow 188 | embedding = None 189 | flow_loss_batch, flow_loss_example = None, None 190 | if FLAGS.flow: 191 | # load model and train config 192 | with open(os.path.join("./flow/config", FLAGS.flow_model_config + ".json"), 'r') as jp: 193 | flow_model_config = AttrDict(json.load(jp)) 194 | flow_model_config.is_training = is_training 195 | flow_model = Glow(flow_model_config) 196 | flow_loss_example = flow_model.body(pooled, is_training) # no l2 normalization here any more 197 | flow_loss_batch = tf.math.reduce_mean(flow_loss_example) 198 | embedding = tf.identity(tf.squeeze(flow_model.z, [1, 2])) # no l2 normalization here any more 199 | else: 200 | embedding = pooled 201 | 202 | if FLAGS.low_dim > 0: 203 | bsz, org_dim = modeling.get_shape_list(embedding) 204 | embedding = tf.reduce_mean( 205 | tf.reshape(embedding, [bsz, FLAGS.low_dim, org_dim // FLAGS.low_dim]), axis=-1) 206 | 207 | return embedding, flow_loss_batch, flow_loss_example 208 | 209 | 210 | def create_model(bert_config, is_regression, 211 | is_training, 212 | input_ids_a, input_mask_a, segment_ids_a, 213 | input_ids_b, input_mask_b, segment_ids_b, 214 | labels, num_labels): 215 | """Creates a classification model.""" 216 | 217 | with tf.variable_scope("bert") as scope: 218 | embedding_a, flow_loss_batch_a, flow_loss_example_a = \ 219 | get_embedding(bert_config, is_training, 220 | input_ids_a, input_mask_a, segment_ids_a, scope) 221 | with tf.variable_scope("bert", reuse=tf.AUTO_REUSE) as scope: 222 | embedding_b, flow_loss_batch_b, flow_loss_example_b = \ 223 | get_embedding(bert_config, is_training, 224 | input_ids_b, input_mask_b, segment_ids_b, scope) 225 | 226 | with tf.variable_scope("loss"): 227 | cos_similarity = tf.reduce_sum(tf.multiply( 228 | tf.nn.l2_normalize(embedding_a, axis=-1), 229 | tf.nn.l2_normalize(embedding_b, axis=-1)), axis=-1) 230 | if is_regression: 231 | # changing cos_similarity into (cos_similarity + 1)/2.0 232 | # leads to large performance decrease in practice 233 | per_example_loss = tf.square(cos_similarity - labels) 234 | loss = tf.reduce_mean(per_example_loss) 235 | logits, predictions = None, None 236 | else: 237 | output_layer = tf.concat([ 238 | embedding_a, embedding_b, tf.math.abs(embedding_a - embedding_b) 239 | ], axis=-1) 240 | output_size = output_layer.shape[-1].value 241 | output_weights = tf.get_variable( 242 | "output_weights", [num_labels, output_size], 243 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 244 | 245 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 246 | 247 | probabilities = tf.nn.softmax(logits, axis=-1) 248 | predictions = tf.argmax(probabilities, axis=-1, output_type=tf.int32) 249 | log_probs = tf.nn.log_softmax(logits, axis=-1) 250 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 251 | 252 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 253 | loss = tf.reduce_mean(per_example_loss) 254 | 255 | if FLAGS.num_examples == 0: 256 | per_example_loss = tf.zeros_like(per_example_loss) 257 | loss = tf.zeros_like(loss) 258 | elif FLAGS.num_examples > 0: 259 | per_example_loss = per_example_loss * tf.cast(labels > -1, dtype=tf.float32) 260 | loss = tf.reduce_mean(per_example_loss) 261 | 262 | if FLAGS.l2_penalty > 0: 263 | l2_penalty_loss = tf.norm(embedding_a, axis=-1, keepdims=False) 264 | l2_penalty_loss += tf.norm(embedding_b, axis=-1, keepdims=False) 265 | l2_penalty_loss *= FLAGS.l2_penalty 266 | 267 | per_example_loss += l2_penalty_loss 268 | loss += tf.reduce_mean(l2_penalty_loss) 269 | 270 | model_output = { 271 | "loss": loss, 272 | "per_example_loss": per_example_loss, 273 | "cos_similarity": cos_similarity, 274 | "embedding_a": embedding_a, 275 | "embedding_b": embedding_b, 276 | "logits": logits, 277 | "predictions": predictions, 278 | } 279 | 280 | if FLAGS.flow: 281 | model_output["flow_example_loss"] = flow_loss_example_a + flow_loss_example_b 282 | model_output["flow_loss"] = flow_loss_batch_a + flow_loss_batch_b 283 | 284 | return model_output 285 | 286 | 287 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 288 | num_train_steps, num_warmup_steps, is_regression): 289 | """Returns `model_fn` closure for Estimator.""" 290 | 291 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 292 | """The `model_fn` for Estimator.""" 293 | 294 | tf.logging.info("*** Features ***") 295 | for name in sorted(features.keys()): 296 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 297 | 298 | input_ids_a = features["input_ids_a"] 299 | input_mask_a = features["input_mask_a"] 300 | segment_ids_a = features["segment_ids_a"] 301 | 302 | input_ids_b = features["input_ids_b"] 303 | input_mask_b = features["input_mask_b"] 304 | segment_ids_b = features["segment_ids_b"] 305 | 306 | label_ids = features["label_ids"] 307 | is_real_example = None 308 | if "is_real_example" in features: 309 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 310 | else: 311 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) 312 | 313 | #### Training or Evaluation 314 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 315 | 316 | #### Get loss from inputs 317 | model_output = create_model( 318 | bert_config, is_regression, 319 | is_training, 320 | input_ids_a, input_mask_a, segment_ids_a, 321 | input_ids_b, input_mask_b, segment_ids_b, 322 | label_ids, 323 | num_labels) 324 | 325 | tvars = tf.trainable_variables() 326 | initialized_variable_names = {} 327 | if init_checkpoint: 328 | (assignment_map, initialized_variable_names 329 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 330 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 331 | 332 | tf.logging.info("**** Trainable Variables ****") 333 | for var in tvars: 334 | init_string = "" 335 | if var.name in initialized_variable_names: 336 | init_string = ", *INIT_FROM_CKPT*" 337 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 338 | init_string) 339 | # if "flow" in var.name: 340 | # input() 341 | 342 | output_spec = None 343 | if mode == tf.estimator.ModeKeys.TRAIN: 344 | if FLAGS.flow_loss: 345 | train_op = optimization_bert_flow.create_optimizer( 346 | model_output["loss"], model_output["flow_loss"], 347 | learning_rate, FLAGS.flow_learning_rate, 348 | num_train_steps, num_warmup_steps, use_tpu=False) 349 | tf.summary.scalar("loss", model_output["loss"]) 350 | tf.summary.scalar("flow_loss", model_output["flow_loss"]) 351 | 352 | output_spec = tf.estimator.EstimatorSpec( 353 | mode=mode, 354 | loss=model_output["loss"] + model_output["flow_loss"], 355 | train_op=train_op) 356 | else: 357 | train_op = optimization.create_optimizer( 358 | model_output["loss"], learning_rate, 359 | num_train_steps, num_warmup_steps, use_tpu=False) 360 | output_spec = tf.estimator.EstimatorSpec( 361 | mode=mode, 362 | loss=model_output["loss"], 363 | train_op=train_op) 364 | 365 | elif mode == tf.estimator.ModeKeys.EVAL: 366 | def metric_fn(model_output, label_ids, is_real_example): 367 | predictions = tf.argmax(model_output["logits"], axis=-1, output_type=tf.int32) 368 | accuracy = tf.metrics.accuracy( 369 | labels=label_ids, predictions=model_output["predictions"], 370 | weights=is_real_example) 371 | loss = tf.metrics.mean( 372 | values=model_output["per_example_loss"], weights=is_real_example) 373 | metric_output = { 374 | "eval_accuracy": accuracy, 375 | "eval_loss": loss, 376 | } 377 | 378 | if "flow_loss" in model_output: 379 | metric_output["eval_loss_flow"] = \ 380 | tf.metrics.mean(values=model_output["flow_example_loss"], weights=is_real_example) 381 | metric_output["eval_loss_total"] = \ 382 | tf.metrics.mean( 383 | values=model_output["per_example_loss"] + model_output["flow_example_loss"], 384 | weights=is_real_example) 385 | 386 | return metric_output 387 | 388 | def regression_metric_fn(model_output, label_ids, is_real_example): 389 | metric_output = { 390 | "eval_loss": tf.metrics.mean( 391 | values=model_output["per_example_loss"], weights=is_real_example), 392 | "eval_pearsonr": tf.contrib.metrics.streaming_pearson_correlation( 393 | model_output["cos_similarity"], label_ids, weights=is_real_example) 394 | } 395 | 396 | # metric_output["auc"] = tf.compat.v1.metrics.auc( 397 | # label_ids, tf.math.maximum(model_output["cos_similarity"], 0), weights=is_real_example, curve='ROC') 398 | 399 | if "flow_loss" in model_output: 400 | metric_output["eval_loss_flow"] = \ 401 | tf.metrics.mean(values=model_output["flow_example_loss"], weights=is_real_example) 402 | metric_output["eval_loss_total"] = \ 403 | tf.metrics.mean( 404 | values=model_output["per_example_loss"] + model_output["flow_example_loss"], 405 | weights=is_real_example) 406 | 407 | return metric_output 408 | 409 | if is_regression: 410 | metric_fn = regression_metric_fn 411 | 412 | eval_metrics = metric_fn(model_output, label_ids, is_real_example) 413 | 414 | output_spec = tf.estimator.EstimatorSpec( 415 | mode=mode, 416 | loss=model_output["loss"], 417 | eval_metric_ops=eval_metrics) 418 | else: 419 | output_spec = tf.estimator.EstimatorSpec( 420 | mode=mode, 421 | predictions= {"embedding_a": model_output["embedding_a"], 422 | "embedding_b": model_output["embedding_b"]} if FLAGS.predict_pool else \ 423 | {"cos_similarity": model_output["cos_similarity"]}) 424 | return output_spec 425 | 426 | return model_fn 427 | 428 | 429 | def main(_): 430 | tf.logging.set_verbosity(tf.logging.INFO) 431 | 432 | # random seed 433 | random.seed(FLAGS.seed) 434 | np.random.seed(FLAGS.seed) 435 | tf.compat.v1.set_random_seed(FLAGS.seed) 436 | print("FLAGS.seed", FLAGS.seed) 437 | # input() 438 | 439 | # prevent double printing of the tf logs 440 | logger = tf.get_logger() 441 | logger.propagate = False 442 | 443 | # get tokenizer 444 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 445 | FLAGS.init_checkpoint) 446 | tokenizer = tokenization.FullTokenizer( 447 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 448 | 449 | # get bert config 450 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 451 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 452 | raise ValueError( 453 | "Cannot use sequence length %d because the BERT model " 454 | "was only trained up to sequence length %d" % 455 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 456 | 457 | # GPU config 458 | run_config = tf.compat.v1.ConfigProto() 459 | if FLAGS.use_xla: 460 | run_config.graph_options.optimizer_options.global_jit_level = \ 461 | tf.OptimizerOptions.ON_1 462 | 463 | run_config.gpu_options.allow_growth = True 464 | 465 | if FLAGS.do_senteval: 466 | # Set up logger 467 | import logging 468 | tf.logging.set_verbosity(0) 469 | logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG) 470 | 471 | # load senteval 472 | import sys 473 | PATH_TO_SENTEVAL, PATH_TO_DATA = '../SentEval', '../SentEval/data' 474 | sys.path.insert(0, PATH_TO_SENTEVAL) 475 | import senteval 476 | 477 | # model 478 | tf.logging.info("***** Running SentEval *****") 479 | with tf.Graph().as_default(): 480 | with tf.variable_scope("bert") as scope: 481 | input_ids = tf.placeholder(shape=[None, None], dtype=tf.int32, name="input_ids") 482 | input_mask = tf.placeholder(shape=[None, None], dtype=tf.int32, name="input_mask") 483 | segment_ids = tf.placeholder(shape=[None, None], dtype=tf.int32, name="segment_ids") 484 | 485 | embedding, flow_loss_batch, flow_loss_example = \ 486 | get_embedding(bert_config, False, 487 | input_ids, input_mask, segment_ids, scope=scope) 488 | embedding = tf.nn.l2_normalize(embedding, axis=-1) 489 | 490 | tvars = tf.trainable_variables() 491 | initialized_variable_names = {} 492 | if FLAGS.init_checkpoint: 493 | (assignment_map, initialized_variable_names 494 | ) = modeling.get_assignment_map_from_checkpoint(tvars, FLAGS.init_checkpoint) 495 | tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map) 496 | 497 | tf.logging.info("**** Trainable Variables ****") 498 | for var in tvars: 499 | init_string = "" 500 | if var.name in initialized_variable_names: 501 | init_string = ", *INIT_FROM_CKPT*" 502 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 503 | init_string) 504 | 505 | with tf.train.MonitoredSession( 506 | session_creator=tf.compat.v1.train.ChiefSessionCreator(config=run_config)) as session: 507 | 508 | # SentEval prepare and batcher 509 | def prepare(params, samples): 510 | return 511 | 512 | def batcher(params, batch): 513 | batch_input_ids, batch_input_mask, batch_segment_ids = [], [], [] 514 | for sent in batch: 515 | if type(sent[0]) == bytes: 516 | sent = [_.decode() for _ in sent] 517 | text = ' '.join(sent) if sent != [] else '.' 518 | # print(text) 519 | 520 | _input_ids, _input_mask, _segment_ids, _tokens = \ 521 | get_input_mask_segment(text, FLAGS.max_seq_length, tokenizer) 522 | batch_input_ids.append(_input_ids) 523 | batch_input_mask.append(_input_mask) 524 | batch_segment_ids.append(_segment_ids) 525 | 526 | batch_input_ids = np.asarray(batch_input_ids) 527 | batch_input_mask = np.asarray(batch_input_mask) 528 | batch_segment_ids = np.asarray(batch_segment_ids) 529 | 530 | print(".", end="") 531 | 532 | return session.run(embedding, 533 | {input_ids: batch_input_ids, 534 | input_mask: batch_input_mask, 535 | segment_ids: batch_segment_ids}) 536 | 537 | # Set params for SentEval 538 | params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5} 539 | params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 540 | 'tenacity': 3, 'epoch_size': 2} 541 | 542 | # main 543 | se = senteval.engine.SE(params_senteval, batcher, prepare) 544 | 545 | # transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 546 | # 'MR', 'CR', 'MPQA', 'SUBJ', 'SST2', 'SST5', 'TREC', 'MRPC', 547 | # 'SICKEntailment', 'SICKRelatedness', 'STSBenchmark', 548 | # 'Length', 'WordContent', 'Depth', 'TopConstituents', 549 | # 'BigramShift', 'Tense', 'SubjNumber', 'ObjNumber', 550 | # 'OddManOut', 'CoordinationInversion'] 551 | #transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness'] 552 | #transfer_tasks = ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC'] 553 | transfer_tasks = FLAGS.senteval_tasks.split(",") 554 | results = se.eval(transfer_tasks) 555 | from collections import OrderedDict 556 | results = OrderedDict(results) 557 | for key in sorted(results): 558 | value = results[key] 559 | if key.startswith("STS"): 560 | print("'" + key + "':", value["all"]) 561 | else: 562 | print(key, value) 563 | 564 | return 565 | 566 | processors = { 567 | 'sts-b': StsbProcessor, 568 | 'sick-r': SickRProcessor, 569 | 'mnli': MnliProcessor, 570 | 'allnli': MnliProcessor, 571 | 'qqp': QqpProcessor, 572 | 'sts-12-16': Sts_12_16_Processor, 573 | 'sts-12': Sts_12_16_Processor, 574 | 'sts-13': Sts_12_16_Processor, 575 | 'sts-14': Sts_12_16_Processor, 576 | 'sts-15': Sts_12_16_Processor, 577 | 'sts-16': Sts_12_16_Processor, 578 | 'mrpc-regression': MrpcRegressionProcessor, 579 | 'qnli-regression': QnliRegressionProcessor, 580 | } 581 | 582 | task_name = FLAGS.task_name.lower() 583 | if task_name not in processors: 584 | raise ValueError("Task not found: %s" % (task_name)) 585 | 586 | if task_name == 'sick-r' or task_name.startswith("sts"): 587 | is_regression = True 588 | label_min, label_max = 0., 5. 589 | elif task_name in ['qqp', 'mrpc-regression', 'qnli-regression']: 590 | is_regression = True 591 | label_min, label_max = 0., 1. 592 | else: 593 | is_regression = False 594 | label_min, label_max = 0., 1. 595 | 596 | dupe_factor = FLAGS.dupe_factor 597 | 598 | processor = processors[task_name]() 599 | 600 | label_list = processor.get_labels() 601 | 602 | # this block is moved here for calculating the epoch_step for save_checkpoints_steps 603 | train_examples = None 604 | num_train_steps = None 605 | num_warmup_steps = None 606 | 607 | if task_name == "allnli": 608 | FLAGS.data_dir = os.path.join(os.path.dirname(FLAGS.data_dir), "MNLI") 609 | 610 | if FLAGS.do_train and FLAGS.num_train_epochs > 1e-6: 611 | train_examples = processor.get_train_examples(FLAGS.data_dir) 612 | 613 | if task_name == "allnli": 614 | snli_data_dir = os.path.join(os.path.dirname(FLAGS.data_dir), "SNLI") 615 | train_examples.extend(SnliTrainProcessor().get_train_examples(snli_data_dir)) 616 | train_examples.extend(SnliDevTestProcessor().get_dev_examples(snli_data_dir)) 617 | train_examples.extend(SnliDevTestProcessor().get_test_examples(snli_data_dir)) 618 | 619 | if FLAGS.use_full_for_training: 620 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 621 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 622 | train_examples.extend(eval_examples + predict_examples) 623 | 624 | num_train_steps = int( 625 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 626 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 627 | epoch_step = int(len(train_examples) / FLAGS.train_batch_size) 628 | 629 | if FLAGS.num_examples > 0: 630 | random.shuffle(train_examples) 631 | for i in range(FLAGS.num_examples, len(train_examples)): 632 | train_examples[i].label = -10 633 | 634 | random.shuffle(train_examples) 635 | 636 | 637 | # ==== # 638 | 639 | if FLAGS.early_stopping: 640 | save_checkpoints_steps = epoch_step 641 | else: 642 | save_checkpoints_steps = FLAGS.save_checkpoints_steps 643 | 644 | keep_checkpoint_max = 3 645 | save_summary_steps = log_every_step = FLAGS.log_every_step 646 | 647 | tf.logging.info("save_checkpoints_steps: %d" % save_checkpoints_steps) 648 | 649 | # make exp dir 650 | if FLAGS.exp_name: 651 | output_dir = os.path.join(FLAGS.output_parent_dir, FLAGS.exp_name) 652 | elif FLAGS.exp_name_prefix: 653 | output_dir = os.path.join(FLAGS.output_parent_dir, FLAGS.exp_name_prefix) 654 | 655 | output_dir += "_t_%s" % (FLAGS.task_name) 656 | output_dir += "_ep_%.2f" % (FLAGS.num_train_epochs) 657 | output_dir += "_lr_%.2e" % (FLAGS.learning_rate) 658 | 659 | if FLAGS.train_batch_size != 32: 660 | output_dir += "_bsz_%d" % (FLAGS.train_batch_size) 661 | 662 | if FLAGS.sentence_embedding_type != "avg": 663 | output_dir += "_e_%s" % (FLAGS.sentence_embedding_type) 664 | 665 | if FLAGS.flow > 0: 666 | output_dir += "_f_%d%d" % (FLAGS.flow, FLAGS.flow_loss) 667 | 668 | if FLAGS.flow_loss > 0: 669 | output_dir += "_%.2e" % (FLAGS.flow_learning_rate) 670 | 671 | if FLAGS.use_full_for_training > 0: 672 | output_dir += "_allsplits" 673 | 674 | if FLAGS.flow_model_config != "config_l3_d3_w32": 675 | output_dir += "_%s" % (FLAGS.flow_model_config) 676 | 677 | if FLAGS.num_examples > 0: 678 | output_dir += "_n_%d" % (FLAGS.num_examples) 679 | 680 | if FLAGS.low_dim > -1: 681 | output_dir += "_ld_%d" % (FLAGS.low_dim) 682 | 683 | if FLAGS.l2_penalty > 0: 684 | output_dir += "_l2_%.2e" % (FLAGS.l2_penalty) 685 | 686 | else: 687 | raise NotImplementedError 688 | 689 | if tf.gfile.Exists(output_dir) and FLAGS.do_train: 690 | tf.io.gfile.rmtree(output_dir) 691 | tf.gfile.MakeDirs(output_dir) 692 | 693 | # set up estimator 694 | run_config = tf.estimator.RunConfig( 695 | model_dir=output_dir, 696 | save_summary_steps=save_summary_steps, 697 | save_checkpoints_steps=save_checkpoints_steps, 698 | keep_checkpoint_max=keep_checkpoint_max, 699 | log_step_count_steps=log_every_step, 700 | session_config=run_config) 701 | 702 | model_fn = model_fn_builder( 703 | bert_config=bert_config, 704 | num_labels=len(label_list), 705 | init_checkpoint=FLAGS.init_checkpoint, 706 | learning_rate=FLAGS.learning_rate, 707 | num_train_steps=num_train_steps, 708 | num_warmup_steps=num_warmup_steps, 709 | is_regression=is_regression) 710 | 711 | estimator = tf.estimator.Estimator( 712 | model_fn=model_fn, 713 | config=run_config, 714 | params={ 715 | 'train_batch_size': FLAGS.train_batch_size, 716 | 'eval_batch_size': FLAGS.eval_batch_size, 717 | 'predict_batch_size': FLAGS.predict_batch_size}) 718 | 719 | def get_train_input_fn(): 720 | cached_dir = FLAGS.cached_dir 721 | if not cached_dir: 722 | cached_dir = output_dir 723 | 724 | data_name = task_name 725 | 726 | if FLAGS.num_examples > 0: 727 | train_file = os.path.join(cached_dir, 728 | data_name + "_n_%d" % (FLAGS.num_examples) \ 729 | + "_seed_%d" % (FLAGS.seed) + "_train.tf_record") 730 | elif FLAGS.use_full_for_training > 0: 731 | train_file = os.path.join(cached_dir, data_name + "_allsplits.tf_record") 732 | else: 733 | train_file = os.path.join(cached_dir, data_name + "_train.tf_record") 734 | 735 | if not tf.gfile.Exists(train_file): 736 | file_based_convert_examples_to_features( 737 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file, 738 | dupe_factor, label_min, label_max, 739 | is_training=True) 740 | 741 | tf.logging.info("***** Running training *****") 742 | tf.logging.info(" Num examples = %d", len(train_examples)) 743 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 744 | tf.logging.info(" Num steps = %d", num_train_steps) 745 | 746 | train_input_fn = file_based_input_fn_builder( 747 | input_file=train_file, 748 | seq_length=FLAGS.max_seq_length, 749 | is_training=True, 750 | drop_remainder=True, 751 | is_regression=is_regression) 752 | 753 | return train_input_fn 754 | 755 | def get_eval_input_fn(): 756 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 757 | num_actual_eval_examples = len(eval_examples) 758 | 759 | cached_dir = FLAGS.cached_dir 760 | if not cached_dir: 761 | cached_dir = output_dir 762 | eval_file = os.path.join(cached_dir, task_name + "_eval.tf_record") 763 | 764 | if not tf.gfile.Exists(eval_file): 765 | file_based_convert_examples_to_features( 766 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file, 767 | dupe_factor, label_min, label_max) 768 | 769 | tf.logging.info("***** Running evaluation *****") 770 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 771 | len(eval_examples), num_actual_eval_examples, 772 | len(eval_examples) - num_actual_eval_examples) 773 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 774 | 775 | # This tells the estimator to run through the entire set. 776 | eval_drop_remainder = False 777 | eval_input_fn = file_based_input_fn_builder( 778 | input_file=eval_file, 779 | seq_length=FLAGS.max_seq_length, 780 | is_training=False, 781 | drop_remainder=eval_drop_remainder, 782 | is_regression=is_regression) 783 | 784 | return eval_input_fn 785 | 786 | def get_predict_input_fn(): 787 | predict_examples = None 788 | if FLAGS.do_predict_on_dev: 789 | predict_examples = processor.get_dev_examples(FLAGS.data_dir) 790 | elif FLAGS.do_predict_on_full: 791 | train_examples = processor.get_train_examples(FLAGS.data_dir) 792 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 793 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 794 | predict_examples.extend(eval_examples + train_examples) 795 | else: 796 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 797 | num_actual_predict_examples = len(predict_examples) 798 | 799 | cached_dir = FLAGS.cached_dir 800 | if not cached_dir: 801 | cached_dir = output_dir 802 | predict_file = os.path.join(cached_dir, task_name + "_predict.tf_record") 803 | 804 | file_based_convert_examples_to_features( 805 | predict_examples, label_list, FLAGS.max_seq_length, tokenizer, predict_file, 806 | dupe_factor, label_min, label_max) 807 | 808 | tf.logging.info("***** Running prediction*****") 809 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 810 | len(predict_examples), num_actual_predict_examples, 811 | len(predict_examples) - num_actual_predict_examples) 812 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 813 | 814 | predict_drop_remainder = False 815 | predict_input_fn = file_based_input_fn_builder( 816 | input_file=predict_file, 817 | seq_length=FLAGS.max_seq_length, 818 | is_training=False, 819 | drop_remainder=predict_drop_remainder, 820 | is_regression=is_regression) 821 | 822 | return predict_input_fn, num_actual_predict_examples 823 | 824 | eval_steps = None 825 | 826 | if FLAGS.do_train and FLAGS.num_train_epochs > 1e-6: 827 | train_input_fn = get_train_input_fn() 828 | if FLAGS.early_stopping: 829 | eval_input_fn = get_eval_input_fn() 830 | early_stopping_hook = tf.estimator.experimental.stop_if_no_decrease_hook( 831 | estimator, metric_name="eval_pearsonr", 832 | max_steps_without_decrease=epoch_step//2, run_every_steps=epoch_step, run_every_secs=None) 833 | train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps, 834 | hooks=[early_stopping_hook]) 835 | 836 | start_delay_secs = FLAGS.start_delay_secs 837 | throttle_secs = FLAGS.throttle_secs 838 | tf.logging.info("start_delay_secs: %d; throttle_secs: %d" % (start_delay_secs, throttle_secs)) 839 | eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=eval_steps, 840 | start_delay_secs=start_delay_secs, throttle_secs=throttle_secs) 841 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 842 | else: 843 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 844 | 845 | if FLAGS.do_eval: 846 | eval_input_fn = get_eval_input_fn() 847 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 848 | 849 | output_eval_file = os.path.join(output_dir, "eval_results.txt") 850 | with tf.gfile.GFile(output_eval_file, "w") as writer: 851 | tf.logging.info("***** Eval results *****") 852 | for key in sorted(result.keys()): 853 | tf.logging.info(" %s = %s", key, str(result[key])) 854 | writer.write("%s = %s\n" % (key, str(result[key]))) 855 | 856 | if FLAGS.do_predict: 857 | predict_input_fn, num_actual_predict_examples = get_predict_input_fn() 858 | checkpoint_path = None 859 | if FLAGS.eval_checkpoint_name: 860 | checkpoint_path = os.path.join(output_dir, FLAGS.eval_checkpoint_name) 861 | result = estimator.predict(input_fn=predict_input_fn, 862 | checkpoint_path=checkpoint_path) 863 | 864 | def round_float_list(values): 865 | values = [round(float(x), 6) for x in values.flat] 866 | return values 867 | 868 | fname = "" 869 | if FLAGS.do_predict_on_full: 870 | fname += "full" 871 | elif FLAGS.do_predict_on_dev: 872 | fname += "dev" 873 | else: 874 | fname += "test" 875 | 876 | if FLAGS.predict_pool: 877 | fname += "_pooled.tsv" 878 | else: 879 | fname += "_results.tsv" 880 | 881 | if FLAGS.eval_checkpoint_name: 882 | fname = FLAGS.eval_checkpoint_name + "." + fname 883 | output_predict_file = os.path.join(output_dir, fname) 884 | with tf.gfile.GFile(output_predict_file, "w") as writer: 885 | num_written_lines = 0 886 | tf.logging.info("***** Predict results *****") 887 | for (i, prediction) in enumerate(result): 888 | 889 | if is_regression: 890 | if FLAGS.predict_pool: 891 | embedding_a = prediction["embedding_a"] 892 | embedding_b = prediction["embedding_b"] 893 | 894 | output_json = collections.OrderedDict() 895 | output_json["embedding_a"] = round_float_list(embedding_a) 896 | output_json["embedding_b"] = round_float_list(embedding_b) 897 | 898 | output_line = json.dumps(output_json) + "\n" 899 | else: 900 | cos_similarity = prediction["cos_similarity"] 901 | if i >= num_actual_predict_examples: 902 | break 903 | output_line = str(cos_similarity) + "\n" 904 | else: 905 | raise NotImplementedError 906 | 907 | writer.write(output_line) 908 | num_written_lines += 1 909 | assert num_written_lines == num_actual_predict_examples 910 | 911 | tf.logging.info("*** output_dir ***") 912 | tf.logging.info(output_dir) 913 | 914 | 915 | if __name__ == "__main__": 916 | flags.mark_flag_as_required("data_dir") 917 | flags.mark_flag_as_required("task_name") 918 | flags.mark_flag_as_required("vocab_file") 919 | flags.mark_flag_as_required("bert_config_file") 920 | tf.app.run() 921 | -------------------------------------------------------------------------------- /scripts/eval_stsb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import sys 5 | import numpy as np 6 | import scipy 7 | import scipy.stats 8 | 9 | def get_pred(fpath): 10 | with open(fpath) as f: 11 | x = [float(_) for _ in f.readlines()] 12 | return x 13 | 14 | def get_gt(fpath, col, header=False): 15 | with open(fpath) as f: 16 | y = np.asarray([float(_.split('\t')[col]) for _ in f.readlines()[int(header):]]) 17 | return y 18 | 19 | def get_correlation(x, y): 20 | print("Pearson: %f" % pearson_r(x, y), end=", ") 21 | print("Spearman: %f" % scipy.stats.spearmanr(x, y).correlation) 22 | 23 | def get_auc(pred, y): 24 | from sklearn import metrics 25 | fpr, tpr, thresholds = metrics.roc_curve(y, pred) 26 | print("AUC: %f" % metrics.auc(fpr, tpr)) 27 | 28 | def pearson_r(x, y): 29 | """Compute Pearson correlation coefficient between two arrays.""" 30 | corr_mat = np.corrcoef(x, y) 31 | return corr_mat[0, 1] 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser(description=None) 35 | parser.add_argument('--glue_path', type=str, default="../glue_data", 36 | help='path to predicted sentence vectors') 37 | parser.add_argument('--task_name', type=str, default="sts-b", 38 | help='path to predicted sentence vectors') 39 | parser.add_argument('--pred_path', type=str, 40 | help='path to predicted sentence vectors') 41 | parser.add_argument('--is_test', type=int, default=0, 42 | help='eval/test set') 43 | args = parser.parse_args() 44 | 45 | x = get_pred(args.pred_path) 46 | if args.task_name.lower() == "sts-b": 47 | if args.is_test == 1: 48 | fpath = os.path.join(args.glue_path, "STS-B/sts-test.csv") 49 | y = get_gt(fpath, 4, 0) 50 | elif args.is_test == 0: 51 | fpath = os.path.join(args.glue_path, "STS-B/dev.tsv") 52 | y = get_gt(fpath, 9, 1) 53 | elif args.is_test == -1: 54 | fpath = os.path.join(args.glue_path, "STS-B/train.tsv") 55 | y = get_gt(fpath, 9, 1) 56 | else: 57 | raise NotImplementedError 58 | elif args.task_name.lower() == "sick-r": 59 | fpath = os.path.join(args.glue_path, "SICK-R/SICK_test_annotated.txt") 60 | y = get_gt(fpath, 3, 1) 61 | elif args.task_name.lower() == "mrpc-regression": 62 | fpath = os.path.join(args.glue_path, "MRPC-Regression/msr_paraphrase_test.txt") 63 | y = get_gt(fpath, 0, 1) 64 | else: 65 | raise NotImplementedError 66 | 67 | get_correlation(x, y) 68 | if args.task_name.lower() in ["mrpc-regression", "qnli-regression"]: 69 | get_auc(x, y) 70 | 71 | -------------------------------------------------------------------------------- /scripts/train_siamese.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CURDIR=$(cd $(dirname $0); cd ..; pwd) 3 | 4 | export BERT_DIR=${BERT_PREMODELS}/uncased_L-12_H-768_A-12 5 | 6 | if [[ $BERT_NAME == "large-wwm" ]];then 7 | export BERT_DIR=${BERT_PREMODELS}/wwm_uncased_L-24_H-1024_A-16 8 | elif [[ $BERT_NAME == "large" ]];then 9 | export BERT_DIR=${BERT_PREMODELS}/uncased_L-24_H-1024_A-16 10 | else 11 | export BERT_DIR=${BERT_PREMODELS}/uncased_L-12_H-768_A-12 12 | fi 13 | 14 | if [ -z "$INIT_CKPT" ]; then 15 | export INIT_CKPT=$BERT_DIR/bert_model.ckpt 16 | fi 17 | 18 | if [ -z "$TASK_NAME" ]; then 19 | export TASK_NAME="STS-B" 20 | fi 21 | 22 | 23 | if [[ $1 == "train" ]];then 24 | echo "train" 25 | 26 | exec python3 ${CURDIR}/run_siamese.py \ 27 | --task_name=${TASK_NAME} \ 28 | --do_train=true \ 29 | --do_eval=true \ 30 | --data_dir=${GLUE_DIR}/${TASK_NAME} \ 31 | --vocab_file=${BERT_DIR}/vocab.txt \ 32 | --bert_config_file=${BERT_DIR}/bert_config.json \ 33 | --init_checkpoint=${INIT_CKPT} \ 34 | --max_seq_length=64 \ 35 | --output_parent_dir=${OUTPUT_PARENT_DIR} \ 36 | ${@:2} 37 | elif [[ $1 == "eval" ]];then 38 | echo "eval" 39 | python3 ${CURDIR}/run_siamese.py \ 40 | --task_name=${TASK_NAME} \ 41 | --do_eval=true \ 42 | --data_dir=${GLUE_DIR}/${TASK_NAME} \ 43 | --vocab_file=${BERT_DIR}/vocab.txt \ 44 | --bert_config_file=${BERT_DIR}/bert_config.json \ 45 | --init_checkpoint=${INIT_CKPT} \ 46 | --max_seq_length=64 \ 47 | --output_parent_dir=${OUTPUT_PARENT_DIR} \ 48 | ${@:2} 49 | elif [[ $1 == "predict" ]];then 50 | echo "predict" 51 | python3 ${CURDIR}/run_siamese.py \ 52 | --task_name=${TASK_NAME} \ 53 | --do_predict=true \ 54 | --data_dir=${GLUE_DIR}/${TASK_NAME} \ 55 | --vocab_file=${BERT_DIR}/vocab.txt \ 56 | --bert_config_file=${BERT_DIR}/bert_config.json \ 57 | --init_checkpoint=${INIT_CKPT} \ 58 | --max_seq_length=64 \ 59 | --output_parent_dir=${OUTPUT_PARENT_DIR} \ 60 | ${@:2} 61 | 62 | python3 scripts/eval_stsb.py \ 63 | --glue_path=${GLUE_DIR} \ 64 | --task_name=${TASK_NAME} \ 65 | --pred_path=${OUTPUT_PARENT_DIR}/${EXP_NAME}/test_results.tsv \ 66 | --is_test=1 67 | 68 | elif [[ $1 == "predict_pool" ]];then 69 | echo "predict_dev" 70 | python3 ${CURDIR}/run_siamese.py \ 71 | --task_name=${TASK_NAME} \ 72 | --do_predict=true \ 73 | --data_dir=${GLUE_DIR}/${TASK_NAME} \ 74 | --vocab_file=${BERT_DIR}/vocab.txt \ 75 | --bert_config_file=${BERT_DIR}/bert_config.json \ 76 | --max_seq_length=64 \ 77 | --output_parent_dir=${OUTPUT_PARENT_DIR} \ 78 | --predict_pool=True \ 79 | ${@:2} 80 | 81 | elif [[ $1 == "predict_dev" ]];then 82 | echo "predict_dev" 83 | python3 ${CURDIR}/run_siamese.py \ 84 | --task_name=${TASK_NAME} \ 85 | --do_predict=true \ 86 | --data_dir=${GLUE_DIR}/${TASK_NAME} \ 87 | --vocab_file=${BERT_DIR}/vocab.txt \ 88 | --bert_config_file=${BERT_DIR}/bert_config.json \ 89 | --max_seq_length=64 \ 90 | --output_parent_dir=${OUTPUT_PARENT_DIR} \ 91 | --do_predict_on_dev=True \ 92 | --predict_pool=True \ 93 | ${@:2} 94 | 95 | elif [[ $1 == "predict_full" ]];then 96 | echo "predict_dev" 97 | python3 ${CURDIR}/run_siamese.py \ 98 | --task_name=${TASK_NAME} \ 99 | --do_predict=true \ 100 | --data_dir=${GLUE_DIR}/${TASK_NAME} \ 101 | --vocab_file=${BERT_DIR}/vocab.txt \ 102 | --bert_config_file=${BERT_DIR}/bert_config.json \ 103 | --max_seq_length=64 \ 104 | --output_parent_dir=${OUTPUT_PARENT_DIR} \ 105 | --do_predict_on_full=True \ 106 | --predict_pool=True \ 107 | ${@:2} 108 | 109 | elif [[ $1 == "do_senteval" ]];then 110 | echo "do_senteval" 111 | python3 ${CURDIR}/run_siamese.py \ 112 | --task_name=${TASK_NAME} \ 113 | --do_senteval=true \ 114 | --data_dir=${GLUE_DIR}/${TASK_NAME} \ 115 | --vocab_file=${BERT_DIR}/vocab.txt \ 116 | --bert_config_file=${BERT_DIR}/bert_config.json \ 117 | --init_checkpoint=${INIT_CKPT} \ 118 | --max_seq_length=64 \ 119 | --output_parent_dir=${OUTPUT_PARENT_DIR} \ 120 | ${@:2} 121 | 122 | else 123 | echo "NotImplementedError" 124 | fi 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /siamese_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import warnings 6 | warnings.simplefilter(action='ignore', category=FutureWarning) 7 | 8 | import os 9 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" 10 | 11 | import collections 12 | import csv 13 | 14 | import modeling 15 | import optimization 16 | import tokenization 17 | 18 | import tensorflow as tf 19 | 20 | import random 21 | import numpy as np 22 | 23 | 24 | class InputExample(object): 25 | """A single training/test example for simple sequence classification.""" 26 | 27 | def __init__(self, guid, text_a, text_b=None, label=None): 28 | """Constructs a InputExample. 29 | 30 | Args: 31 | guid: Unique id for the example. 32 | text_a: string. The untokenized text of the first sequence. For single 33 | sequence tasks, only this sequence must be specified. 34 | text_b: (Optional) string. The untokenized text of the second sequence. 35 | Only must be specified for sequence pair tasks. 36 | label: (Optional) string. The label of the example. This should be 37 | specified for train and dev examples, but not for test examples. 38 | """ 39 | self.guid = guid 40 | self.text_a = text_a 41 | self.text_b = text_b 42 | self.label = label 43 | 44 | 45 | class InputFeatures(object): 46 | """A single set of features of data.""" 47 | 48 | def __init__(self, 49 | input_ids_a, 50 | input_mask_a, 51 | segment_ids_a, 52 | input_ids_b, 53 | input_mask_b, 54 | segment_ids_b, 55 | label_id, 56 | is_real_example=True): 57 | self.input_ids_a = input_ids_a 58 | self.input_mask_a = input_mask_a 59 | self.segment_ids_a = segment_ids_a 60 | self.input_ids_b = input_ids_b 61 | self.input_mask_b = input_mask_b 62 | self.segment_ids_b = segment_ids_b 63 | self.label_id = label_id 64 | self.is_real_example = is_real_example 65 | 66 | 67 | class DataProcessor(object): 68 | """Base class for data converters for sequence classification data sets.""" 69 | 70 | def get_train_examples(self, data_dir): 71 | """Gets a collection of `InputExample`s for the train set.""" 72 | raise NotImplementedError() 73 | 74 | def get_dev_examples(self, data_dir): 75 | """Gets a collection of `InputExample`s for the dev set.""" 76 | raise NotImplementedError() 77 | 78 | def get_test_examples(self, data_dir): 79 | """Gets a collection of `InputExample`s for prediction.""" 80 | raise NotImplementedError() 81 | 82 | def get_labels(self): 83 | """Gets the list of labels for this data set.""" 84 | raise NotImplementedError() 85 | 86 | @classmethod 87 | def _read_tsv(cls, input_file, quotechar=None): 88 | """Reads a tab separated value file.""" 89 | with tf.gfile.Open(input_file, "r") as f: 90 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 91 | lines = [] 92 | for line in reader: 93 | lines.append(line) 94 | return lines 95 | 96 | 97 | class GLUEProcessor(DataProcessor): 98 | def __init__(self): 99 | self.train_file = "train.tsv" 100 | self.dev_file = "dev.tsv" 101 | self.test_file = "test.tsv" 102 | self.label_column = None 103 | self.text_a_column = None 104 | self.text_b_column = None 105 | self.contains_header = True 106 | self.test_text_a_column = None 107 | self.test_text_b_column = None 108 | self.test_contains_header = True 109 | 110 | def get_train_examples(self, data_dir): 111 | """See base class.""" 112 | return self._create_examples( 113 | self._read_tsv(os.path.join(data_dir, self.train_file)), "train") 114 | 115 | def get_dev_examples(self, data_dir): 116 | """See base class.""" 117 | return self._create_examples( 118 | self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev") 119 | 120 | def get_test_examples(self, data_dir): 121 | """See base class.""" 122 | if self.test_text_a_column is None: 123 | self.test_text_a_column = self.text_a_column 124 | if self.test_text_b_column is None: 125 | self.test_text_b_column = self.text_b_column 126 | 127 | return self._create_examples( 128 | self._read_tsv(os.path.join(data_dir, self.test_file)), "test") 129 | 130 | def get_labels(self): 131 | """See base class.""" 132 | return ["0", "1"] 133 | 134 | def _create_examples(self, lines, set_type): 135 | """Creates examples for the training and dev sets.""" 136 | examples = [] 137 | for (i, line) in enumerate(lines): 138 | if i == 0 and self.contains_header and set_type != "test": 139 | continue 140 | if i == 0 and self.test_contains_header and set_type == "test": 141 | continue 142 | guid = "%s-%s" % (set_type, i) 143 | 144 | a_column = (self.text_a_column if set_type != "test" else 145 | self.test_text_a_column) 146 | b_column = (self.text_b_column if set_type != "test" else 147 | self.test_text_b_column) 148 | 149 | # there are some incomplete lines in QNLI 150 | if len(line) <= a_column: 151 | tf.logging.warning('Incomplete line, ignored.') 152 | continue 153 | text_a = line[a_column] 154 | 155 | if b_column is not None: 156 | if len(line) <= b_column: 157 | tf.logging.warning('Incomplete line, ignored.') 158 | continue 159 | text_b = line[b_column] 160 | else: 161 | text_b = None 162 | 163 | if set_type == "test": 164 | label = self.get_labels()[0] 165 | else: 166 | if len(line) <= self.label_column: 167 | tf.logging.warning('Incomplete line, ignored.') 168 | continue 169 | label = line[self.label_column] 170 | if len(label) == 0: 171 | raise Exception 172 | examples.append( 173 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 174 | return examples 175 | 176 | 177 | class StsbProcessor(GLUEProcessor): 178 | def __init__(self): 179 | super(StsbProcessor, self).__init__() 180 | self.label_column = 9 181 | self.text_a_column = 7 182 | self.text_b_column = 8 183 | 184 | def get_labels(self): 185 | return [0.] 186 | 187 | def _create_examples(self, lines, set_type): 188 | """Creates examples for the training and dev sets.""" 189 | examples = [] 190 | for (i, line) in enumerate(lines): 191 | if i == 0 and self.contains_header and set_type != "test": 192 | continue 193 | if i == 0 and self.test_contains_header and set_type == "test": 194 | continue 195 | guid = "%s-%s" % (set_type, i) 196 | 197 | a_column = (self.text_a_column if set_type != "test" else 198 | self.test_text_a_column) 199 | b_column = (self.text_b_column if set_type != "test" else 200 | self.test_text_b_column) 201 | 202 | # there are some incomplete lines in QNLI 203 | if len(line) <= a_column: 204 | tf.logging.warning('Incomplete line, ignored.') 205 | continue 206 | text_a = line[a_column] 207 | 208 | if b_column is not None: 209 | if len(line) <= b_column: 210 | tf.logging.warning('Incomplete line, ignored.') 211 | continue 212 | text_b = line[b_column] 213 | else: 214 | text_b = None 215 | 216 | if set_type == "test": 217 | label = self.get_labels()[0] 218 | else: 219 | if len(line) <= self.label_column: 220 | tf.logging.warning('Incomplete line, ignored.') 221 | continue 222 | label = float(line[self.label_column]) 223 | examples.append( 224 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 225 | 226 | return examples 227 | 228 | 229 | class SickRProcessor(StsbProcessor): 230 | """Processor for the MultiNLI data set (GLUE version).""" 231 | def __init__(self): 232 | super(SickRProcessor, self).__init__() 233 | self.train_file = "SICK_train.txt" 234 | self.dev_file = "SICK_trial.txt" 235 | self.test_file = "SICK_test_annotated.txt" 236 | self.label_column = 3 237 | self.text_a_column = 1 238 | self.text_b_column = 2 239 | self.contains_header = True 240 | self.test_text_a_column = None 241 | self.test_text_b_column = None 242 | self.test_contains_header = True 243 | 244 | class Sts_12_16_Processor(GLUEProcessor): 245 | def __init__(self): 246 | super(Sts_12_16_Processor, self).__init__() 247 | self.train_file = "full.txt" 248 | self.dev_file = "full.txt" 249 | self.test_file = "full.txt" 250 | self.text_a_column = 0 251 | self.text_b_column = 1 252 | 253 | def get_labels(self): 254 | return [0.] 255 | 256 | def _create_examples(self, lines, set_type): 257 | """Creates examples for the training and dev sets.""" 258 | examples = [] 259 | for (i, line) in enumerate(lines): 260 | guid = "%s-%s" % (set_type, i) 261 | text_a = line[self.text_a_column] 262 | text_b = line[self.text_b_column] 263 | label = self.get_labels()[0] 264 | examples.append( 265 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 266 | 267 | return examples 268 | 269 | 270 | class QqpProcessor(GLUEProcessor): 271 | """Processor for the MultiNLI data set (GLUE version).""" 272 | def __init__(self): 273 | super(QqpProcessor, self).__init__() 274 | self.train_file = "train.tsv" 275 | self.dev_file = "dev.tsv" 276 | self.test_file = "test.tsv" 277 | self.label_column = 5 278 | self.text_a_column = 3 279 | self.text_b_column = 4 280 | self.contains_header = True 281 | self.test_text_a_column = 1 282 | self.test_text_b_column = 2 283 | self.test_contains_header = True 284 | 285 | def get_labels(self): 286 | """See base class.""" 287 | return [0.] 288 | 289 | 290 | class MrpcRegressionProcessor(StsbProcessor): 291 | def __init__(self): 292 | super(MrpcRegressionProcessor, self).__init__() 293 | self.label_column = 0 294 | self.text_a_column = 3 295 | self.text_b_column = 4 296 | 297 | 298 | class QnliRegressionProcessor(GLUEProcessor): 299 | def __init__(self): 300 | super(QnliRegressionProcessor, self).__init__() 301 | self.label_column = -1 302 | self.text_a_column = 1 303 | self.text_b_column = 2 304 | 305 | def get_labels(self): 306 | return [0.] 307 | 308 | def _create_examples(self, lines, set_type): 309 | """Creates examples for the training and dev sets.""" 310 | examples = [] 311 | for (i, line) in enumerate(lines): 312 | if i == 0 and self.contains_header and set_type != "test": 313 | continue 314 | if i == 0 and self.test_contains_header and set_type == "test": 315 | continue 316 | guid = "%s-%s" % (set_type, i) 317 | 318 | a_column = (self.text_a_column if set_type != "test" else 319 | self.test_text_a_column) 320 | b_column = (self.text_b_column if set_type != "test" else 321 | self.test_text_b_column) 322 | 323 | # there are some incomplete lines in QNLI 324 | if len(line) <= a_column: 325 | tf.logging.warning('Incomplete line, ignored.') 326 | continue 327 | text_a = line[a_column] 328 | 329 | if b_column is not None: 330 | if len(line) <= b_column: 331 | tf.logging.warning('Incomplete line, ignored.') 332 | continue 333 | text_b = line[b_column] 334 | else: 335 | text_b = None 336 | 337 | if set_type == "test": 338 | label = self.get_labels()[0] 339 | else: 340 | if len(line) <= self.label_column: 341 | tf.logging.warning('Incomplete line, ignored.') 342 | continue 343 | label_score_map = { "not_entailment": 0, "entailment": 1 } 344 | label = label_score_map[line[self.label_column]] 345 | examples.append( 346 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 347 | 348 | return examples 349 | 350 | 351 | class MnliProcessor(GLUEProcessor): 352 | """Processor for the MultiNLI data set (GLUE version).""" 353 | def __init__(self): 354 | super(MnliProcessor, self).__init__() 355 | self.train_file = "train.tsv" 356 | self.dev_file = "dev_matched.tsv" 357 | self.test_file = "test_matched.tsv" 358 | self.label_column = 10 359 | self.text_a_column = 8 360 | self.text_b_column = 9 361 | self.contains_header = True 362 | self.test_text_a_column = None 363 | self.test_text_b_column = None 364 | self.test_contains_header = True 365 | 366 | def get_labels(self): 367 | """See base class.""" 368 | return ["contradiction", "entailment", "neutral"] 369 | 370 | 371 | class SnliTrainProcessor(GLUEProcessor): 372 | """Processor for the MultiNLI data set (GLUE version).""" 373 | def __init__(self): 374 | super(SnliTrainProcessor, self).__init__() 375 | self.train_file = "train.tsv" 376 | self.dev_file = "dev.tsv" 377 | self.test_file = "test.tsv" 378 | self.label_column = -1 379 | self.text_a_column = 7 380 | self.text_b_column = 8 381 | self.contains_header = True 382 | 383 | def get_labels(self): 384 | """See base class.""" 385 | return ["contradiction", "entailment", "neutral"] 386 | 387 | class SnliDevTestProcessor(SnliTrainProcessor): 388 | """Processor for the MultiNLI data set (GLUE version).""" 389 | def __init__(self): 390 | super(SnliDevTestProcessor, self).__init__() 391 | self.label_column = -1 392 | 393 | def get_input_mask_segment(text, 394 | max_seq_length, tokenizer, random_mask=0): 395 | tokens = tokenizer.tokenize(text) 396 | 397 | # Account for [CLS] and [SEP] with "- 2" 398 | if len(tokens) > max_seq_length - 2: 399 | tokens = tokens[0:(max_seq_length - 2)] 400 | 401 | if random_mask: 402 | tokens[random.randint(0, len(tokens)-1)] = "[MASK]" 403 | 404 | tokens = ["[CLS]"] + tokens + ["[SEP]"] 405 | segment_ids = [0 for _ in tokens] 406 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 407 | 408 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 409 | # tokens are attended to. 410 | input_mask = [1] * len(input_ids) 411 | 412 | # Zero-pad up to the sequence length. 413 | while len(input_ids) < max_seq_length: 414 | input_ids.append(0) 415 | input_mask.append(0) 416 | segment_ids.append(0) 417 | 418 | assert len(input_ids) == max_seq_length 419 | assert len(input_mask) == max_seq_length 420 | assert len(segment_ids) == max_seq_length 421 | 422 | return (input_ids, input_mask, segment_ids, tokens) 423 | 424 | def convert_single_example(ex_index, example, label_list, max_seq_length, 425 | tokenizer, random_mask=0): 426 | """Converts a single `InputExample` into a single `InputFeatures`.""" 427 | 428 | label_map = {} 429 | for (i, label) in enumerate(label_list): 430 | label_map[label] = i 431 | 432 | input_ids_a, input_mask_a, segment_ids_a, tokens_a = \ 433 | get_input_mask_segment(example.text_a, max_seq_length, tokenizer, random_mask) 434 | input_ids_b, input_mask_b, segment_ids_b, tokens_b = \ 435 | get_input_mask_segment(example.text_b, max_seq_length, tokenizer, random_mask) 436 | 437 | if len(label_list) > 1: 438 | label_id = label_map[example.label] 439 | else: 440 | label_id = example.label 441 | 442 | if ex_index < 5: 443 | tf.logging.info("*** Example ***") 444 | tf.logging.info("guid: %s" % (example.guid)) 445 | tf.logging.info("tokens_a: %s" % " ".join( 446 | [tokenization.printable_text(x) for x in tokens_a])) 447 | tf.logging.info("tokens_b: %s" % " ".join( 448 | [tokenization.printable_text(x) for x in tokens_b])) 449 | tf.logging.info("input_ids_a: %s" % " ".join([str(x) for x in input_ids_a])) 450 | tf.logging.info("input_mask_a: %s" % " ".join([str(x) for x in input_mask_a])) 451 | tf.logging.info("segment_ids_a: %s" % " ".join([str(x) for x in segment_ids_a])) 452 | tf.logging.info("input_ids_b: %s" % " ".join([str(x) for x in input_ids_b])) 453 | tf.logging.info("input_mask_b: %s" % " ".join([str(x) for x in input_mask_b])) 454 | tf.logging.info("segment_ids_b: %s" % " ".join([str(x) for x in segment_ids_b])) 455 | tf.logging.info("label: %s (id = %s)" % (example.label, label_id)) 456 | 457 | feature = InputFeatures( 458 | input_ids_a=input_ids_a, 459 | input_mask_a=input_mask_a, 460 | segment_ids_a=segment_ids_a, 461 | input_ids_b=input_ids_b, 462 | input_mask_b=input_mask_b, 463 | segment_ids_b=segment_ids_b, 464 | label_id=label_id, 465 | is_real_example=True) 466 | return feature 467 | 468 | 469 | def file_based_convert_examples_to_features( 470 | examples, label_list, max_seq_length, tokenizer, output_file, 471 | dupe_factor, label_min, label_max, is_training=False): 472 | """Convert a set of `InputExample`s to a TFRecord file.""" 473 | 474 | writer = tf.python_io.TFRecordWriter(output_file) 475 | 476 | for (ex_index, example) in enumerate(examples): 477 | if ex_index % 10000 == 0: 478 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 479 | 480 | for t in range(dupe_factor if is_training else 1): 481 | feature = convert_single_example(ex_index, example, label_list, 482 | max_seq_length, tokenizer, 483 | random_mask=0 if t == 0 else 1) 484 | 485 | def create_int_feature(values): 486 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 487 | return f 488 | 489 | def create_float_feature(values): 490 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 491 | return f 492 | 493 | features = collections.OrderedDict() 494 | features["input_ids_a"] = create_int_feature(feature.input_ids_a) 495 | features["input_mask_a"] = create_int_feature(feature.input_mask_a) 496 | features["segment_ids_a"] = create_int_feature(feature.segment_ids_a) 497 | 498 | features["input_ids_b"] = create_int_feature(feature.input_ids_b) 499 | features["input_mask_b"] = create_int_feature(feature.input_mask_b) 500 | features["segment_ids_b"] = create_int_feature(feature.segment_ids_b) 501 | 502 | if len(label_list) > 1: 503 | features["label_ids"] = create_int_feature([feature.label_id]) 504 | else: 505 | features["label_ids"] = create_float_feature( 506 | [(float(feature.label_id) - label_min) / (label_max - label_min)]) 507 | features["is_real_example"] = create_int_feature( 508 | [int(feature.is_real_example)]) 509 | 510 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 511 | writer.write(tf_example.SerializeToString()) 512 | 513 | writer.close() 514 | 515 | 516 | def file_based_input_fn_builder(input_file, seq_length, is_training, 517 | drop_remainder, is_regression): 518 | """Creates an `input_fn` closure to be passed to Estimator.""" 519 | 520 | name_to_features = { 521 | "input_ids_a": tf.FixedLenFeature([seq_length], tf.int64), 522 | "input_mask_a": tf.FixedLenFeature([seq_length], tf.int64), 523 | "segment_ids_a": tf.FixedLenFeature([seq_length], tf.int64), 524 | "input_ids_b": tf.FixedLenFeature([seq_length], tf.int64), 525 | "input_mask_b": tf.FixedLenFeature([seq_length], tf.int64), 526 | "segment_ids_b": tf.FixedLenFeature([seq_length], tf.int64), 527 | "label_ids": tf.FixedLenFeature([], tf.int64), 528 | "is_real_example": tf.FixedLenFeature([], tf.int64), 529 | } 530 | 531 | if is_regression: 532 | name_to_features["label_ids"] = tf.FixedLenFeature([], tf.float32) 533 | 534 | def _decode_record(record, name_to_features): 535 | """Decodes a record to a TensorFlow example.""" 536 | example = tf.parse_single_example(record, name_to_features) 537 | 538 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 539 | # So cast all int64 to int32. 540 | for name in list(example.keys()): 541 | t = example[name] 542 | if t.dtype == tf.int64: 543 | t = tf.to_int32(t) 544 | example[name] = t 545 | 546 | return example 547 | 548 | def input_fn(mode, params): 549 | """The actual input function.""" 550 | if mode == tf.estimator.ModeKeys.TRAIN: 551 | batch_size = params["train_batch_size"] 552 | elif mode == tf.estimator.ModeKeys.EVAL: 553 | batch_size = params["eval_batch_size"] 554 | elif mode == tf.estimator.ModeKeys.PREDICT: 555 | batch_size = params["predict_batch_size"] 556 | else: 557 | raise NotImplementedError 558 | 559 | # For training, we want a lot of parallel reading and shuffling. 560 | # For eval, we want no shuffling and parallel reading doesn't matter. 561 | d = tf.data.TFRecordDataset(input_file) 562 | # if is_training: 563 | # d = d.repeat() 564 | # d = d.shuffle(buffer_size=100) 565 | 566 | if is_training: 567 | d = d.shuffle(buffer_size=1000, reshuffle_each_iteration=True) 568 | d = d.repeat() 569 | 570 | d = d.apply( 571 | tf.contrib.data.map_and_batch( 572 | lambda record: _decode_record(record, name_to_features), 573 | batch_size=batch_size, 574 | drop_remainder=drop_remainder)) 575 | 576 | return d 577 | 578 | return input_fn 579 | 580 | 581 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | import tokenization 22 | import six 23 | import tensorflow as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | if six.PY2: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | else: 37 | vocab_writer.write("".join( 38 | [x + "\n" for x in vocab_tokens]).encode("utf-8")) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = tokenization.FullTokenizer(vocab_file) 43 | os.unlink(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertAllEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | def test_chinese(self): 52 | tokenizer = tokenization.BasicTokenizer() 53 | 54 | self.assertAllEqual( 55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 56 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 57 | 58 | def test_basic_tokenizer_lower(self): 59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 60 | 61 | self.assertAllEqual( 62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 65 | 66 | def test_basic_tokenizer_no_lower(self): 67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 68 | 69 | self.assertAllEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 72 | 73 | def test_wordpiece_tokenizer(self): 74 | vocab_tokens = [ 75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 76 | "##ing" 77 | ] 78 | 79 | vocab = {} 80 | for (i, token) in enumerate(vocab_tokens): 81 | vocab[token] = i 82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 83 | 84 | self.assertAllEqual(tokenizer.tokenize(""), []) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwanted running"), 88 | ["un", "##want", "##ed", "runn", "##ing"]) 89 | 90 | self.assertAllEqual( 91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 92 | 93 | def test_convert_tokens_to_ids(self): 94 | vocab_tokens = [ 95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 96 | "##ing" 97 | ] 98 | 99 | vocab = {} 100 | for (i, token) in enumerate(vocab_tokens): 101 | vocab[token] = i 102 | 103 | self.assertAllEqual( 104 | tokenization.convert_tokens_to_ids( 105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(tokenization._is_whitespace(u" ")) 109 | self.assertTrue(tokenization._is_whitespace(u"\t")) 110 | self.assertTrue(tokenization._is_whitespace(u"\r")) 111 | self.assertTrue(tokenization._is_whitespace(u"\n")) 112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(tokenization._is_whitespace(u"A")) 115 | self.assertFalse(tokenization._is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(tokenization._is_control(u"\u0005")) 119 | 120 | self.assertFalse(tokenization._is_control(u"A")) 121 | self.assertFalse(tokenization._is_control(u" ")) 122 | self.assertFalse(tokenization._is_control(u"\t")) 123 | self.assertFalse(tokenization._is_control(u"\r")) 124 | self.assertFalse(tokenization._is_control(u"\U0001F4A9")) 125 | 126 | def test_is_punctuation(self): 127 | self.assertTrue(tokenization._is_punctuation(u"-")) 128 | self.assertTrue(tokenization._is_punctuation(u"$")) 129 | self.assertTrue(tokenization._is_punctuation(u"`")) 130 | self.assertTrue(tokenization._is_punctuation(u".")) 131 | 132 | self.assertFalse(tokenization._is_punctuation(u"A")) 133 | self.assertFalse(tokenization._is_punctuation(u" ")) 134 | 135 | 136 | if __name__ == "__main__": 137 | tf.test.main() 138 | --------------------------------------------------------------------------------