├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── albert_glue_fine_tuning_tutorial.ipynb ├── classifier_utils.py ├── create_pretraining_data.py ├── export_checkpoints.py ├── export_to_tfhub.py ├── fine_tuning_utils.py ├── lamb_optimizer.py ├── modeling.py ├── modeling_test.py ├── optimization.py ├── optimization_test.py ├── race_utils.py ├── requirements.txt ├── run_classifier.py ├── run_glue.sh ├── run_pretraining.py ├── run_pretraining_test.py ├── run_race.py ├── run_squad_v1.py ├── run_squad_v2.py ├── run_trivial_model_test.sh ├── squad_utils.py ├── tokenization.py └── tokenization_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ALBERT 2 | ====== 3 | 4 | ***************New March 28, 2020 *************** 5 | 6 | Add a colab [tutorial](https://github.com/google-research/albert/blob/master/albert_glue_fine_tuning_tutorial.ipynb) to run fine-tuning for GLUE datasets. 7 | 8 | ***************New January 7, 2020 *************** 9 | 10 | v2 TF-Hub models should be working now with TF 1.15, as we removed the 11 | native Einsum op from the graph. See updated TF-Hub links below. 12 | 13 | ***************New December 30, 2019 *************** 14 | 15 | Chinese models are released. We would like to thank [CLUE team ](https://github.com/CLUEbenchmark/CLUE) for providing the training data. 16 | 17 | - [Base](https://storage.googleapis.com/albert_models/albert_base_zh.tar.gz) 18 | - [Large](https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz) 19 | - [Xlarge](https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz) 20 | - [Xxlarge](https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz) 21 | 22 | Version 2 of ALBERT models is released. 23 | 24 | - Base: [[Tar file](https://storage.googleapis.com/albert_models/albert_base_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_base/3)] 25 | - Large: [[Tar file](https://storage.googleapis.com/albert_models/albert_large_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_large/3)] 26 | - Xlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xlarge_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xlarge/3)] 27 | - Xxlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xxlarge_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xxlarge/3)] 28 | 29 | In this version, we apply 'no dropout', 'additional training data' and 'long training time' strategies to all models. We train ALBERT-base for 10M steps and other models for 3M steps. 30 | 31 | The result comparison to the v1 models is as followings: 32 | 33 | | | Average | SQuAD1.1 | SQuAD2.0 | MNLI | SST-2 | RACE | 34 | |----------------|----------|----------|----------|----------|----------|----------| 35 | |V2 | 36 | |ALBERT-base |82.3 |90.2/83.2 |82.1/79.3 |84.6 |92.9 |66.8 | 37 | |ALBERT-large |85.7 |91.8/85.2 |84.9/81.8 |86.5 |94.9 |75.2 | 38 | |ALBERT-xlarge |87.9 |92.9/86.4 |87.9/84.1 |87.9 |95.4 |80.7 | 39 | |ALBERT-xxlarge |90.9 |94.6/89.1 |89.8/86.9 |90.6 |96.8 |86.8 | 40 | |V1 | 41 | |ALBERT-base |80.1 |89.3/82.3 | 80.0/77.1|81.6 |90.3 | 64.0 | 42 | |ALBERT-large |82.4 |90.6/83.9 | 82.3/79.4|83.5 |91.7 | 68.5 | 43 | |ALBERT-xlarge |85.5 |92.5/86.1 | 86.1/83.1|86.4 |92.4 | 74.8 | 44 | |ALBERT-xxlarge |91.0 |94.8/89.3 | 90.2/87.4|90.8 |96.9 | 86.5 | 45 | 46 | The comparison shows that for ALBERT-base, ALBERT-large, and ALBERT-xlarge, v2 is much better than v1, indicating the importance of applying the above three strategies. On average, ALBERT-xxlarge is slightly worse than the v1, because of the following two reasons: 1) Training additional 1.5 M steps (the only difference between these two models is training for 1.5M steps and 3M steps) did not lead to significant performance improvement. 2) For v1, we did a little bit hyperparameter search among the parameters sets given by BERT, Roberta, and XLnet. For v2, we simply adopt the parameters from v1 except for RACE, where we use a learning rate of 1e-5 and 0 [ALBERT DR](https://arxiv.org/pdf/1909.11942.pdf) (dropout rate for ALBERT in finetuning). The original (v1) RACE hyperparameter will cause model divergence for v2 models. Given that the downstream tasks are sensitive to the fine-tuning hyperparameters, we should be careful about so called slight improvements. 47 | 48 | ALBERT is "A Lite" version of BERT, a popular unsupervised language 49 | representation learning algorithm. ALBERT uses parameter-reduction techniques 50 | that allow for large-scale configurations, overcome previous memory limitations, 51 | and achieve better behavior with respect to model degradation. 52 | 53 | For a technical description of the algorithm, see our paper: 54 | 55 | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) 56 | 57 | Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 58 | 59 | Release Notes 60 | ============= 61 | 62 | - Initial release: 10/9/2019 63 | 64 | Results 65 | ======= 66 | 67 | Performance of ALBERT on GLUE benchmark results using a single-model setup on 68 | dev: 69 | 70 | | Models | MNLI | QNLI | QQP | RTE | SST | MRPC | CoLA | STS | 71 | |-------------------|----------|----------|----------|----------|----------|----------|----------|----------| 72 | | BERT-large | 86.6 | 92.3 | 91.3 | 70.4 | 93.2 | 88.0 | 60.6 | 90.0 | 73 | | XLNet-large | 89.8 | 93.9 | 91.8 | 83.8 | 95.6 | 89.2 | 63.6 | 91.8 | 74 | | RoBERTa-large | 90.2 | 94.7 | **92.2** | 86.6 | 96.4 | **90.9** | 68.0 | 92.4 | 75 | | ALBERT (1M) | 90.4 | 95.2 | 92.0 | 88.1 | 96.8 | 90.2 | 68.7 | 92.7 | 76 | | ALBERT (1.5M) | **90.8** | **95.3** | **92.2** | **89.2** | **96.9** | **90.9** | **71.4** | **93.0** | 77 | 78 | Performance of ALBERT-xxl on SQuaD and RACE benchmarks using a single-model 79 | setup: 80 | 81 | |Models | SQuAD1.1 dev | SQuAD2.0 dev | SQuAD2.0 test | RACE test (Middle/High) | 82 | |--------------------------|---------------|---------------|---------------|-------------------------| 83 | |BERT-large | 90.9/84.1 | 81.8/79.0 | 89.1/86.3 | 72.0 (76.6/70.1) | 84 | |XLNet | 94.5/89.0 | 88.8/86.1 | 89.1/86.3 | 81.8 (85.5/80.2) | 85 | |RoBERTa | 94.6/88.9 | 89.4/86.5 | 89.8/86.8 | 83.2 (86.5/81.3) | 86 | |UPM | - | - | 89.9/87.2 | - | 87 | |XLNet + SG-Net Verifier++ | - | - | 90.1/87.2 | - | 88 | |ALBERT (1M) | 94.8/89.2 | 89.9/87.2 | - | 86.0 (88.2/85.1) | 89 | |ALBERT (1.5M) | **94.8/89.3** | **90.2/87.4** | **90.9/88.1** | **86.5 (89.0/85.5)** | 90 | 91 | 92 | Pre-trained Models 93 | ================== 94 | TF-Hub modules are available: 95 | 96 | - Base: [[Tar file](https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_base/1)] 97 | - Large: [[Tar file](https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_large/1)] 98 | - Xlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xlarge_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xlarge/1)] 99 | - Xxlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xxlarge_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xxlarge/1)] 100 | 101 | Example usage of the TF-Hub module in code: 102 | 103 | ``` 104 | tags = set() 105 | if is_training: 106 | tags.add("train") 107 | albert_module = hub.Module("https://tfhub.dev/google/albert_base/1", tags=tags, 108 | trainable=True) 109 | albert_inputs = dict( 110 | input_ids=input_ids, 111 | input_mask=input_mask, 112 | segment_ids=segment_ids) 113 | albert_outputs = albert_module( 114 | inputs=albert_inputs, 115 | signature="tokens", 116 | as_dict=True) 117 | 118 | # If you want to use the token-level output, use 119 | # albert_outputs["sequence_output"] instead. 120 | output_layer = albert_outputs["pooled_output"] 121 | ``` 122 | 123 | Most of the fine-tuning scripts in this repository support TF-hub modules 124 | via the `--albert_hub_module_handle` flag. 125 | 126 | Pre-training Instructions 127 | ========================= 128 | To pretrain ALBERT, use `run_pretraining.py`: 129 | 130 | ``` 131 | pip install -r albert/requirements.txt 132 | python -m albert.run_pretraining \ 133 | --input_file=... \ 134 | --output_dir=... \ 135 | --init_checkpoint=... \ 136 | --albert_config_file=... \ 137 | --do_train \ 138 | --do_eval \ 139 | --train_batch_size=4096 \ 140 | --eval_batch_size=64 \ 141 | --max_seq_length=512 \ 142 | --max_predictions_per_seq=20 \ 143 | --optimizer='lamb' \ 144 | --learning_rate=.00176 \ 145 | --num_train_steps=125000 \ 146 | --num_warmup_steps=3125 \ 147 | --save_checkpoints_steps=5000 148 | ``` 149 | 150 | Fine-tuning on GLUE 151 | =================== 152 | To fine-tune and evaluate a pretrained ALBERT on GLUE, please see the 153 | convenience script `run_glue.sh`. 154 | 155 | Lower-level use cases may want to use the `run_classifier.py` script directly. 156 | The `run_classifier.py` script is used both for fine-tuning and evaluation of 157 | ALBERT on individual GLUE benchmark tasks, such as MNLI: 158 | 159 | ``` 160 | pip install -r albert/requirements.txt 161 | python -m albert.run_classifier \ 162 | --data_dir=... \ 163 | --output_dir=... \ 164 | --init_checkpoint=... \ 165 | --albert_config_file=... \ 166 | --spm_model_file=... \ 167 | --do_train \ 168 | --do_eval \ 169 | --do_predict \ 170 | --do_lower_case \ 171 | --max_seq_length=128 \ 172 | --optimizer=adamw \ 173 | --task_name=MNLI \ 174 | --warmup_step=1000 \ 175 | --learning_rate=3e-5 \ 176 | --train_step=10000 \ 177 | --save_checkpoints_steps=100 \ 178 | --train_batch_size=128 179 | ``` 180 | 181 | Good default flag values for each GLUE task can be found in `run_glue.sh`. 182 | 183 | You can fine-tune the model starting from TF-Hub modules instead of raw 184 | checkpoints by setting e.g. 185 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead 186 | of `--init_checkpoint`. 187 | 188 | You can find the spm_model_file in the tar files or under the assets folder of 189 | the tf-hub module. The name of the model file is "30k-clean.model". 190 | 191 | After evaluation, the script should report some output like this: 192 | 193 | ``` 194 | ***** Eval results ***** 195 | global_step = ... 196 | loss = ... 197 | masked_lm_accuracy = ... 198 | masked_lm_loss = ... 199 | sentence_order_accuracy = ... 200 | sentence_order_loss = ... 201 | ``` 202 | 203 | Fine-tuning on SQuAD 204 | ==================== 205 | To fine-tune and evaluate a pretrained model on SQuAD v1, use the 206 | `run_squad_v1.py` script: 207 | 208 | ``` 209 | pip install -r albert/requirements.txt 210 | python -m albert.run_squad_v1 \ 211 | --albert_config_file=... \ 212 | --output_dir=... \ 213 | --train_file=... \ 214 | --predict_file=... \ 215 | --train_feature_file=... \ 216 | --predict_feature_file=... \ 217 | --predict_feature_left_file=... \ 218 | --init_checkpoint=... \ 219 | --spm_model_file=... \ 220 | --do_lower_case \ 221 | --max_seq_length=384 \ 222 | --doc_stride=128 \ 223 | --max_query_length=64 \ 224 | --do_train=true \ 225 | --do_predict=true \ 226 | --train_batch_size=48 \ 227 | --predict_batch_size=8 \ 228 | --learning_rate=5e-5 \ 229 | --num_train_epochs=2.0 \ 230 | --warmup_proportion=.1 \ 231 | --save_checkpoints_steps=5000 \ 232 | --n_best_size=20 \ 233 | --max_answer_length=30 234 | ``` 235 | 236 | You can fine-tune the model starting from TF-Hub modules instead of raw 237 | checkpoints by setting e.g. 238 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead 239 | of `--init_checkpoint`. 240 | 241 | For SQuAD v2, use the `run_squad_v2.py` script: 242 | 243 | ``` 244 | pip install -r albert/requirements.txt 245 | python -m albert.run_squad_v2 \ 246 | --albert_config_file=... \ 247 | --output_dir=... \ 248 | --train_file=... \ 249 | --predict_file=... \ 250 | --train_feature_file=... \ 251 | --predict_feature_file=... \ 252 | --predict_feature_left_file=... \ 253 | --init_checkpoint=... \ 254 | --spm_model_file=... \ 255 | --do_lower_case \ 256 | --max_seq_length=384 \ 257 | --doc_stride=128 \ 258 | --max_query_length=64 \ 259 | --do_train \ 260 | --do_predict \ 261 | --train_batch_size=48 \ 262 | --predict_batch_size=8 \ 263 | --learning_rate=5e-5 \ 264 | --num_train_epochs=2.0 \ 265 | --warmup_proportion=.1 \ 266 | --save_checkpoints_steps=5000 \ 267 | --n_best_size=20 \ 268 | --max_answer_length=30 269 | ``` 270 | 271 | You can fine-tune the model starting from TF-Hub modules instead of raw 272 | checkpoints by setting e.g. 273 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead 274 | of `--init_checkpoint`. 275 | 276 | Fine-tuning on RACE 277 | =================== 278 | For RACE, use the `run_race.py` script: 279 | 280 | ``` 281 | pip install -r albert/requirements.txt 282 | python -m albert.run_race \ 283 | --albert_config_file=... \ 284 | --output_dir=... \ 285 | --train_file=... \ 286 | --eval_file=... \ 287 | --data_dir=...\ 288 | --init_checkpoint=... \ 289 | --spm_model_file=... \ 290 | --max_seq_length=512 \ 291 | --max_qa_length=128 \ 292 | --do_train \ 293 | --do_eval \ 294 | --train_batch_size=32 \ 295 | --eval_batch_size=8 \ 296 | --learning_rate=1e-5 \ 297 | --train_step=12000 \ 298 | --warmup_step=1000 \ 299 | --save_checkpoints_steps=100 300 | ``` 301 | 302 | You can fine-tune the model starting from TF-Hub modules instead of raw 303 | checkpoints by setting e.g. 304 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead 305 | of `--init_checkpoint`. 306 | 307 | SentencePiece 308 | ============= 309 | Command for generating the sentence piece vocabulary: 310 | 311 | ``` 312 | spm_train \ 313 | --input all.txt --model_prefix=30k-clean --vocab_size=30000 --logtostderr 314 | --pad_id=0 --unk_id=1 --eos_id=-1 --bos_id=-1 315 | --control_symbols=[CLS],[SEP],[MASK] 316 | --user_defined_symbols="(,),\",-,.,–,£,€" 317 | --shuffle_input_sentence=true --input_sentence_size=10000000 318 | --character_coverage=0.99995 --model_type=unigram 319 | ``` 320 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /albert_glue_fine_tuning_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "albert_glue_fine_tuning_tutorial", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "TPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "y8SJfpgTccDB", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\n", 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "wHQH4OCHZ9bq", 33 | "colab_type": "code", 34 | "cellView": "form", 35 | "colab": {} 36 | }, 37 | "source": [ 38 | "# @title Copyright 2020 The ALBERT Authors. All Rights Reserved.\n", 39 | "#\n", 40 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", 41 | "# you may not use this file except in compliance with the License.\n", 42 | "# You may obtain a copy of the License at\n", 43 | "#\n", 44 | "# http://www.apache.org/licenses/LICENSE-2.0\n", 45 | "#\n", 46 | "# Unless required by applicable law or agreed to in writing, software\n", 47 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 48 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 49 | "# See the License for the specific language governing permissions and\n", 50 | "# limitations under the License.\n", 51 | "# ==============================================================================" 52 | ], 53 | "execution_count": 0, 54 | "outputs": [] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "id": "rkTLZ3I4_7c_", 60 | "colab_type": "text" 61 | }, 62 | "source": [ 63 | "# ALBERT End to End (Fine-tuning + Predicting) with Cloud TPU" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "id": "1wtjs1QDb3DX", 70 | "colab_type": "text" 71 | }, 72 | "source": [ 73 | "## Overview\n", 74 | "\n", 75 | "ALBERT is \"A Lite\" version of BERT, a popular unsupervised language representation learning algorithm. ALBERT uses parameter-reduction techniques that allow for large-scale configurations, overcome previous memory limitations, and achieve better behavior with respect to model degradation.\n", 76 | "\n", 77 | "For a technical description of the algorithm, see our paper:\n", 78 | "\n", 79 | "https://arxiv.org/abs/1909.11942\n", 80 | "\n", 81 | "Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut\n", 82 | "\n", 83 | "This Colab demonstates using a free Colab Cloud TPU to fine-tune GLUE tasks built on top of pretrained ALBERT models and \n", 84 | "run predictions on tuned model. The colab demonsrates loading pretrained ALBERT models from both [TF Hub](https://www.tensorflow.org/hub) and checkpoints.\n", 85 | "\n", 86 | "**Note:** You will need a GCP (Google Compute Engine) account and a GCS (Google Cloud \n", 87 | "Storage) bucket for this Colab to run.\n", 88 | "\n", 89 | "Please follow the [Google Cloud TPU quickstart](https://cloud.google.com/tpu/docs/quickstart) for how to create GCP account and GCS bucket. You have [$300 free credit](https://cloud.google.com/free/) to get started with any GCP product. You can learn more about Cloud TPU at https://cloud.google.com/tpu/docs.\n", 90 | "\n", 91 | "This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select **File > View on GitHub**." 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": { 97 | "id": "Ld-JXlueIuPH", 98 | "colab_type": "text" 99 | }, 100 | "source": [ 101 | "## Instructions" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": { 107 | "id": "POkof5uHaQ_c", 108 | "colab_type": "text" 109 | }, 110 | "source": [ 111 | "

  Train on TPU

\n", 112 | "\n", 113 | " 1. Create a Cloud Storage bucket for your TensorBoard logs at http://console.cloud.google.com/storage and fill in the BUCKET parameter in the \"Parameters\" section below.\n", 114 | " \n", 115 | " 1. On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n", 116 | " 1. Click Runtime again and select **Runtime > Run All** (Watch out: the \"Colab-only auth for this notebook and the TPU\" cell requires user input). You can also run the cells manually with Shift-ENTER." 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": { 122 | "id": "UdMmwCJFaT8F", 123 | "colab_type": "text" 124 | }, 125 | "source": [ 126 | "### Set up your TPU environment\n", 127 | "\n", 128 | "In this section, you perform the following tasks:\n", 129 | "\n", 130 | "* Set up a Colab TPU running environment\n", 131 | "* Verify that you are connected to a TPU device\n", 132 | "* Upload your credentials to TPU to access your GCS bucket." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "metadata": { 138 | "id": "191zq3ZErihP", 139 | "colab_type": "code", 140 | "colab": {} 141 | }, 142 | "source": [ 143 | "# TODO(lanzhzh): Add support for 2.x.\n", 144 | "%tensorflow_version 1.x\n", 145 | "import os\n", 146 | "import pprint\n", 147 | "import json\n", 148 | "import tensorflow as tf\n", 149 | "\n", 150 | "assert \"COLAB_TPU_ADDR\" in os.environ, \"ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!\"\n", 151 | "TPU_ADDRESS = \"grpc://\" + os.environ[\"COLAB_TPU_ADDR\"] \n", 152 | "TPU_TOPOLOGY = \"2x2\"\n", 153 | "print(\"TPU address is\", TPU_ADDRESS)\n", 154 | "\n", 155 | "from google.colab import auth\n", 156 | "auth.authenticate_user()\n", 157 | "with tf.Session(TPU_ADDRESS) as session:\n", 158 | " print('TPU devices:')\n", 159 | " pprint.pprint(session.list_devices())\n", 160 | "\n", 161 | " # Upload credentials to TPU.\n", 162 | " with open('/content/adc.json', 'r') as f:\n", 163 | " auth_info = json.load(f)\n", 164 | " tf.contrib.cloud.configure_gcs(session, credentials=auth_info)\n", 165 | " # Now credentials are set for all future sessions on this TPU." 166 | ], 167 | "execution_count": 0, 168 | "outputs": [] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": { 173 | "id": "HUBP35oCDmbF", 174 | "colab_type": "text" 175 | }, 176 | "source": [ 177 | "### Prepare and import ALBERT modules\n", 178 | "​\n", 179 | "With your environment configured, you can now prepare and import the ALBERT modules. The following step clones the source code from GitHub." 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "metadata": { 185 | "id": "7wzwke0sxS6W", 186 | "colab_type": "code", 187 | "colab": {}, 188 | "cellView": "code" 189 | }, 190 | "source": [ 191 | "#TODO(lanzhzh): Add pip support\n", 192 | "import sys\n", 193 | "\n", 194 | "!test -d albert || git clone https://github.com/google-research/albert albert\n", 195 | "if not 'albert' in sys.path:\n", 196 | " sys.path += ['albert']\n", 197 | " \n", 198 | "!pip install sentencepiece\n" 199 | ], 200 | "execution_count": 0, 201 | "outputs": [] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": { 206 | "id": "RRu1aKO1D7-Z", 207 | "colab_type": "text" 208 | }, 209 | "source": [ 210 | "### Prepare for training\n", 211 | "\n", 212 | "This next section of code performs the following tasks:\n", 213 | "\n", 214 | "* Specify GS bucket, create output directory for model checkpoints and eval results.\n", 215 | "* Specify task and download training data.\n", 216 | "* Specify ALBERT pretrained model\n", 217 | "\n", 218 | "\n", 219 | "\n" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "metadata": { 225 | "id": "tYkaAlJNfhul", 226 | "colab_type": "code", 227 | "colab": {}, 228 | "cellView": "form" 229 | }, 230 | "source": [ 231 | "# Please find the full list of tasks and their fintuning hyperparameters\n", 232 | "# here https://github.com/google-research/albert/blob/master/run_glue.sh\n", 233 | "\n", 234 | "BUCKET = \"albert_tutorial_glue\" #@param { type: \"string\" }\n", 235 | "TASK = 'MRPC' #@param {type:\"string\"}\n", 236 | "# Available pretrained model checkpoints:\n", 237 | "# base, large, xlarge, xxlarge\n", 238 | "ALBERT_MODEL = 'base' #@param {type:\"string\"}\n", 239 | "\n", 240 | "TASK_DATA_DIR = 'glue_data'\n", 241 | "\n", 242 | "BASE_DIR = \"gs://\" + BUCKET\n", 243 | "if not BASE_DIR or BASE_DIR == \"gs://\":\n", 244 | " raise ValueError(\"You must enter a BUCKET.\")\n", 245 | "DATA_DIR = os.path.join(BASE_DIR, \"data\")\n", 246 | "MODELS_DIR = os.path.join(BASE_DIR, \"models\")\n", 247 | "OUTPUT_DIR = 'gs://{}/albert-tfhub/models/{}'.format(BUCKET, TASK)\n", 248 | "tf.gfile.MakeDirs(OUTPUT_DIR)\n", 249 | "print('***** Model output directory: {} *****'.format(OUTPUT_DIR))\n", 250 | "\n", 251 | "# Download glue data.\n", 252 | "! test -d download_glue_repo || git clone https://gist.github.com/60c2bdb54d156a41194446737ce03e2e.git download_glue_repo\n", 253 | "!python download_glue_repo/download_glue_data.py --data_dir=$TASK_DATA_DIR --tasks=$TASK\n", 254 | "print('***** Task data directory: {} *****'.format(TASK_DATA_DIR))\n", 255 | "\n", 256 | "ALBERT_MODEL_HUB = 'https://tfhub.dev/google/albert_' + ALBERT_MODEL + '/3'" 257 | ], 258 | "execution_count": 0, 259 | "outputs": [] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": { 264 | "id": "Hcpfl4N2EdOk", 265 | "colab_type": "text" 266 | }, 267 | "source": [ 268 | "Now let's run the fine-tuning scripts. If you use the default MRPC task, this should be finished in around 10 mintues and you will get an accuracy of around 86.5." 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "metadata": { 274 | "id": "o8qXPxv8-kBO", 275 | "colab_type": "code", 276 | "colab": {} 277 | }, 278 | "source": [ 279 | "os.environ['TFHUB_CACHE_DIR'] = OUTPUT_DIR\n", 280 | "!python -m albert.run_classifier \\\n", 281 | " --data_dir=\"glue_data/\" \\\n", 282 | " --output_dir=$OUTPUT_DIR \\\n", 283 | " --albert_hub_module_handle=$ALBERT_MODEL_HUB \\\n", 284 | " --spm_model_file=\"from_tf_hub\" \\\n", 285 | " --do_train=True \\\n", 286 | " --do_eval=True \\\n", 287 | " --do_predict=False \\\n", 288 | " --max_seq_length=512 \\\n", 289 | " --optimizer=adamw \\\n", 290 | " --task_name=$TASK \\\n", 291 | " --warmup_step=200 \\\n", 292 | " --learning_rate=2e-5 \\\n", 293 | " --train_step=800 \\\n", 294 | " --save_checkpoints_steps=100 \\\n", 295 | " --train_batch_size=32 \\\n", 296 | " --tpu_name=$TPU_ADDRESS \\\n", 297 | " --use_tpu=True" 298 | ], 299 | "execution_count": 0, 300 | "outputs": [] 301 | } 302 | ] 303 | } 304 | -------------------------------------------------------------------------------- /export_checkpoints.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Exports a minimal module for ALBERT models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import os 21 | from absl import app 22 | from absl import flags 23 | from albert import modeling 24 | import tensorflow.compat.v1 as tf 25 | 26 | flags.DEFINE_string( 27 | "albert_directory", None, 28 | "The config json file corresponding to the pre-trained ALBERT model. " 29 | "This specifies the model architecture.") 30 | 31 | flags.DEFINE_string( 32 | "checkpoint_name", "model.ckpt-best", 33 | "Name of the checkpoint under albert_directory to be exported.") 34 | 35 | flags.DEFINE_bool( 36 | "do_lower_case", True, 37 | "Whether to lower case the input text. Should be True for uncased " 38 | "models and False for cased models.") 39 | 40 | flags.DEFINE_string("export_path", None, "Path to the output module.") 41 | 42 | FLAGS = flags.FLAGS 43 | 44 | 45 | def gather_indexes(sequence_tensor, positions): 46 | """Gathers the vectors at the specific positions over a minibatch.""" 47 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 48 | batch_size = sequence_shape[0] 49 | seq_length = sequence_shape[1] 50 | width = sequence_shape[2] 51 | 52 | flat_offsets = tf.reshape( 53 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 54 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 55 | flat_sequence_tensor = tf.reshape(sequence_tensor, 56 | [batch_size * seq_length, width]) 57 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 58 | return output_tensor 59 | 60 | 61 | def get_mlm_logits(input_tensor, albert_config, mlm_positions, output_weights): 62 | """From run_pretraining.py.""" 63 | input_tensor = gather_indexes(input_tensor, mlm_positions) 64 | with tf.variable_scope("cls/predictions"): 65 | # We apply one more non-linear transformation before the output layer. 66 | # This matrix is not used after pre-training. 67 | with tf.variable_scope("transform"): 68 | input_tensor = tf.layers.dense( 69 | input_tensor, 70 | units=albert_config.embedding_size, 71 | activation=modeling.get_activation(albert_config.hidden_act), 72 | kernel_initializer=modeling.create_initializer( 73 | albert_config.initializer_range)) 74 | input_tensor = modeling.layer_norm(input_tensor) 75 | 76 | # The output weights are the same as the input embeddings, but there is 77 | # an output-only bias for each token. 78 | output_bias = tf.get_variable( 79 | "output_bias", 80 | shape=[albert_config.vocab_size], 81 | initializer=tf.zeros_initializer()) 82 | logits = tf.matmul( 83 | input_tensor, output_weights, transpose_b=True) 84 | logits = tf.nn.bias_add(logits, output_bias) 85 | return logits 86 | 87 | 88 | def get_sentence_order_logits(input_tensor, albert_config): 89 | """Get loss and log probs for the next sentence prediction.""" 90 | 91 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 92 | # "random sentence". This weight matrix is not used after pre-training. 93 | with tf.variable_scope("cls/seq_relationship"): 94 | output_weights = tf.get_variable( 95 | "output_weights", 96 | shape=[2, albert_config.hidden_size], 97 | initializer=modeling.create_initializer( 98 | albert_config.initializer_range)) 99 | output_bias = tf.get_variable( 100 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 101 | 102 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 103 | logits = tf.nn.bias_add(logits, output_bias) 104 | return logits 105 | 106 | 107 | def build_model(sess): 108 | """Module function.""" 109 | input_ids = tf.placeholder(tf.int32, [None, None], "input_ids") 110 | input_mask = tf.placeholder(tf.int32, [None, None], "input_mask") 111 | segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids") 112 | mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions") 113 | 114 | albert_config_path = os.path.join( 115 | FLAGS.albert_directory, "albert_config.json") 116 | albert_config = modeling.AlbertConfig.from_json_file(albert_config_path) 117 | model = modeling.AlbertModel( 118 | config=albert_config, 119 | is_training=False, 120 | input_ids=input_ids, 121 | input_mask=input_mask, 122 | token_type_ids=segment_ids, 123 | use_one_hot_embeddings=False) 124 | 125 | get_mlm_logits(model.get_sequence_output(), albert_config, 126 | mlm_positions, model.get_embedding_table()) 127 | get_sentence_order_logits(model.get_pooled_output(), albert_config) 128 | 129 | checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name) 130 | tvars = tf.trainable_variables() 131 | (assignment_map, initialized_variable_names 132 | ) = modeling.get_assignment_map_from_checkpoint(tvars, checkpoint_path) 133 | 134 | tf.logging.info("**** Trainable Variables ****") 135 | for var in tvars: 136 | init_string = "" 137 | if var.name in initialized_variable_names: 138 | init_string = ", *INIT_FROM_CKPT*" 139 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 140 | init_string) 141 | tf.train.init_from_checkpoint(checkpoint_path, assignment_map) 142 | init = tf.global_variables_initializer() 143 | sess.run(init) 144 | return sess 145 | 146 | 147 | def main(_): 148 | sess = tf.Session() 149 | tf.train.get_or_create_global_step() 150 | sess = build_model(sess) 151 | my_vars = [] 152 | for var in tf.global_variables(): 153 | if "lamb_v" not in var.name and "lamb_m" not in var.name: 154 | my_vars.append(var) 155 | saver = tf.train.Saver(my_vars) 156 | saver.save(sess, FLAGS.export_path) 157 | 158 | 159 | if __name__ == "__main__": 160 | flags.mark_flag_as_required("albert_directory") 161 | flags.mark_flag_as_required("export_path") 162 | app.run(main) 163 | -------------------------------------------------------------------------------- /export_to_tfhub.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Exports a minimal TF-Hub module for ALBERT models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import os 21 | from absl import app 22 | from absl import flags 23 | from albert import modeling 24 | import tensorflow.compat.v1 as tf 25 | import tensorflow_hub as hub 26 | 27 | flags.DEFINE_string( 28 | "albert_directory", None, 29 | "The config json file corresponding to the pre-trained ALBERT model. " 30 | "This specifies the model architecture.") 31 | 32 | flags.DEFINE_string( 33 | "checkpoint_name", "model.ckpt-best", 34 | "Name of the checkpoint under albert_directory to be exported.") 35 | 36 | flags.DEFINE_bool( 37 | "do_lower_case", True, 38 | "Whether to lower case the input text. Should be True for uncased " 39 | "models and False for cased models.") 40 | 41 | flags.DEFINE_bool( 42 | "use_einsum", True, 43 | "Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must " 44 | "be set to False for TFLite compatibility.") 45 | 46 | flags.DEFINE_string("export_path", None, "Path to the output TF-Hub module.") 47 | 48 | FLAGS = flags.FLAGS 49 | 50 | 51 | def gather_indexes(sequence_tensor, positions): 52 | """Gathers the vectors at the specific positions over a minibatch.""" 53 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 54 | batch_size = sequence_shape[0] 55 | seq_length = sequence_shape[1] 56 | width = sequence_shape[2] 57 | 58 | flat_offsets = tf.reshape( 59 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 60 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 61 | flat_sequence_tensor = tf.reshape(sequence_tensor, 62 | [batch_size * seq_length, width]) 63 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 64 | return output_tensor 65 | 66 | 67 | def get_mlm_logits(model, albert_config, mlm_positions): 68 | """From run_pretraining.py.""" 69 | input_tensor = gather_indexes(model.get_sequence_output(), mlm_positions) 70 | with tf.variable_scope("cls/predictions"): 71 | # We apply one more non-linear transformation before the output layer. 72 | # This matrix is not used after pre-training. 73 | with tf.variable_scope("transform"): 74 | input_tensor = tf.layers.dense( 75 | input_tensor, 76 | units=albert_config.embedding_size, 77 | activation=modeling.get_activation(albert_config.hidden_act), 78 | kernel_initializer=modeling.create_initializer( 79 | albert_config.initializer_range)) 80 | input_tensor = modeling.layer_norm(input_tensor) 81 | 82 | # The output weights are the same as the input embeddings, but there is 83 | # an output-only bias for each token. 84 | output_bias = tf.get_variable( 85 | "output_bias", 86 | shape=[albert_config.vocab_size], 87 | initializer=tf.zeros_initializer()) 88 | logits = tf.matmul( 89 | input_tensor, model.get_embedding_table(), transpose_b=True) 90 | logits = tf.nn.bias_add(logits, output_bias) 91 | return logits 92 | 93 | 94 | def get_sop_log_probs(model, albert_config): 95 | """Get loss and log probs for the next sentence prediction.""" 96 | input_tensor = model.get_pooled_output() 97 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 98 | # "random sentence". This weight matrix is not used after pre-training. 99 | with tf.variable_scope("cls/seq_relationship"): 100 | output_weights = tf.get_variable( 101 | "output_weights", 102 | shape=[2, albert_config.hidden_size], 103 | initializer=modeling.create_initializer( 104 | albert_config.initializer_range)) 105 | output_bias = tf.get_variable( 106 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 107 | 108 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 109 | logits = tf.nn.bias_add(logits, output_bias) 110 | log_probs = tf.nn.log_softmax(logits, axis=-1) 111 | return log_probs 112 | 113 | 114 | def module_fn(is_training): 115 | """Module function.""" 116 | input_ids = tf.placeholder(tf.int32, [None, None], "input_ids") 117 | input_mask = tf.placeholder(tf.int32, [None, None], "input_mask") 118 | segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids") 119 | mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions") 120 | 121 | albert_config_path = os.path.join( 122 | FLAGS.albert_directory, "albert_config.json") 123 | albert_config = modeling.AlbertConfig.from_json_file(albert_config_path) 124 | model = modeling.AlbertModel( 125 | config=albert_config, 126 | is_training=is_training, 127 | input_ids=input_ids, 128 | input_mask=input_mask, 129 | token_type_ids=segment_ids, 130 | use_one_hot_embeddings=False, 131 | use_einsum=FLAGS.use_einsum) 132 | 133 | mlm_logits = get_mlm_logits(model, albert_config, mlm_positions) 134 | sop_log_probs = get_sop_log_probs(model, albert_config) 135 | 136 | vocab_model_path = os.path.join(FLAGS.albert_directory, "30k-clean.model") 137 | vocab_file_path = os.path.join(FLAGS.albert_directory, "30k-clean.vocab") 138 | 139 | config_file = tf.constant( 140 | value=albert_config_path, dtype=tf.string, name="config_file") 141 | vocab_model = tf.constant( 142 | value=vocab_model_path, dtype=tf.string, name="vocab_model") 143 | # This is only for visualization purpose. 144 | vocab_file = tf.constant( 145 | value=vocab_file_path, dtype=tf.string, name="vocab_file") 146 | 147 | # By adding `config_file, vocab_model and vocab_file` 148 | # to the ASSET_FILEPATHS collection, TF-Hub will 149 | # rewrite this tensor so that this asset is portable. 150 | tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, config_file) 151 | tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_model) 152 | tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file) 153 | 154 | hub.add_signature( 155 | name="tokens", 156 | inputs=dict( 157 | input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids), 158 | outputs=dict( 159 | sequence_output=model.get_sequence_output(), 160 | pooled_output=model.get_pooled_output())) 161 | 162 | hub.add_signature( 163 | name="sop", 164 | inputs=dict( 165 | input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids), 166 | outputs=dict( 167 | sequence_output=model.get_sequence_output(), 168 | pooled_output=model.get_pooled_output(), 169 | sop_log_probs=sop_log_probs)) 170 | 171 | hub.add_signature( 172 | name="mlm", 173 | inputs=dict( 174 | input_ids=input_ids, 175 | input_mask=input_mask, 176 | segment_ids=segment_ids, 177 | mlm_positions=mlm_positions), 178 | outputs=dict( 179 | sequence_output=model.get_sequence_output(), 180 | pooled_output=model.get_pooled_output(), 181 | mlm_logits=mlm_logits)) 182 | 183 | hub.add_signature( 184 | name="tokenization_info", 185 | inputs={}, 186 | outputs=dict( 187 | vocab_file=vocab_model, 188 | do_lower_case=tf.constant(FLAGS.do_lower_case))) 189 | 190 | 191 | def main(_): 192 | tags_and_args = [] 193 | for is_training in (True, False): 194 | tags = set() 195 | if is_training: 196 | tags.add("train") 197 | tags_and_args.append((tags, dict(is_training=is_training))) 198 | spec = hub.create_module_spec(module_fn, tags_and_args=tags_and_args) 199 | checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name) 200 | tf.logging.info("Using checkpoint {}".format(checkpoint_path)) 201 | spec.export(FLAGS.export_path, checkpoint_path=checkpoint_path) 202 | 203 | 204 | if __name__ == "__main__": 205 | flags.mark_flag_as_required("albert_directory") 206 | flags.mark_flag_as_required("export_path") 207 | app.run(main) 208 | -------------------------------------------------------------------------------- /fine_tuning_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Helper library for ALBERT fine-tuning. 16 | 17 | This library can be used to construct ALBERT models for fine-tuning, either from 18 | json config files or from TF-Hub modules. 19 | """ 20 | 21 | from albert import modeling 22 | from albert import tokenization 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_hub as hub 25 | 26 | 27 | def _create_model_from_hub(hub_module, is_training, input_ids, input_mask, 28 | segment_ids): 29 | """Creates an ALBERT model from TF-Hub.""" 30 | tags = set() 31 | if is_training: 32 | tags.add("train") 33 | albert_module = hub.Module(hub_module, tags=tags, trainable=True) 34 | albert_inputs = dict( 35 | input_ids=input_ids, 36 | input_mask=input_mask, 37 | segment_ids=segment_ids) 38 | albert_outputs = albert_module( 39 | inputs=albert_inputs, 40 | signature="tokens", 41 | as_dict=True) 42 | return (albert_outputs["pooled_output"], albert_outputs["sequence_output"]) 43 | 44 | 45 | def _create_model_from_scratch(albert_config, is_training, input_ids, 46 | input_mask, segment_ids, use_one_hot_embeddings, 47 | use_einsum): 48 | """Creates an ALBERT model from scratch/config.""" 49 | model = modeling.AlbertModel( 50 | config=albert_config, 51 | is_training=is_training, 52 | input_ids=input_ids, 53 | input_mask=input_mask, 54 | token_type_ids=segment_ids, 55 | use_one_hot_embeddings=use_one_hot_embeddings, 56 | use_einsum=use_einsum) 57 | return (model.get_pooled_output(), model.get_sequence_output()) 58 | 59 | 60 | def create_albert(albert_config, is_training, input_ids, input_mask, 61 | segment_ids, use_one_hot_embeddings, use_einsum, hub_module): 62 | """Creates an ALBERT, either from TF-Hub or from scratch.""" 63 | if hub_module: 64 | tf.logging.info("creating model from hub_module: %s", hub_module) 65 | return _create_model_from_hub(hub_module, is_training, input_ids, 66 | input_mask, segment_ids) 67 | else: 68 | tf.logging.info("creating model from albert_config") 69 | return _create_model_from_scratch(albert_config, is_training, input_ids, 70 | input_mask, segment_ids, 71 | use_one_hot_embeddings, use_einsum) 72 | 73 | 74 | def create_vocab(vocab_file, do_lower_case, spm_model_file, hub_module): 75 | """Creates a vocab, either from vocab file or from a TF-Hub module.""" 76 | if hub_module: 77 | use_spm = True if spm_model_file else False 78 | return tokenization.FullTokenizer.from_hub_module( 79 | hub_module=hub_module, use_spm=use_spm) 80 | else: 81 | return tokenization.FullTokenizer.from_scratch( 82 | vocab_file=vocab_file, do_lower_case=do_lower_case, 83 | spm_model_file=spm_model_file) 84 | 85 | -------------------------------------------------------------------------------- /lamb_optimizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import six 23 | import tensorflow.compat.v1 as tf 24 | 25 | # pylint: disable=g-direct-tensorflow-import 26 | from tensorflow.python.ops import array_ops 27 | from tensorflow.python.ops import linalg_ops 28 | from tensorflow.python.ops import math_ops 29 | # pylint: enable=g-direct-tensorflow-import 30 | 31 | 32 | class LAMBOptimizer(tf.train.Optimizer): 33 | """LAMB (Layer-wise Adaptive Moments optimizer for Batch training).""" 34 | # A new optimizer that includes correct L2 weight decay, adaptive 35 | # element-wise updating, and layer-wise justification. The LAMB optimizer 36 | # was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song, 37 | # James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT 38 | # Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962) 39 | 40 | def __init__(self, 41 | learning_rate, 42 | weight_decay_rate=0.0, 43 | beta_1=0.9, 44 | beta_2=0.999, 45 | epsilon=1e-6, 46 | exclude_from_weight_decay=None, 47 | exclude_from_layer_adaptation=None, 48 | name="LAMBOptimizer"): 49 | """Constructs a LAMBOptimizer.""" 50 | super(LAMBOptimizer, self).__init__(False, name) 51 | 52 | self.learning_rate = learning_rate 53 | self.weight_decay_rate = weight_decay_rate 54 | self.beta_1 = beta_1 55 | self.beta_2 = beta_2 56 | self.epsilon = epsilon 57 | self.exclude_from_weight_decay = exclude_from_weight_decay 58 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the 59 | # arg is None. 60 | # TODO(jingli): validate if exclude_from_layer_adaptation is necessary. 61 | if exclude_from_layer_adaptation: 62 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation 63 | else: 64 | self.exclude_from_layer_adaptation = exclude_from_weight_decay 65 | 66 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 67 | """See base class.""" 68 | assignments = [] 69 | for (grad, param) in grads_and_vars: 70 | if grad is None or param is None: 71 | continue 72 | 73 | param_name = self._get_variable_name(param.name) 74 | 75 | m = tf.get_variable( 76 | name=six.ensure_str(param_name) + "/adam_m", 77 | shape=param.shape.as_list(), 78 | dtype=tf.float32, 79 | trainable=False, 80 | initializer=tf.zeros_initializer()) 81 | v = tf.get_variable( 82 | name=six.ensure_str(param_name) + "/adam_v", 83 | shape=param.shape.as_list(), 84 | dtype=tf.float32, 85 | trainable=False, 86 | initializer=tf.zeros_initializer()) 87 | 88 | # Standard Adam update. 89 | next_m = ( 90 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 91 | next_v = ( 92 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 93 | tf.square(grad))) 94 | 95 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 96 | 97 | # Just adding the square of the weights to the loss function is *not* 98 | # the correct way of using L2 regularization/weight decay with Adam, 99 | # since that will interact with the m and v parameters in strange ways. 100 | # 101 | # Instead we want ot decay the weights in a manner that doesn't interact 102 | # with the m/v parameters. This is equivalent to adding the square 103 | # of the weights to the loss with plain (non-momentum) SGD. 104 | if self._do_use_weight_decay(param_name): 105 | update += self.weight_decay_rate * param 106 | 107 | ratio = 1.0 108 | if self._do_layer_adaptation(param_name): 109 | w_norm = linalg_ops.norm(param, ord=2) 110 | g_norm = linalg_ops.norm(update, ord=2) 111 | ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where( 112 | math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) 113 | 114 | update_with_lr = ratio * self.learning_rate * update 115 | 116 | next_param = param - update_with_lr 117 | 118 | assignments.extend( 119 | [param.assign(next_param), 120 | m.assign(next_m), 121 | v.assign(next_v)]) 122 | return tf.group(*assignments, name=name) 123 | 124 | def _do_use_weight_decay(self, param_name): 125 | """Whether to use L2 weight decay for `param_name`.""" 126 | if not self.weight_decay_rate: 127 | return False 128 | if self.exclude_from_weight_decay: 129 | for r in self.exclude_from_weight_decay: 130 | if re.search(r, param_name) is not None: 131 | return False 132 | return True 133 | 134 | def _do_layer_adaptation(self, param_name): 135 | """Whether to do layer-wise learning rate adaptation for `param_name`.""" 136 | if self.exclude_from_layer_adaptation: 137 | for r in self.exclude_from_layer_adaptation: 138 | if re.search(r, param_name) is not None: 139 | return False 140 | return True 141 | 142 | def _get_variable_name(self, param_name): 143 | """Get the variable name from the tensor name.""" 144 | m = re.match("^(.*):\\d+$", six.ensure_str(param_name)) 145 | if m is not None: 146 | param_name = m.group(1) 147 | return param_name 148 | -------------------------------------------------------------------------------- /modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | from albert import modeling 25 | import numpy as np 26 | import six 27 | from six.moves import range 28 | import tensorflow.compat.v1 as tf 29 | 30 | 31 | class AlbertModelTest(tf.test.TestCase): 32 | 33 | class AlbertModelTester(object): 34 | 35 | def __init__(self, 36 | parent, 37 | batch_size=13, 38 | seq_length=7, 39 | is_training=True, 40 | use_input_mask=True, 41 | use_token_type_ids=True, 42 | vocab_size=99, 43 | embedding_size=32, 44 | hidden_size=32, 45 | num_hidden_layers=5, 46 | num_attention_heads=4, 47 | intermediate_size=37, 48 | hidden_act="gelu", 49 | hidden_dropout_prob=0.1, 50 | attention_probs_dropout_prob=0.1, 51 | max_position_embeddings=512, 52 | type_vocab_size=16, 53 | initializer_range=0.02, 54 | scope=None): 55 | self.parent = parent 56 | self.batch_size = batch_size 57 | self.seq_length = seq_length 58 | self.is_training = is_training 59 | self.use_input_mask = use_input_mask 60 | self.use_token_type_ids = use_token_type_ids 61 | self.vocab_size = vocab_size 62 | self.embedding_size = embedding_size 63 | self.hidden_size = hidden_size 64 | self.num_hidden_layers = num_hidden_layers 65 | self.num_attention_heads = num_attention_heads 66 | self.intermediate_size = intermediate_size 67 | self.hidden_act = hidden_act 68 | self.hidden_dropout_prob = hidden_dropout_prob 69 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 70 | self.max_position_embeddings = max_position_embeddings 71 | self.type_vocab_size = type_vocab_size 72 | self.initializer_range = initializer_range 73 | self.scope = scope 74 | 75 | def create_model(self): 76 | input_ids = AlbertModelTest.ids_tensor([self.batch_size, self.seq_length], 77 | self.vocab_size) 78 | 79 | input_mask = None 80 | if self.use_input_mask: 81 | input_mask = AlbertModelTest.ids_tensor( 82 | [self.batch_size, self.seq_length], vocab_size=2) 83 | 84 | token_type_ids = None 85 | if self.use_token_type_ids: 86 | token_type_ids = AlbertModelTest.ids_tensor( 87 | [self.batch_size, self.seq_length], self.type_vocab_size) 88 | 89 | config = modeling.AlbertConfig( 90 | vocab_size=self.vocab_size, 91 | embedding_size=self.embedding_size, 92 | hidden_size=self.hidden_size, 93 | num_hidden_layers=self.num_hidden_layers, 94 | num_attention_heads=self.num_attention_heads, 95 | intermediate_size=self.intermediate_size, 96 | hidden_act=self.hidden_act, 97 | hidden_dropout_prob=self.hidden_dropout_prob, 98 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 99 | max_position_embeddings=self.max_position_embeddings, 100 | type_vocab_size=self.type_vocab_size, 101 | initializer_range=self.initializer_range) 102 | 103 | model = modeling.AlbertModel( 104 | config=config, 105 | is_training=self.is_training, 106 | input_ids=input_ids, 107 | input_mask=input_mask, 108 | token_type_ids=token_type_ids, 109 | scope=self.scope) 110 | 111 | outputs = { 112 | "embedding_output": model.get_embedding_output(), 113 | "sequence_output": model.get_sequence_output(), 114 | "pooled_output": model.get_pooled_output(), 115 | "all_encoder_layers": model.get_all_encoder_layers(), 116 | } 117 | return outputs 118 | 119 | def check_output(self, result): 120 | self.parent.assertAllEqual( 121 | result["embedding_output"].shape, 122 | [self.batch_size, self.seq_length, self.embedding_size]) 123 | 124 | self.parent.assertAllEqual( 125 | result["sequence_output"].shape, 126 | [self.batch_size, self.seq_length, self.hidden_size]) 127 | 128 | self.parent.assertAllEqual(result["pooled_output"].shape, 129 | [self.batch_size, self.hidden_size]) 130 | 131 | def test_default(self): 132 | self.run_tester(AlbertModelTest.AlbertModelTester(self)) 133 | 134 | def test_config_to_json_string(self): 135 | config = modeling.AlbertConfig(vocab_size=99, hidden_size=37) 136 | obj = json.loads(config.to_json_string()) 137 | self.assertEqual(obj["vocab_size"], 99) 138 | self.assertEqual(obj["hidden_size"], 37) 139 | 140 | def test_einsum_via_matmul(self): 141 | batch_size = 8 142 | seq_length = 12 143 | num_attention_heads = 3 144 | head_size = 6 145 | hidden_size = 10 146 | 147 | input_tensor = np.random.uniform(0, 1, 148 | [batch_size, seq_length, hidden_size]) 149 | input_tensor = tf.constant(input_tensor, dtype=tf.float32) 150 | w = np.random.uniform(0, 1, [hidden_size, num_attention_heads, head_size]) 151 | w = tf.constant(w, dtype=tf.float32) 152 | ret1 = tf.einsum("BFH,HND->BFND", input_tensor, w) 153 | ret2 = modeling.einsum_via_matmul(input_tensor, w, 1) 154 | self.assertAllClose(ret1, ret2) 155 | 156 | input_tensor = np.random.uniform(0, 1, 157 | [batch_size, seq_length, 158 | num_attention_heads, head_size]) 159 | input_tensor = tf.constant(input_tensor, dtype=tf.float32) 160 | w = np.random.uniform(0, 1, [num_attention_heads, head_size, hidden_size]) 161 | w = tf.constant(w, dtype=tf.float32) 162 | ret1 = tf.einsum("BFND,NDH->BFH", input_tensor, w) 163 | ret2 = modeling.einsum_via_matmul(input_tensor, w, 2) 164 | self.assertAllClose(ret1, ret2) 165 | 166 | def run_tester(self, tester): 167 | with self.test_session() as sess: 168 | ops = tester.create_model() 169 | init_op = tf.group(tf.global_variables_initializer(), 170 | tf.local_variables_initializer()) 171 | sess.run(init_op) 172 | output_result = sess.run(ops) 173 | tester.check_output(output_result) 174 | 175 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 176 | 177 | @classmethod 178 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 179 | """Creates a random int32 tensor of the shape within the vocab size.""" 180 | if rng is None: 181 | rng = random.Random() 182 | 183 | total_dims = 1 184 | for dim in shape: 185 | total_dims *= dim 186 | 187 | values = [] 188 | for _ in range(total_dims): 189 | values.append(rng.randint(0, vocab_size - 1)) 190 | 191 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 192 | 193 | def assert_all_tensors_reachable(self, sess, outputs): 194 | """Checks that all the tensors in the graph are reachable from outputs.""" 195 | graph = sess.graph 196 | 197 | ignore_strings = [ 198 | "^.*/assert_less_equal/.*$", 199 | "^.*/dilation_rate$", 200 | "^.*/Tensordot/concat$", 201 | "^.*/Tensordot/concat/axis$", 202 | "^testing/.*$", 203 | ] 204 | 205 | ignore_regexes = [re.compile(x) for x in ignore_strings] 206 | 207 | unreachable = self.get_unreachable_ops(graph, outputs) 208 | filtered_unreachable = [] 209 | for x in unreachable: 210 | do_ignore = False 211 | for r in ignore_regexes: 212 | m = r.match(six.ensure_str(x.name)) 213 | if m is not None: 214 | do_ignore = True 215 | if do_ignore: 216 | continue 217 | filtered_unreachable.append(x) 218 | unreachable = filtered_unreachable 219 | 220 | self.assertEqual( 221 | len(unreachable), 0, "The following ops are unreachable: %s" % 222 | (" ".join([x.name for x in unreachable]))) 223 | 224 | @classmethod 225 | def get_unreachable_ops(cls, graph, outputs): 226 | """Finds all of the tensors in graph that are unreachable from outputs.""" 227 | outputs = cls.flatten_recursive(outputs) 228 | output_to_op = collections.defaultdict(list) 229 | op_to_all = collections.defaultdict(list) 230 | assign_out_to_in = collections.defaultdict(list) 231 | 232 | for op in graph.get_operations(): 233 | for x in op.inputs: 234 | op_to_all[op.name].append(x.name) 235 | for y in op.outputs: 236 | output_to_op[y.name].append(op.name) 237 | op_to_all[op.name].append(y.name) 238 | if str(op.type) == "Assign": 239 | for y in op.outputs: 240 | for x in op.inputs: 241 | assign_out_to_in[y.name].append(x.name) 242 | 243 | assign_groups = collections.defaultdict(list) 244 | for out_name in assign_out_to_in.keys(): 245 | name_group = assign_out_to_in[out_name] 246 | for n1 in name_group: 247 | assign_groups[n1].append(out_name) 248 | for n2 in name_group: 249 | if n1 != n2: 250 | assign_groups[n1].append(n2) 251 | 252 | seen_tensors = {} 253 | stack = [x.name for x in outputs] 254 | while stack: 255 | name = stack.pop() 256 | if name in seen_tensors: 257 | continue 258 | seen_tensors[name] = True 259 | 260 | if name in output_to_op: 261 | for op_name in output_to_op[name]: 262 | if op_name in op_to_all: 263 | for input_name in op_to_all[op_name]: 264 | if input_name not in stack: 265 | stack.append(input_name) 266 | 267 | expanded_names = [] 268 | if name in assign_groups: 269 | for assign_name in assign_groups[name]: 270 | expanded_names.append(assign_name) 271 | 272 | for expanded_name in expanded_names: 273 | if expanded_name not in stack: 274 | stack.append(expanded_name) 275 | 276 | unreachable_ops = [] 277 | for op in graph.get_operations(): 278 | is_unreachable = False 279 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 280 | for name in all_names: 281 | if name not in seen_tensors: 282 | is_unreachable = True 283 | if is_unreachable: 284 | unreachable_ops.append(op) 285 | return unreachable_ops 286 | 287 | @classmethod 288 | def flatten_recursive(cls, item): 289 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 290 | output = [] 291 | if isinstance(item, list): 292 | output.extend(item) 293 | elif isinstance(item, tuple): 294 | output.extend(list(item)) 295 | elif isinstance(item, dict): 296 | for (_, v) in six.iteritems(item): 297 | output.append(v) 298 | else: 299 | return [item] 300 | 301 | flat_output = [] 302 | for x in output: 303 | flat_output.extend(cls.flatten_recursive(x)) 304 | return flat_output 305 | 306 | 307 | if __name__ == "__main__": 308 | tf.test.main() 309 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import re 21 | from albert import lamb_optimizer 22 | import six 23 | from six.moves import zip 24 | import tensorflow.compat.v1 as tf 25 | from tensorflow.contrib import tpu as contrib_tpu 26 | 27 | 28 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, 29 | optimizer="adamw", poly_power=1.0, start_warmup_step=0, 30 | colocate_gradients_with_ops=False, excluded_tvars=None): 31 | """Creates an optimizer training op.""" 32 | global_step = tf.train.get_or_create_global_step() 33 | 34 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 35 | 36 | # Implements linear decay of the learning rate. 37 | learning_rate = tf.train.polynomial_decay( 38 | learning_rate, 39 | global_step, 40 | num_train_steps, 41 | end_learning_rate=0.0, 42 | power=poly_power, 43 | cycle=False) 44 | 45 | # Implements linear warmup. I.e., if global_step - start_warmup_step < 46 | # num_warmup_steps, the learning rate will be 47 | # `(global_step - start_warmup_step)/num_warmup_steps * init_lr`. 48 | if num_warmup_steps: 49 | tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step) 50 | + ", for " + str(num_warmup_steps) + " steps ++++++") 51 | global_steps_int = tf.cast(global_step, tf.int32) 52 | start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32) 53 | global_steps_int = global_steps_int - start_warm_int 54 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 55 | 56 | global_steps_float = tf.cast(global_steps_int, tf.float32) 57 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 58 | 59 | warmup_percent_done = global_steps_float / warmup_steps_float 60 | warmup_learning_rate = init_lr * warmup_percent_done 61 | 62 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 63 | learning_rate = ( 64 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 65 | 66 | # It is OK that you use this optimizer for finetuning, since this 67 | # is how the model was trained (note that the Adam m/v variables are NOT 68 | # loaded from init_checkpoint.) 69 | # It is OK to use AdamW in the finetuning even the model is trained by LAMB. 70 | # As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune 71 | # is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a 72 | # batch size of 64 in the finetune. 73 | if optimizer == "adamw": 74 | tf.logging.info("using adamw") 75 | optimizer = AdamWeightDecayOptimizer( 76 | learning_rate=learning_rate, 77 | weight_decay_rate=0.01, 78 | beta_1=0.9, 79 | beta_2=0.999, 80 | epsilon=1e-6, 81 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 82 | elif optimizer == "lamb": 83 | tf.logging.info("using lamb") 84 | optimizer = lamb_optimizer.LAMBOptimizer( 85 | learning_rate=learning_rate, 86 | weight_decay_rate=0.01, 87 | beta_1=0.9, 88 | beta_2=0.999, 89 | epsilon=1e-6, 90 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 91 | else: 92 | raise ValueError("Not supported optimizer: ", optimizer) 93 | 94 | if use_tpu: 95 | optimizer = contrib_tpu.CrossShardOptimizer(optimizer) 96 | 97 | tvars = tf.trainable_variables() 98 | for tvar in tvars: 99 | if excluded_tvars and tvar.name in excluded_tvars: 100 | tvars.remove(tvar) 101 | 102 | grads = tf.gradients( 103 | loss, tvars, colocate_gradients_with_ops=colocate_gradients_with_ops) 104 | 105 | # This is how the model was pre-trained. 106 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 107 | 108 | train_op = optimizer.apply_gradients( 109 | list(zip(grads, tvars)), global_step=global_step) 110 | 111 | # Normally the global step update is done inside of `apply_gradients`. 112 | # However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this. 113 | # But if you use a different optimizer, you should probably take this line 114 | # out. 115 | new_global_step = global_step + 1 116 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 117 | return train_op 118 | 119 | 120 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 121 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 122 | 123 | def __init__(self, 124 | learning_rate, 125 | weight_decay_rate=0.0, 126 | beta_1=0.9, 127 | beta_2=0.999, 128 | epsilon=1e-6, 129 | exclude_from_weight_decay=None, 130 | name="AdamWeightDecayOptimizer"): 131 | """Constructs a AdamWeightDecayOptimizer.""" 132 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 133 | 134 | self.learning_rate = learning_rate 135 | self.weight_decay_rate = weight_decay_rate 136 | self.beta_1 = beta_1 137 | self.beta_2 = beta_2 138 | self.epsilon = epsilon 139 | self.exclude_from_weight_decay = exclude_from_weight_decay 140 | 141 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 142 | """See base class.""" 143 | assignments = [] 144 | for (grad, param) in grads_and_vars: 145 | if grad is None or param is None: 146 | continue 147 | 148 | param_name = self._get_variable_name(param.name) 149 | 150 | m = tf.get_variable( 151 | name=six.ensure_str(param_name) + "/adam_m", 152 | shape=param.shape.as_list(), 153 | dtype=tf.float32, 154 | trainable=False, 155 | initializer=tf.zeros_initializer()) 156 | v = tf.get_variable( 157 | name=six.ensure_str(param_name) + "/adam_v", 158 | shape=param.shape.as_list(), 159 | dtype=tf.float32, 160 | trainable=False, 161 | initializer=tf.zeros_initializer()) 162 | 163 | # Standard Adam update. 164 | next_m = ( 165 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 166 | next_v = ( 167 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 168 | tf.square(grad))) 169 | 170 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 171 | 172 | # Just adding the square of the weights to the loss function is *not* 173 | # the correct way of using L2 regularization/weight decay with Adam, 174 | # since that will interact with the m and v parameters in strange ways. 175 | # 176 | # Instead we want ot decay the weights in a manner that doesn't interact 177 | # with the m/v parameters. This is equivalent to adding the square 178 | # of the weights to the loss with plain (non-momentum) SGD. 179 | if self._do_use_weight_decay(param_name): 180 | update += self.weight_decay_rate * param 181 | 182 | update_with_lr = self.learning_rate * update 183 | 184 | next_param = param - update_with_lr 185 | 186 | assignments.extend( 187 | [param.assign(next_param), 188 | m.assign(next_m), 189 | v.assign(next_v)]) 190 | return tf.group(*assignments, name=name) 191 | 192 | def _do_use_weight_decay(self, param_name): 193 | """Whether to use L2 weight decay for `param_name`.""" 194 | if not self.weight_decay_rate: 195 | return False 196 | if self.exclude_from_weight_decay: 197 | for r in self.exclude_from_weight_decay: 198 | if re.search(r, param_name) is not None: 199 | return False 200 | return True 201 | 202 | def _get_variable_name(self, param_name): 203 | """Get the variable name from the tensor name.""" 204 | m = re.match("^(.*):\\d+$", six.ensure_str(param_name)) 205 | if m is not None: 206 | param_name = m.group(1) 207 | return param_name 208 | -------------------------------------------------------------------------------- /optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | from albert import optimization 19 | from six.moves import range 20 | from six.moves import zip 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | class OptimizationTest(tf.test.TestCase): 25 | 26 | def test_adam(self): 27 | with self.test_session() as sess: 28 | w = tf.get_variable( 29 | "w", 30 | shape=[3], 31 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 32 | x = tf.constant([0.4, 0.2, -0.5]) 33 | loss = tf.reduce_mean(tf.square(x - w)) 34 | tvars = tf.trainable_variables() 35 | grads = tf.gradients(loss, tvars) 36 | global_step = tf.train.get_or_create_global_step() 37 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 38 | train_op = optimizer.apply_gradients(list(zip(grads, tvars)), global_step) 39 | init_op = tf.group(tf.global_variables_initializer(), 40 | tf.local_variables_initializer()) 41 | sess.run(init_op) 42 | for _ in range(100): 43 | sess.run(train_op) 44 | w_np = sess.run(w) 45 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 46 | 47 | 48 | if __name__ == "__main__": 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /race_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utility functions for RACE dataset.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import json 23 | import os 24 | from albert import classifier_utils 25 | from albert import fine_tuning_utils 26 | from albert import modeling 27 | from albert import optimization 28 | from albert import tokenization 29 | import tensorflow.compat.v1 as tf 30 | from tensorflow.compat.v1 import estimator as tf_estimator 31 | from tensorflow.contrib import tpu as contrib_tpu 32 | 33 | 34 | class InputExample(object): 35 | """A single training/test example for the RACE dataset.""" 36 | 37 | def __init__(self, 38 | example_id, 39 | context_sentence, 40 | start_ending, 41 | endings, 42 | label=None): 43 | self.example_id = example_id 44 | self.context_sentence = context_sentence 45 | self.start_ending = start_ending 46 | self.endings = endings 47 | self.label = label 48 | 49 | def __str__(self): 50 | return self.__repr__() 51 | 52 | def __repr__(self): 53 | l = [ 54 | "id: {}".format(self.example_id), 55 | "context_sentence: {}".format(self.context_sentence), 56 | "start_ending: {}".format(self.start_ending), 57 | "ending_0: {}".format(self.endings[0]), 58 | "ending_1: {}".format(self.endings[1]), 59 | "ending_2: {}".format(self.endings[2]), 60 | "ending_3: {}".format(self.endings[3]), 61 | ] 62 | 63 | if self.label is not None: 64 | l.append("label: {}".format(self.label)) 65 | 66 | return ", ".join(l) 67 | 68 | 69 | class RaceProcessor(object): 70 | """Processor for the RACE data set.""" 71 | 72 | def __init__(self, use_spm, do_lower_case, high_only, middle_only): 73 | super(RaceProcessor, self).__init__() 74 | self.use_spm = use_spm 75 | self.do_lower_case = do_lower_case 76 | self.high_only = high_only 77 | self.middle_only = middle_only 78 | 79 | def get_train_examples(self, data_dir): 80 | """Gets a collection of `InputExample`s for the train set.""" 81 | return self.read_examples( 82 | os.path.join(data_dir, "RACE", "train")) 83 | 84 | def get_dev_examples(self, data_dir): 85 | """Gets a collection of `InputExample`s for the dev set.""" 86 | return self.read_examples( 87 | os.path.join(data_dir, "RACE", "dev")) 88 | 89 | def get_test_examples(self, data_dir): 90 | """Gets a collection of `InputExample`s for prediction.""" 91 | return self.read_examples( 92 | os.path.join(data_dir, "RACE", "test")) 93 | 94 | def get_labels(self): 95 | """Gets the list of labels for this data set.""" 96 | return ["A", "B", "C", "D"] 97 | 98 | def process_text(self, text): 99 | if self.use_spm: 100 | return tokenization.preprocess_text(text, lower=self.do_lower_case) 101 | else: 102 | return tokenization.convert_to_unicode(text) 103 | 104 | def read_examples(self, data_dir): 105 | """Read examples from RACE json files.""" 106 | examples = [] 107 | for level in ["middle", "high"]: 108 | if level == "middle" and self.high_only: continue 109 | if level == "high" and self.middle_only: continue 110 | cur_dir = os.path.join(data_dir, level) 111 | 112 | cur_path = os.path.join(cur_dir, "all.txt") 113 | with tf.gfile.Open(cur_path) as f: 114 | for line in f: 115 | cur_data = json.loads(line.strip()) 116 | 117 | answers = cur_data["answers"] 118 | options = cur_data["options"] 119 | questions = cur_data["questions"] 120 | context = self.process_text(cur_data["article"]) 121 | 122 | for i in range(len(answers)): 123 | label = ord(answers[i]) - ord("A") 124 | qa_list = [] 125 | 126 | question = self.process_text(questions[i]) 127 | for j in range(4): 128 | option = self.process_text(options[i][j]) 129 | 130 | if "_" in question: 131 | qa_cat = question.replace("_", option) 132 | else: 133 | qa_cat = " ".join([question, option]) 134 | 135 | qa_list.append(qa_cat) 136 | 137 | examples.append( 138 | InputExample( 139 | example_id=cur_data["id"], 140 | context_sentence=context, 141 | start_ending=None, 142 | endings=[qa_list[0], qa_list[1], qa_list[2], qa_list[3]], 143 | label=label 144 | ) 145 | ) 146 | 147 | return examples 148 | 149 | 150 | def convert_single_example(example_index, example, label_size, max_seq_length, 151 | tokenizer, max_qa_length): 152 | """Loads a data file into a list of `InputBatch`s.""" 153 | 154 | # RACE is a multiple choice task. To perform this task using AlBERT, 155 | # we will use the formatting proposed in "Improving Language 156 | # Understanding by Generative Pre-Training" and suggested by 157 | # @jacobdevlin-google in this issue 158 | # https://github.com/google-research/bert/issues/38. 159 | # 160 | # Each choice will correspond to a sample on which we run the 161 | # inference. For a given RACE example, we will create the 4 162 | # following inputs: 163 | # - [CLS] context [SEP] choice_1 [SEP] 164 | # - [CLS] context [SEP] choice_2 [SEP] 165 | # - [CLS] context [SEP] choice_3 [SEP] 166 | # - [CLS] context [SEP] choice_4 [SEP] 167 | # The model will output a single value for each input. To get the 168 | # final decision of the model, we will run a softmax over these 4 169 | # outputs. 170 | if isinstance(example, classifier_utils.PaddingInputExample): 171 | return classifier_utils.InputFeatures( 172 | example_id=0, 173 | input_ids=[[0] * max_seq_length] * label_size, 174 | input_mask=[[0] * max_seq_length] * label_size, 175 | segment_ids=[[0] * max_seq_length] * label_size, 176 | label_id=0, 177 | is_real_example=False) 178 | else: 179 | context_tokens = tokenizer.tokenize(example.context_sentence) 180 | if example.start_ending is not None: 181 | start_ending_tokens = tokenizer.tokenize(example.start_ending) 182 | 183 | all_input_tokens = [] 184 | all_input_ids = [] 185 | all_input_mask = [] 186 | all_segment_ids = [] 187 | for ending in example.endings: 188 | # We create a copy of the context tokens in order to be 189 | # able to shrink it according to ending_tokens 190 | context_tokens_choice = context_tokens[:] 191 | if example.start_ending is not None: 192 | ending_tokens = start_ending_tokens + tokenizer.tokenize(ending) 193 | else: 194 | ending_tokens = tokenizer.tokenize(ending) 195 | # Modifies `context_tokens_choice` and `ending_tokens` in 196 | # place so that the total length is less than the 197 | # specified length. Account for [CLS], [SEP], [SEP] with 198 | # "- 3" 199 | ending_tokens = ending_tokens[- max_qa_length:] 200 | 201 | if len(context_tokens_choice) + len(ending_tokens) > max_seq_length - 3: 202 | context_tokens_choice = context_tokens_choice[: ( 203 | max_seq_length - 3 - len(ending_tokens))] 204 | tokens = ["[CLS]"] + context_tokens_choice + ( 205 | ["[SEP]"] + ending_tokens + ["[SEP]"]) 206 | segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * ( 207 | len(ending_tokens) + 1) 208 | 209 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 210 | input_mask = [1] * len(input_ids) 211 | 212 | # Zero-pad up to the sequence length. 213 | padding = [0] * (max_seq_length - len(input_ids)) 214 | input_ids += padding 215 | input_mask += padding 216 | segment_ids += padding 217 | 218 | assert len(input_ids) == max_seq_length 219 | assert len(input_mask) == max_seq_length 220 | assert len(segment_ids) == max_seq_length 221 | 222 | all_input_tokens.append(tokens) 223 | all_input_ids.append(input_ids) 224 | all_input_mask.append(input_mask) 225 | all_segment_ids.append(segment_ids) 226 | 227 | label = example.label 228 | if example_index < 5: 229 | tf.logging.info("*** Example ***") 230 | tf.logging.info("id: {}".format(example.example_id)) 231 | for choice_idx, (tokens, input_ids, input_mask, segment_ids) in \ 232 | enumerate(zip(all_input_tokens, all_input_ids, all_input_mask, all_segment_ids)): 233 | tf.logging.info("choice: {}".format(choice_idx)) 234 | tf.logging.info("tokens: {}".format(" ".join(tokens))) 235 | tf.logging.info( 236 | "input_ids: {}".format(" ".join(map(str, input_ids)))) 237 | tf.logging.info( 238 | "input_mask: {}".format(" ".join(map(str, input_mask)))) 239 | tf.logging.info( 240 | "segment_ids: {}".format(" ".join(map(str, segment_ids)))) 241 | tf.logging.info("label: {}".format(label)) 242 | 243 | return classifier_utils.InputFeatures( 244 | example_id=example.example_id, 245 | input_ids=all_input_ids, 246 | input_mask=all_input_mask, 247 | segment_ids=all_segment_ids, 248 | label_id=label 249 | ) 250 | 251 | 252 | def file_based_convert_examples_to_features( 253 | examples, label_list, max_seq_length, tokenizer, 254 | output_file, max_qa_length): 255 | """Convert a set of `InputExample`s to a TFRecord file.""" 256 | 257 | writer = tf.python_io.TFRecordWriter(output_file) 258 | 259 | for (ex_index, example) in enumerate(examples): 260 | if ex_index % 10000 == 0: 261 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 262 | 263 | feature = convert_single_example(ex_index, example, len(label_list), 264 | max_seq_length, tokenizer, max_qa_length) 265 | 266 | def create_int_feature(values): 267 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 268 | return f 269 | 270 | features = collections.OrderedDict() 271 | features["input_ids"] = create_int_feature(sum(feature.input_ids, [])) 272 | features["input_mask"] = create_int_feature(sum(feature.input_mask, [])) 273 | features["segment_ids"] = create_int_feature(sum(feature.segment_ids, [])) 274 | features["label_ids"] = create_int_feature([feature.label_id]) 275 | features["is_real_example"] = create_int_feature( 276 | [int(feature.is_real_example)]) 277 | 278 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 279 | writer.write(tf_example.SerializeToString()) 280 | writer.close() 281 | 282 | 283 | def create_model(albert_config, is_training, input_ids, input_mask, segment_ids, 284 | labels, num_labels, use_one_hot_embeddings, max_seq_length, 285 | dropout_prob, hub_module): 286 | """Creates a classification model.""" 287 | bsz_per_core = tf.shape(input_ids)[0] 288 | 289 | input_ids = tf.reshape(input_ids, [bsz_per_core * num_labels, max_seq_length]) 290 | input_mask = tf.reshape(input_mask, 291 | [bsz_per_core * num_labels, max_seq_length]) 292 | token_type_ids = tf.reshape(segment_ids, 293 | [bsz_per_core * num_labels, max_seq_length]) 294 | 295 | (output_layer, _) = fine_tuning_utils.create_albert( 296 | albert_config=albert_config, 297 | is_training=is_training, 298 | input_ids=input_ids, 299 | input_mask=input_mask, 300 | segment_ids=token_type_ids, 301 | use_one_hot_embeddings=use_one_hot_embeddings, 302 | use_einsum=True, 303 | hub_module=hub_module) 304 | 305 | hidden_size = output_layer.shape[-1].value 306 | 307 | output_weights = tf.get_variable( 308 | "output_weights", [1, hidden_size], 309 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 310 | 311 | output_bias = tf.get_variable( 312 | "output_bias", [1], 313 | initializer=tf.zeros_initializer()) 314 | 315 | with tf.variable_scope("loss"): 316 | if is_training: 317 | # I.e., 0.1 dropout 318 | output_layer = tf.nn.dropout( 319 | output_layer, keep_prob=1 - dropout_prob) 320 | 321 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 322 | logits = tf.nn.bias_add(logits, output_bias) 323 | logits = tf.reshape(logits, [bsz_per_core, num_labels]) 324 | probabilities = tf.nn.softmax(logits, axis=-1) 325 | predictions = tf.argmax(probabilities, axis=-1, output_type=tf.int32) 326 | log_probs = tf.nn.log_softmax(logits, axis=-1) 327 | 328 | one_hot_labels = tf.one_hot( 329 | labels, depth=tf.cast(num_labels, dtype=tf.int32), dtype=tf.float32) 330 | 331 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 332 | loss = tf.reduce_mean(per_example_loss) 333 | 334 | return (loss, per_example_loss, probabilities, logits, predictions) 335 | 336 | 337 | def model_fn_builder(albert_config, num_labels, init_checkpoint, learning_rate, 338 | num_train_steps, num_warmup_steps, use_tpu, 339 | use_one_hot_embeddings, max_seq_length, dropout_prob, 340 | hub_module): 341 | """Returns `model_fn` closure for TPUEstimator.""" 342 | 343 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 344 | """The `model_fn` for TPUEstimator.""" 345 | 346 | tf.logging.info("*** Features ***") 347 | for name in sorted(features.keys()): 348 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 349 | 350 | input_ids = features["input_ids"] 351 | input_mask = features["input_mask"] 352 | segment_ids = features["segment_ids"] 353 | label_ids = features["label_ids"] 354 | is_real_example = None 355 | if "is_real_example" in features: 356 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 357 | else: 358 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) 359 | 360 | is_training = (mode == tf_estimator.ModeKeys.TRAIN) 361 | 362 | (total_loss, per_example_loss, probabilities, logits, predictions) = \ 363 | create_model(albert_config, is_training, input_ids, input_mask, 364 | segment_ids, label_ids, num_labels, 365 | use_one_hot_embeddings, max_seq_length, dropout_prob, 366 | hub_module) 367 | 368 | tvars = tf.trainable_variables() 369 | initialized_variable_names = {} 370 | scaffold_fn = None 371 | if init_checkpoint: 372 | (assignment_map, initialized_variable_names 373 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 374 | if use_tpu: 375 | 376 | def tpu_scaffold(): 377 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 378 | return tf.train.Scaffold() 379 | 380 | scaffold_fn = tpu_scaffold 381 | else: 382 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 383 | 384 | tf.logging.info("**** Trainable Variables ****") 385 | for var in tvars: 386 | init_string = "" 387 | if var.name in initialized_variable_names: 388 | init_string = ", *INIT_FROM_CKPT*" 389 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 390 | init_string) 391 | 392 | output_spec = None 393 | if mode == tf_estimator.ModeKeys.TRAIN: 394 | 395 | train_op = optimization.create_optimizer( 396 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 397 | 398 | output_spec = contrib_tpu.TPUEstimatorSpec( 399 | mode=mode, 400 | loss=total_loss, 401 | train_op=train_op, 402 | scaffold_fn=scaffold_fn) 403 | elif mode == tf_estimator.ModeKeys.EVAL: 404 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 405 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 406 | accuracy = tf.metrics.accuracy( 407 | labels=label_ids, predictions=predictions, 408 | weights=is_real_example) 409 | loss = tf.metrics.mean( 410 | values=per_example_loss, weights=is_real_example) 411 | return { 412 | "eval_accuracy": accuracy, 413 | "eval_loss": loss, 414 | } 415 | 416 | eval_metrics = (metric_fn, 417 | [per_example_loss, label_ids, logits, is_real_example]) 418 | output_spec = contrib_tpu.TPUEstimatorSpec( 419 | mode=mode, 420 | loss=total_loss, 421 | eval_metrics=eval_metrics, 422 | scaffold_fn=scaffold_fn) 423 | else: 424 | output_spec = contrib_tpu.TPUEstimatorSpec( 425 | mode=mode, 426 | predictions={"probabilities": probabilities, 427 | "predictions": predictions}, 428 | scaffold_fn=scaffold_fn) 429 | return output_spec 430 | 431 | return model_fn 432 | 433 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Run pip install --upgrade pip if tensorflow 1.15 cannot be found 2 | tensorflow==1.15.2 # CPU Version of TensorFlow 3 | tensorflow_hub==0.7 4 | # tensorflow-gpu==1.15 # GPU version of TensorFlow 5 | sentencepiece 6 | -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning on classification tasks.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import os 23 | import time 24 | from albert import classifier_utils 25 | from albert import fine_tuning_utils 26 | from albert import modeling 27 | import tensorflow.compat.v1 as tf 28 | from tensorflow.compat.v1 import estimator as tf_estimator 29 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver 30 | from tensorflow.contrib import tpu as contrib_tpu 31 | 32 | flags = tf.flags 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | ## Required parameters 37 | flags.DEFINE_string( 38 | "data_dir", None, 39 | "The input data dir. Should contain the .tsv files (or other data files) " 40 | "for the task.") 41 | 42 | flags.DEFINE_string( 43 | "albert_config_file", None, 44 | "The config json file corresponding to the pre-trained ALBERT model. " 45 | "This specifies the model architecture.") 46 | 47 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 48 | 49 | flags.DEFINE_string( 50 | "vocab_file", None, 51 | "The vocabulary file that the ALBERT model was trained on.") 52 | 53 | flags.DEFINE_string("spm_model_file", None, 54 | "The model file for sentence piece tokenization.") 55 | 56 | flags.DEFINE_string( 57 | "output_dir", None, 58 | "The output directory where the model checkpoints will be written.") 59 | 60 | flags.DEFINE_string("cached_dir", None, 61 | "Path to cached training and dev tfrecord file. " 62 | "The file will be generated if not exist.") 63 | 64 | ## Other parameters 65 | 66 | flags.DEFINE_string( 67 | "init_checkpoint", None, 68 | "Initial checkpoint (usually from a pre-trained BERT model).") 69 | 70 | flags.DEFINE_string( 71 | "albert_hub_module_handle", None, 72 | "If set, the ALBERT hub module to use.") 73 | 74 | flags.DEFINE_bool( 75 | "do_lower_case", True, 76 | "Whether to lower case the input text. Should be True for uncased " 77 | "models and False for cased models.") 78 | 79 | flags.DEFINE_integer( 80 | "max_seq_length", 512, 81 | "The maximum total input sequence length after WordPiece tokenization. " 82 | "Sequences longer than this will be truncated, and sequences shorter " 83 | "than this will be padded.") 84 | 85 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 86 | 87 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 88 | 89 | flags.DEFINE_bool( 90 | "do_predict", False, 91 | "Whether to run the model in inference mode on the test set.") 92 | 93 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 94 | 95 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 96 | 97 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 98 | 99 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 100 | 101 | flags.DEFINE_integer("train_step", 1000, 102 | "Total number of training steps to perform.") 103 | 104 | flags.DEFINE_integer( 105 | "warmup_step", 0, 106 | "number of steps to perform linear learning rate warmup for.") 107 | 108 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 109 | "How often to save the model checkpoint.") 110 | 111 | flags.DEFINE_integer("keep_checkpoint_max", 5, 112 | "How many checkpoints to keep.") 113 | 114 | flags.DEFINE_integer("iterations_per_loop", 1000, 115 | "How many steps to make in each estimator call.") 116 | 117 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 118 | 119 | flags.DEFINE_string("optimizer", "adamw", "Optimizer to use") 120 | 121 | tf.flags.DEFINE_string( 122 | "tpu_name", None, 123 | "The Cloud TPU to use for training. This should be either the name " 124 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 125 | "url.") 126 | 127 | tf.flags.DEFINE_string( 128 | "tpu_zone", None, 129 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 130 | "specified, we will attempt to automatically detect the GCE project from " 131 | "metadata.") 132 | 133 | tf.flags.DEFINE_string( 134 | "gcp_project", None, 135 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 136 | "specified, we will attempt to automatically detect the GCE project from " 137 | "metadata.") 138 | 139 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 140 | 141 | flags.DEFINE_integer( 142 | "num_tpu_cores", 8, 143 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 144 | 145 | flags.DEFINE_string( 146 | "export_dir", None, 147 | "The directory where the exported SavedModel will be stored.") 148 | 149 | flags.DEFINE_float( 150 | "threshold_to_export", float("nan"), 151 | "The threshold value that should be used with the exported classifier. " 152 | "When specified, the threshold will be attached to the exported " 153 | "SavedModel, and served along with the predictions. Please use the " 154 | "saved model cli (" 155 | "https://www.tensorflow.org/guide/saved_model#details_of_the_savedmodel_command_line_interface" 156 | ") to view the output signature of the threshold.") 157 | 158 | 159 | def _serving_input_receiver_fn(): 160 | """Creates an input function for serving.""" 161 | seq_len = FLAGS.max_seq_length 162 | serialized_example = tf.placeholder( 163 | dtype=tf.string, shape=[None], name="serialized_example") 164 | features = { 165 | "input_ids": tf.FixedLenFeature([seq_len], dtype=tf.int64), 166 | "input_mask": tf.FixedLenFeature([seq_len], dtype=tf.int64), 167 | "segment_ids": tf.FixedLenFeature([seq_len], dtype=tf.int64), 168 | } 169 | feature_map = tf.parse_example(serialized_example, features=features) 170 | feature_map["is_real_example"] = tf.constant(1, dtype=tf.int32) 171 | feature_map["label_ids"] = tf.constant(0, dtype=tf.int32) 172 | 173 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 174 | # So cast all int64 to int32. 175 | for name in feature_map.keys(): 176 | t = feature_map[name] 177 | if t.dtype == tf.int64: 178 | t = tf.to_int32(t) 179 | feature_map[name] = t 180 | 181 | return tf_estimator.export.ServingInputReceiver( 182 | features=feature_map, receiver_tensors=serialized_example) 183 | 184 | 185 | def _add_threshold_to_model_fn(model_fn, threshold): 186 | """Adds the classifier threshold to the given model_fn.""" 187 | 188 | def new_model_fn(features, labels, mode, params): 189 | spec = model_fn(features, labels, mode, params) 190 | threshold_tensor = tf.constant(threshold, dtype=tf.float32) 191 | default_serving_export = spec.export_outputs[ 192 | tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 193 | default_serving_export.outputs["threshold"] = threshold_tensor 194 | return spec 195 | 196 | return new_model_fn 197 | 198 | 199 | def main(_): 200 | tf.logging.set_verbosity(tf.logging.INFO) 201 | 202 | processors = { 203 | "cola": classifier_utils.ColaProcessor, 204 | "mnli": classifier_utils.MnliProcessor, 205 | "mismnli": classifier_utils.MisMnliProcessor, 206 | "mrpc": classifier_utils.MrpcProcessor, 207 | "rte": classifier_utils.RteProcessor, 208 | "sst-2": classifier_utils.Sst2Processor, 209 | "sts-b": classifier_utils.StsbProcessor, 210 | "qqp": classifier_utils.QqpProcessor, 211 | "qnli": classifier_utils.QnliProcessor, 212 | "wnli": classifier_utils.WnliProcessor, 213 | } 214 | 215 | if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_predict or 216 | FLAGS.export_dir): 217 | raise ValueError( 218 | "At least one of `do_train`, `do_eval`, `do_predict' or `export_dir` " 219 | "must be True.") 220 | 221 | if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle: 222 | raise ValueError("At least one of `--albert_config_file` and " 223 | "`--albert_hub_module_handle` must be set") 224 | 225 | if FLAGS.albert_config_file: 226 | albert_config = modeling.AlbertConfig.from_json_file( 227 | FLAGS.albert_config_file) 228 | if FLAGS.max_seq_length > albert_config.max_position_embeddings: 229 | raise ValueError( 230 | "Cannot use sequence length %d because the ALBERT model " 231 | "was only trained up to sequence length %d" % 232 | (FLAGS.max_seq_length, albert_config.max_position_embeddings)) 233 | else: 234 | albert_config = None # Get the config from TF-Hub. 235 | 236 | tf.gfile.MakeDirs(FLAGS.output_dir) 237 | 238 | task_name = FLAGS.task_name.lower() 239 | 240 | if task_name not in processors: 241 | raise ValueError("Task not found: %s" % (task_name)) 242 | 243 | processor = processors[task_name]( 244 | use_spm=True if FLAGS.spm_model_file else False, 245 | do_lower_case=FLAGS.do_lower_case) 246 | 247 | label_list = processor.get_labels() 248 | 249 | tokenizer = fine_tuning_utils.create_vocab( 250 | vocab_file=FLAGS.vocab_file, 251 | do_lower_case=FLAGS.do_lower_case, 252 | spm_model_file=FLAGS.spm_model_file, 253 | hub_module=FLAGS.albert_hub_module_handle) 254 | 255 | tpu_cluster_resolver = None 256 | if FLAGS.use_tpu and FLAGS.tpu_name: 257 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( 258 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 259 | 260 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 261 | if FLAGS.do_train: 262 | iterations_per_loop = int(min(FLAGS.iterations_per_loop, 263 | FLAGS.save_checkpoints_steps)) 264 | else: 265 | iterations_per_loop = FLAGS.iterations_per_loop 266 | run_config = contrib_tpu.RunConfig( 267 | cluster=tpu_cluster_resolver, 268 | master=FLAGS.master, 269 | model_dir=FLAGS.output_dir, 270 | save_checkpoints_steps=int(FLAGS.save_checkpoints_steps), 271 | keep_checkpoint_max=0, 272 | tpu_config=contrib_tpu.TPUConfig( 273 | iterations_per_loop=iterations_per_loop, 274 | num_shards=FLAGS.num_tpu_cores, 275 | per_host_input_for_training=is_per_host)) 276 | 277 | train_examples = None 278 | if FLAGS.do_train: 279 | train_examples = processor.get_train_examples(FLAGS.data_dir) 280 | model_fn = classifier_utils.model_fn_builder( 281 | albert_config=albert_config, 282 | num_labels=len(label_list), 283 | init_checkpoint=FLAGS.init_checkpoint, 284 | learning_rate=FLAGS.learning_rate, 285 | num_train_steps=FLAGS.train_step, 286 | num_warmup_steps=FLAGS.warmup_step, 287 | use_tpu=FLAGS.use_tpu, 288 | use_one_hot_embeddings=FLAGS.use_tpu, 289 | task_name=task_name, 290 | hub_module=FLAGS.albert_hub_module_handle, 291 | optimizer=FLAGS.optimizer) 292 | 293 | if not math.isnan(FLAGS.threshold_to_export): 294 | model_fn = _add_threshold_to_model_fn(model_fn, FLAGS.threshold_to_export) 295 | 296 | # If TPU is not available, this will fall back to normal Estimator on CPU 297 | # or GPU. 298 | estimator = contrib_tpu.TPUEstimator( 299 | use_tpu=FLAGS.use_tpu, 300 | model_fn=model_fn, 301 | config=run_config, 302 | train_batch_size=FLAGS.train_batch_size, 303 | eval_batch_size=FLAGS.eval_batch_size, 304 | predict_batch_size=FLAGS.predict_batch_size, 305 | export_to_tpu=False) # http://yaqs/4707241341091840 306 | 307 | if FLAGS.do_train: 308 | cached_dir = FLAGS.cached_dir 309 | if not cached_dir: 310 | cached_dir = FLAGS.output_dir 311 | train_file = os.path.join(cached_dir, task_name + "_train.tf_record") 312 | if not tf.gfile.Exists(train_file): 313 | classifier_utils.file_based_convert_examples_to_features( 314 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, 315 | train_file, task_name) 316 | tf.logging.info("***** Running training *****") 317 | tf.logging.info(" Num examples = %d", len(train_examples)) 318 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 319 | tf.logging.info(" Num steps = %d", FLAGS.train_step) 320 | train_input_fn = classifier_utils.file_based_input_fn_builder( 321 | input_file=train_file, 322 | seq_length=FLAGS.max_seq_length, 323 | is_training=True, 324 | drop_remainder=True, 325 | task_name=task_name, 326 | use_tpu=FLAGS.use_tpu, 327 | bsz=FLAGS.train_batch_size) 328 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_step) 329 | 330 | if FLAGS.do_eval: 331 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 332 | num_actual_eval_examples = len(eval_examples) 333 | if FLAGS.use_tpu: 334 | # TPU requires a fixed batch size for all batches, therefore the number 335 | # of examples must be a multiple of the batch size, or else examples 336 | # will get dropped. So we pad with fake examples which are ignored 337 | # later on. These do NOT count towards the metric (all tf.metrics 338 | # support a per-instance weight, and these get a weight of 0.0). 339 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 340 | eval_examples.append(classifier_utils.PaddingInputExample()) 341 | 342 | cached_dir = FLAGS.cached_dir 343 | if not cached_dir: 344 | cached_dir = FLAGS.output_dir 345 | eval_file = os.path.join(cached_dir, task_name + "_eval.tf_record") 346 | if not tf.gfile.Exists(eval_file): 347 | classifier_utils.file_based_convert_examples_to_features( 348 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, 349 | eval_file, task_name) 350 | 351 | tf.logging.info("***** Running evaluation *****") 352 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 353 | len(eval_examples), num_actual_eval_examples, 354 | len(eval_examples) - num_actual_eval_examples) 355 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 356 | 357 | # This tells the estimator to run through the entire set. 358 | eval_steps = None 359 | # However, if running eval on the TPU, you will need to specify the 360 | # number of steps. 361 | if FLAGS.use_tpu: 362 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 363 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 364 | 365 | eval_drop_remainder = True if FLAGS.use_tpu else False 366 | eval_input_fn = classifier_utils.file_based_input_fn_builder( 367 | input_file=eval_file, 368 | seq_length=FLAGS.max_seq_length, 369 | is_training=False, 370 | drop_remainder=eval_drop_remainder, 371 | task_name=task_name, 372 | use_tpu=FLAGS.use_tpu, 373 | bsz=FLAGS.eval_batch_size) 374 | 375 | best_trial_info_file = os.path.join(FLAGS.output_dir, "best_trial.txt") 376 | 377 | def _best_trial_info(): 378 | """Returns information about which checkpoints have been evaled so far.""" 379 | if tf.gfile.Exists(best_trial_info_file): 380 | with tf.gfile.GFile(best_trial_info_file, "r") as best_info: 381 | global_step, best_metric_global_step, metric_value = ( 382 | best_info.read().split(":")) 383 | global_step = int(global_step) 384 | best_metric_global_step = int(best_metric_global_step) 385 | metric_value = float(metric_value) 386 | else: 387 | metric_value = -1 388 | best_metric_global_step = -1 389 | global_step = -1 390 | tf.logging.info( 391 | "Best trial info: Step: %s, Best Value Step: %s, " 392 | "Best Value: %s", global_step, best_metric_global_step, metric_value) 393 | return global_step, best_metric_global_step, metric_value 394 | 395 | def _remove_checkpoint(checkpoint_path): 396 | for ext in ["meta", "data-00000-of-00001", "index"]: 397 | src_ckpt = checkpoint_path + ".{}".format(ext) 398 | tf.logging.info("removing {}".format(src_ckpt)) 399 | tf.gfile.Remove(src_ckpt) 400 | 401 | def _find_valid_cands(curr_step): 402 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 403 | candidates = [] 404 | for filename in filenames: 405 | if filename.endswith(".index"): 406 | ckpt_name = filename[:-6] 407 | idx = ckpt_name.split("-")[-1] 408 | if int(idx) > curr_step: 409 | candidates.append(filename) 410 | return candidates 411 | 412 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 413 | 414 | if task_name == "sts-b": 415 | key_name = "pearson" 416 | elif task_name == "cola": 417 | key_name = "matthew_corr" 418 | else: 419 | key_name = "eval_accuracy" 420 | 421 | global_step, best_perf_global_step, best_perf = _best_trial_info() 422 | writer = tf.gfile.GFile(output_eval_file, "w") 423 | while global_step < FLAGS.train_step: 424 | steps_and_files = {} 425 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 426 | for filename in filenames: 427 | if filename.endswith(".index"): 428 | ckpt_name = filename[:-6] 429 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 430 | if cur_filename.split("-")[-1] == "best": 431 | continue 432 | gstep = int(cur_filename.split("-")[-1]) 433 | if gstep not in steps_and_files: 434 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 435 | steps_and_files[gstep] = cur_filename 436 | tf.logging.info("found {} files.".format(len(steps_and_files))) 437 | if not steps_and_files: 438 | tf.logging.info("found 0 file, global step: {}. Sleeping." 439 | .format(global_step)) 440 | time.sleep(60) 441 | else: 442 | for checkpoint in sorted(steps_and_files.items()): 443 | step, checkpoint_path = checkpoint 444 | if global_step >= step: 445 | if (best_perf_global_step != step and 446 | len(_find_valid_cands(step)) > 1): 447 | _remove_checkpoint(checkpoint_path) 448 | continue 449 | result = estimator.evaluate( 450 | input_fn=eval_input_fn, 451 | steps=eval_steps, 452 | checkpoint_path=checkpoint_path) 453 | global_step = result["global_step"] 454 | tf.logging.info("***** Eval results *****") 455 | for key in sorted(result.keys()): 456 | tf.logging.info(" %s = %s", key, str(result[key])) 457 | writer.write("%s = %s\n" % (key, str(result[key]))) 458 | writer.write("best = {}\n".format(best_perf)) 459 | if result[key_name] > best_perf: 460 | best_perf = result[key_name] 461 | best_perf_global_step = global_step 462 | elif len(_find_valid_cands(global_step)) > 1: 463 | _remove_checkpoint(checkpoint_path) 464 | writer.write("=" * 50 + "\n") 465 | writer.flush() 466 | with tf.gfile.GFile(best_trial_info_file, "w") as best_info: 467 | best_info.write("{}:{}:{}".format( 468 | global_step, best_perf_global_step, best_perf)) 469 | writer.close() 470 | 471 | for ext in ["meta", "data-00000-of-00001", "index"]: 472 | src_ckpt = "model.ckpt-{}.{}".format(best_perf_global_step, ext) 473 | tgt_ckpt = "model.ckpt-best.{}".format(ext) 474 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt)) 475 | tf.io.gfile.rename( 476 | os.path.join(FLAGS.output_dir, src_ckpt), 477 | os.path.join(FLAGS.output_dir, tgt_ckpt), 478 | overwrite=True) 479 | 480 | if FLAGS.do_predict: 481 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 482 | num_actual_predict_examples = len(predict_examples) 483 | if FLAGS.use_tpu: 484 | # TPU requires a fixed batch size for all batches, therefore the number 485 | # of examples must be a multiple of the batch size, or else examples 486 | # will get dropped. So we pad with fake examples which are ignored 487 | # later on. 488 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 489 | predict_examples.append(classifier_utils.PaddingInputExample()) 490 | 491 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 492 | classifier_utils.file_based_convert_examples_to_features( 493 | predict_examples, label_list, 494 | FLAGS.max_seq_length, tokenizer, 495 | predict_file, task_name) 496 | 497 | tf.logging.info("***** Running prediction*****") 498 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 499 | len(predict_examples), num_actual_predict_examples, 500 | len(predict_examples) - num_actual_predict_examples) 501 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 502 | 503 | predict_drop_remainder = True if FLAGS.use_tpu else False 504 | predict_input_fn = classifier_utils.file_based_input_fn_builder( 505 | input_file=predict_file, 506 | seq_length=FLAGS.max_seq_length, 507 | is_training=False, 508 | drop_remainder=predict_drop_remainder, 509 | task_name=task_name, 510 | use_tpu=FLAGS.use_tpu, 511 | bsz=FLAGS.predict_batch_size) 512 | 513 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 514 | result = estimator.predict( 515 | input_fn=predict_input_fn, 516 | checkpoint_path=checkpoint_path) 517 | 518 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 519 | output_submit_file = os.path.join(FLAGS.output_dir, "submit_results.tsv") 520 | with tf.gfile.GFile(output_predict_file, "w") as pred_writer,\ 521 | tf.gfile.GFile(output_submit_file, "w") as sub_writer: 522 | sub_writer.write("index" + "\t" + "prediction\n") 523 | num_written_lines = 0 524 | tf.logging.info("***** Predict results *****") 525 | for (i, (example, prediction)) in\ 526 | enumerate(zip(predict_examples, result)): 527 | probabilities = prediction["probabilities"] 528 | if i >= num_actual_predict_examples: 529 | break 530 | output_line = "\t".join( 531 | str(class_probability) 532 | for class_probability in probabilities) + "\n" 533 | pred_writer.write(output_line) 534 | 535 | if task_name != "sts-b": 536 | actual_label = label_list[int(prediction["predictions"])] 537 | else: 538 | actual_label = str(prediction["predictions"]) 539 | sub_writer.write(example.guid + "\t" + actual_label + "\n") 540 | num_written_lines += 1 541 | assert num_written_lines == num_actual_predict_examples 542 | 543 | if FLAGS.export_dir: 544 | tf.gfile.MakeDirs(FLAGS.export_dir) 545 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 546 | tf.logging.info("Starting to export model.") 547 | subfolder = estimator.export_saved_model( 548 | export_dir_base=FLAGS.export_dir, 549 | serving_input_receiver_fn=_serving_input_receiver_fn, 550 | checkpoint_path=checkpoint_path) 551 | tf.logging.info("Model exported to %s.", subfolder) 552 | 553 | 554 | if __name__ == "__main__": 555 | flags.mark_flag_as_required("data_dir") 556 | flags.mark_flag_as_required("task_name") 557 | flags.mark_flag_as_required("spm_model_file") 558 | flags.mark_flag_as_required("output_dir") 559 | tf.app.run() 560 | -------------------------------------------------------------------------------- /run_glue.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This is a convenience script for evaluating ALBERT on the GLUE benchmark. 3 | # 4 | # By default, this script uses a pretrained ALBERT v1 BASE model, but you may 5 | # use a custom checkpoint or any compatible TF-Hub checkpoint with minimal 6 | # edits to environment variables (see ALBERT_HUB_MODULE_HANDLE below). 7 | # 8 | # This script does fine-tuning and evaluation on 8 tasks, so it may take a 9 | # while to complete if you do not have a hardware accelerator. 10 | 11 | set -ex 12 | 13 | python3 -m venv $HOME/albertenv 14 | . $HOME/albertenv/bin/activate 15 | 16 | OUTPUT_DIR_BASE="$(mktemp -d)" 17 | OUTPUT_DIR="${OUTPUT_DIR_BASE}/output" 18 | 19 | # To start from a custom pretrained checkpoint, set ALBERT_HUB_MODULE_HANDLE 20 | # below to an empty string and set INIT_CHECKPOINT to your checkpoint path. 21 | ALBERT_HUB_MODULE_HANDLE="https://tfhub.dev/google/albert_base/1" 22 | INIT_CHECKPOINT="" 23 | 24 | pip3 install --upgrade pip 25 | pip3 install numpy 26 | pip3 install -r requirements.txt 27 | 28 | function run_task() { 29 | COMMON_ARGS="--output_dir="${OUTPUT_DIR}/$1" --data_dir="${ALBERT_ROOT}/glue" --vocab_file="${ALBERT_ROOT}/vocab.txt" --spm_model_file="${ALBERT_ROOT}/30k-clean.model" --do_lower_case --max_seq_length=512 --optimizer=adamw --task_name=$1 --warmup_step=$2 --learning_rate=$3 --train_step=$4 --save_checkpoints_steps=$5 --train_batch_size=$6" 30 | python3 -m run_classifier \ 31 | ${COMMON_ARGS} \ 32 | --do_train \ 33 | --nodo_eval \ 34 | --nodo_predict \ 35 | --albert_hub_module_handle="${ALBERT_HUB_MODULE_HANDLE}" \ 36 | --init_checkpoint="${INIT_CHECKPOINT}" 37 | python3 -m run_classifier \ 38 | ${COMMON_ARGS} \ 39 | --nodo_train \ 40 | --do_eval \ 41 | --do_predict \ 42 | --albert_hub_module_handle="${ALBERT_HUB_MODULE_HANDLE}" 43 | } 44 | 45 | run_task SST-2 1256 1e-5 20935 100 32 46 | run_task MNLI 1000 3e-5 10000 100 128 47 | run_task CoLA 320 1e-5 5336 100 16 48 | run_task QNLI 1986 1e-5 33112 200 32 49 | run_task QQP 1000 5e-5 14000 100 128 50 | run_task RTE 200 3e-5 800 100 32 51 | run_task STS-B 214 2e-5 3598 100 16 52 | run_task MRPC 200 2e-5 800 100 32 53 | -------------------------------------------------------------------------------- /run_pretraining_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tests for run_pretraining.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import random 23 | import tempfile 24 | from absl.testing import flagsaver 25 | from albert import modeling 26 | from albert import run_pretraining 27 | import tensorflow.compat.v1 as tf 28 | 29 | FLAGS = tf.app.flags.FLAGS 30 | 31 | 32 | def _create_config_file(filename, max_seq_length, vocab_size): 33 | """Creates an AlbertConfig and saves it to file.""" 34 | albert_config = modeling.AlbertConfig( 35 | vocab_size, 36 | embedding_size=5, 37 | hidden_size=14, 38 | num_hidden_layers=3, 39 | num_hidden_groups=1, 40 | num_attention_heads=2, 41 | intermediate_size=19, 42 | inner_group_num=1, 43 | down_scale_factor=1, 44 | hidden_act="gelu", 45 | hidden_dropout_prob=0, 46 | attention_probs_dropout_prob=0, 47 | max_position_embeddings=max_seq_length, 48 | type_vocab_size=2, 49 | initializer_range=0.02) 50 | with tf.gfile.Open(filename, "w") as outfile: 51 | outfile.write(albert_config.to_json_string()) 52 | 53 | 54 | def _create_record(max_predictions_per_seq, max_seq_length, vocab_size): 55 | """Returns a tf.train.Example containing random data.""" 56 | example = tf.train.Example() 57 | example.features.feature["input_ids"].int64_list.value.extend( 58 | [random.randint(0, vocab_size - 1) for _ in range(max_seq_length)]) 59 | example.features.feature["input_mask"].int64_list.value.extend( 60 | [random.randint(0, 1) for _ in range(max_seq_length)]) 61 | example.features.feature["masked_lm_positions"].int64_list.value.extend([ 62 | random.randint(0, max_seq_length - 1) 63 | for _ in range(max_predictions_per_seq) 64 | ]) 65 | example.features.feature["masked_lm_ids"].int64_list.value.extend([ 66 | random.randint(0, vocab_size - 1) for _ in range(max_predictions_per_seq) 67 | ]) 68 | example.features.feature["masked_lm_weights"].float_list.value.extend( 69 | [1. for _ in range(max_predictions_per_seq)]) 70 | example.features.feature["segment_ids"].int64_list.value.extend( 71 | [0 for _ in range(max_seq_length)]) 72 | example.features.feature["next_sentence_labels"].int64_list.value.append( 73 | random.randint(0, 1)) 74 | return example 75 | 76 | 77 | def _create_input_file(filename, 78 | max_predictions_per_seq, 79 | max_seq_length, 80 | vocab_size, 81 | size=1000): 82 | """Creates an input TFRecord file of specified size.""" 83 | with tf.io.TFRecordWriter(filename) as writer: 84 | for _ in range(size): 85 | ex = _create_record(max_predictions_per_seq, max_seq_length, vocab_size) 86 | writer.write(ex.SerializeToString()) 87 | 88 | 89 | class RunPretrainingTest(tf.test.TestCase): 90 | 91 | def _verify_output_file(self, basename): 92 | self.assertTrue(tf.gfile.Exists(os.path.join(FLAGS.output_dir, basename))) 93 | 94 | def _verify_checkpoint_files(self, name): 95 | self._verify_output_file(name + ".meta") 96 | self._verify_output_file(name + ".index") 97 | self._verify_output_file(name + ".data-00000-of-00001") 98 | 99 | @flagsaver.flagsaver 100 | def test_pretraining(self): 101 | # Set up required flags. 102 | vocab_size = 97 103 | FLAGS.max_predictions_per_seq = 7 104 | FLAGS.max_seq_length = 13 105 | FLAGS.output_dir = tempfile.mkdtemp("output_dir") 106 | FLAGS.albert_config_file = os.path.join( 107 | tempfile.mkdtemp("config_dir"), "albert_config.json") 108 | FLAGS.input_file = os.path.join( 109 | tempfile.mkdtemp("input_dir"), "input_data.tfrecord") 110 | FLAGS.do_train = True 111 | FLAGS.do_eval = True 112 | FLAGS.num_train_steps = 1 113 | FLAGS.save_checkpoints_steps = 1 114 | 115 | # Construct requisite input files. 116 | _create_config_file(FLAGS.albert_config_file, FLAGS.max_seq_length, 117 | vocab_size) 118 | _create_input_file(FLAGS.input_file, FLAGS.max_predictions_per_seq, 119 | FLAGS.max_seq_length, vocab_size) 120 | 121 | # Run the pretraining. 122 | run_pretraining.main(None) 123 | 124 | # Verify output. 125 | self._verify_checkpoint_files("model.ckpt-best") 126 | self._verify_checkpoint_files("model.ckpt-1") 127 | self._verify_output_file("eval_results.txt") 128 | self._verify_output_file("checkpoint") 129 | 130 | 131 | if __name__ == "__main__": 132 | tf.test.main() 133 | -------------------------------------------------------------------------------- /run_race.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ALBERT finetuning runner with sentence piece tokenization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import time 23 | from albert import classifier_utils 24 | from albert import fine_tuning_utils 25 | from albert import modeling 26 | from albert import race_utils 27 | import tensorflow.compat.v1 as tf 28 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver 29 | from tensorflow.contrib import tpu as contrib_tpu 30 | 31 | flags = tf.flags 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | ## Required parameters 36 | flags.DEFINE_string( 37 | "data_dir", None, 38 | "The input data dir. Should contain the .tsv files (or other data files) " 39 | "for the task.") 40 | 41 | flags.DEFINE_string( 42 | "albert_config_file", None, 43 | "The config json file corresponding to the pre-trained ALBERT model. " 44 | "This specifies the model architecture.") 45 | 46 | flags.DEFINE_string("task_name", "race", "The name of the task to train.") 47 | 48 | flags.DEFINE_string("vocab_file", None, 49 | "The vocabulary file that the ALBERT model was trained on.") 50 | 51 | flags.DEFINE_string("train_file", None, 52 | "path to preprocessed tfrecord file. " 53 | "The file will be generated if not exst.") 54 | 55 | flags.DEFINE_string("eval_file", None, 56 | "path to preprocessed tfrecord file. " 57 | "The file will be generated if not exst.") 58 | 59 | flags.DEFINE_string("predict_file", None, 60 | "path to preprocessed tfrecord file. " 61 | "The file will be generated if not exst.") 62 | 63 | flags.DEFINE_string("spm_model_file", None, 64 | "The model file for sentence piece tokenization.") 65 | 66 | flags.DEFINE_string( 67 | "output_dir", None, 68 | "The output directory where the model checkpoints will be written.") 69 | 70 | ## Other parameters 71 | 72 | flags.DEFINE_string( 73 | "init_checkpoint", None, 74 | "Initial checkpoint (usually from a pre-trained ALBERT model).") 75 | 76 | flags.DEFINE_string( 77 | "albert_hub_module_handle", None, 78 | "If set, the ALBERT hub module to use.") 79 | 80 | flags.DEFINE_bool( 81 | "do_lower_case", True, 82 | "Whether to lower case the input text. Should be True for uncased " 83 | "models and False for cased models.") 84 | 85 | flags.DEFINE_float("dropout_prob", 0.1, "dropout probability.") 86 | 87 | flags.DEFINE_integer( 88 | "max_seq_length", 512, 89 | "The maximum total input sequence length after WordPiece tokenization. " 90 | "Sequences longer than this will be truncated, and sequences shorter " 91 | "than this will be padded.") 92 | 93 | flags.DEFINE_integer( 94 | "max_qa_length", 128, 95 | "The maximum total input sequence length after WordPiece tokenization. " 96 | "Sequences longer than this will be truncated, and sequences shorter " 97 | "than this will be padded.") 98 | 99 | flags.DEFINE_integer( 100 | "num_keep_checkpoint", 5, 101 | "maximum number of keep checkpoints") 102 | 103 | 104 | flags.DEFINE_bool( 105 | "high_only", False, 106 | "Whether to only run the model on the high school set.") 107 | 108 | flags.DEFINE_bool( 109 | "middle_only", False, 110 | "Whether to only run the model on the middle school set.") 111 | 112 | flags.DEFINE_bool("do_train", True, "Whether to run training.") 113 | 114 | flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.") 115 | 116 | flags.DEFINE_bool( 117 | "do_predict", False, 118 | "Whether to run the model in inference mode on the test set.") 119 | 120 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 121 | 122 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 123 | 124 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 125 | 126 | flags.DEFINE_float("learning_rate", 1e-5, "The initial learning rate for Adam.") 127 | 128 | flags.DEFINE_integer("train_step", 12000, 129 | "Total number of training epochs to perform.") 130 | 131 | flags.DEFINE_integer( 132 | "warmup_step", 1000, 133 | "number of steps to perform linear learning rate warmup for.") 134 | 135 | flags.DEFINE_integer("save_checkpoints_steps", 100, 136 | "How often to save the model checkpoint.") 137 | 138 | flags.DEFINE_integer("iterations_per_loop", 1000, 139 | "How many steps to make in each estimator call.") 140 | 141 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 142 | 143 | tf.flags.DEFINE_string( 144 | "tpu_name", None, 145 | "The Cloud TPU to use for training. This should be either the name " 146 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 147 | "url.") 148 | 149 | tf.flags.DEFINE_string( 150 | "tpu_zone", None, 151 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 152 | "specified, we will attempt to automatically detect the GCE project from " 153 | "metadata.") 154 | 155 | tf.flags.DEFINE_string( 156 | "gcp_project", None, 157 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 158 | "specified, we will attempt to automatically detect the GCE project from " 159 | "metadata.") 160 | 161 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 162 | 163 | flags.DEFINE_integer( 164 | "num_tpu_cores", 8, 165 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 166 | 167 | 168 | def main(_): 169 | tf.logging.set_verbosity(tf.logging.INFO) 170 | 171 | processors = { 172 | "race": race_utils.RaceProcessor 173 | } 174 | 175 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 176 | raise ValueError( 177 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 178 | 179 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) 180 | 181 | if FLAGS.max_seq_length > albert_config.max_position_embeddings: 182 | raise ValueError( 183 | "Cannot use sequence length %d because the ALBERT model " 184 | "was only trained up to sequence length %d" % 185 | (FLAGS.max_seq_length, albert_config.max_position_embeddings)) 186 | 187 | tf.gfile.MakeDirs(FLAGS.output_dir) 188 | 189 | task_name = FLAGS.task_name.lower() 190 | 191 | if task_name not in processors: 192 | raise ValueError("Task not found: %s" % (task_name)) 193 | 194 | processor = processors[task_name]( 195 | use_spm=True if FLAGS.spm_model_file else False, 196 | do_lower_case=FLAGS.do_lower_case, 197 | high_only=FLAGS.high_only, 198 | middle_only=FLAGS.middle_only) 199 | 200 | label_list = processor.get_labels() 201 | 202 | tokenizer = fine_tuning_utils.create_vocab( 203 | vocab_file=FLAGS.vocab_file, 204 | do_lower_case=FLAGS.do_lower_case, 205 | spm_model_file=FLAGS.spm_model_file, 206 | hub_module=FLAGS.albert_hub_module_handle) 207 | 208 | tpu_cluster_resolver = None 209 | if FLAGS.use_tpu and FLAGS.tpu_name: 210 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( 211 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 212 | 213 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 214 | if FLAGS.do_train: 215 | iterations_per_loop = int(min(FLAGS.iterations_per_loop, 216 | FLAGS.save_checkpoints_steps)) 217 | else: 218 | iterations_per_loop = FLAGS.iterations_per_loop 219 | run_config = contrib_tpu.RunConfig( 220 | cluster=tpu_cluster_resolver, 221 | master=FLAGS.master, 222 | model_dir=FLAGS.output_dir, 223 | save_checkpoints_steps=int(FLAGS.save_checkpoints_steps), 224 | keep_checkpoint_max=0, 225 | tpu_config=contrib_tpu.TPUConfig( 226 | iterations_per_loop=iterations_per_loop, 227 | num_shards=FLAGS.num_tpu_cores, 228 | per_host_input_for_training=is_per_host)) 229 | 230 | train_examples = None 231 | if FLAGS.do_train: 232 | train_examples = processor.get_train_examples(FLAGS.data_dir) 233 | 234 | model_fn = race_utils.model_fn_builder( 235 | albert_config=albert_config, 236 | num_labels=len(label_list), 237 | init_checkpoint=FLAGS.init_checkpoint, 238 | learning_rate=FLAGS.learning_rate, 239 | num_train_steps=FLAGS.train_step, 240 | num_warmup_steps=FLAGS.warmup_step, 241 | use_tpu=FLAGS.use_tpu, 242 | use_one_hot_embeddings=FLAGS.use_tpu, 243 | max_seq_length=FLAGS.max_seq_length, 244 | dropout_prob=FLAGS.dropout_prob, 245 | hub_module=FLAGS.albert_hub_module_handle) 246 | 247 | # If TPU is not available, this will fall back to normal Estimator on CPU 248 | # or GPU. 249 | estimator = contrib_tpu.TPUEstimator( 250 | use_tpu=FLAGS.use_tpu, 251 | model_fn=model_fn, 252 | config=run_config, 253 | train_batch_size=FLAGS.train_batch_size, 254 | eval_batch_size=FLAGS.eval_batch_size, 255 | predict_batch_size=FLAGS.predict_batch_size) 256 | 257 | if FLAGS.do_train: 258 | if not tf.gfile.Exists(FLAGS.train_file): 259 | race_utils.file_based_convert_examples_to_features( 260 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, 261 | FLAGS.train_file, FLAGS.max_qa_length) 262 | tf.logging.info("***** Running training *****") 263 | tf.logging.info(" Num examples = %d", len(train_examples)) 264 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 265 | tf.logging.info(" Num steps = %d", FLAGS.train_step) 266 | train_input_fn = classifier_utils.file_based_input_fn_builder( 267 | input_file=FLAGS.train_file, 268 | seq_length=FLAGS.max_seq_length, 269 | is_training=True, 270 | drop_remainder=True, 271 | task_name=task_name, 272 | use_tpu=FLAGS.use_tpu, 273 | bsz=FLAGS.train_batch_size, 274 | multiple=len(label_list)) 275 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_step) 276 | 277 | if FLAGS.do_eval: 278 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 279 | num_actual_eval_examples = len(eval_examples) 280 | if FLAGS.use_tpu: 281 | # TPU requires a fixed batch size for all batches, therefore the number 282 | # of examples must be a multiple of the batch size, or else examples 283 | # will get dropped. So we pad with fake examples which are ignored 284 | # later on. These do NOT count towards the metric (all tf.metrics 285 | # support a per-instance weight, and these get a weight of 0.0). 286 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 287 | eval_examples.append(classifier_utils.PaddingInputExample()) 288 | 289 | if not tf.gfile.Exists(FLAGS.eval_file): 290 | race_utils.file_based_convert_examples_to_features( 291 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, 292 | FLAGS.eval_file, FLAGS.max_qa_length) 293 | 294 | tf.logging.info("***** Running evaluation *****") 295 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 296 | len(eval_examples), num_actual_eval_examples, 297 | len(eval_examples) - num_actual_eval_examples) 298 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 299 | 300 | # This tells the estimator to run through the entire set. 301 | eval_steps = None 302 | # However, if running eval on the TPU, you will need to specify the 303 | # number of steps. 304 | if FLAGS.use_tpu: 305 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 306 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 307 | 308 | eval_drop_remainder = True if FLAGS.use_tpu else False 309 | eval_input_fn = classifier_utils.file_based_input_fn_builder( 310 | input_file=FLAGS.eval_file, 311 | seq_length=FLAGS.max_seq_length, 312 | is_training=False, 313 | drop_remainder=eval_drop_remainder, 314 | task_name=task_name, 315 | use_tpu=FLAGS.use_tpu, 316 | bsz=FLAGS.eval_batch_size, 317 | multiple=len(label_list)) 318 | 319 | def _find_valid_cands(curr_step): 320 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 321 | candidates = [] 322 | for filename in filenames: 323 | if filename.endswith(".index"): 324 | ckpt_name = filename[:-6] 325 | idx = ckpt_name.split("-")[-1] 326 | if idx != "best" and int(idx) > curr_step: 327 | candidates.append(filename) 328 | return candidates 329 | 330 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 331 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 332 | key_name = "eval_accuracy" 333 | if tf.gfile.Exists(checkpoint_path + ".index"): 334 | result = estimator.evaluate( 335 | input_fn=eval_input_fn, 336 | steps=eval_steps, 337 | checkpoint_path=checkpoint_path) 338 | best_perf = result[key_name] 339 | global_step = result["global_step"] 340 | else: 341 | global_step = -1 342 | best_perf = -1 343 | checkpoint_path = None 344 | writer = tf.gfile.GFile(output_eval_file, "w") 345 | while global_step < FLAGS.train_step: 346 | steps_and_files = {} 347 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 348 | for filename in filenames: 349 | if filename.endswith(".index"): 350 | ckpt_name = filename[:-6] 351 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 352 | if cur_filename.split("-")[-1] == "best": 353 | continue 354 | gstep = int(cur_filename.split("-")[-1]) 355 | if gstep not in steps_and_files: 356 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 357 | steps_and_files[gstep] = cur_filename 358 | tf.logging.info("found {} files.".format(len(steps_and_files))) 359 | # steps_and_files = sorted(steps_and_files, key=lambda x: x[0]) 360 | if not steps_and_files: 361 | tf.logging.info("found 0 file, global step: {}. Sleeping." 362 | .format(global_step)) 363 | time.sleep(1) 364 | else: 365 | for ele in sorted(steps_and_files.items()): 366 | step, checkpoint_path = ele 367 | if global_step >= step: 368 | if len(_find_valid_cands(step)) > 1: 369 | for ext in ["meta", "data-00000-of-00001", "index"]: 370 | src_ckpt = checkpoint_path + ".{}".format(ext) 371 | tf.logging.info("removing {}".format(src_ckpt)) 372 | tf.gfile.Remove(src_ckpt) 373 | continue 374 | result = estimator.evaluate( 375 | input_fn=eval_input_fn, 376 | steps=eval_steps, 377 | checkpoint_path=checkpoint_path) 378 | global_step = result["global_step"] 379 | tf.logging.info("***** Eval results *****") 380 | for key in sorted(result.keys()): 381 | tf.logging.info(" %s = %s", key, str(result[key])) 382 | writer.write("%s = %s\n" % (key, str(result[key]))) 383 | writer.write("best = {}\n".format(best_perf)) 384 | if result[key_name] > best_perf: 385 | best_perf = result[key_name] 386 | for ext in ["meta", "data-00000-of-00001", "index"]: 387 | src_ckpt = checkpoint_path + ".{}".format(ext) 388 | tgt_ckpt = checkpoint_path.rsplit("-", 1)[0] + "-best.{}".format(ext) 389 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt)) 390 | tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True) 391 | writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt)) 392 | 393 | if len(_find_valid_cands(global_step)) > 1: 394 | for ext in ["meta", "data-00000-of-00001", "index"]: 395 | src_ckpt = checkpoint_path + ".{}".format(ext) 396 | tf.logging.info("removing {}".format(src_ckpt)) 397 | tf.gfile.Remove(src_ckpt) 398 | writer.write("=" * 50 + "\n") 399 | writer.close() 400 | if FLAGS.do_predict: 401 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 402 | num_actual_predict_examples = len(predict_examples) 403 | if FLAGS.use_tpu: 404 | # TPU requires a fixed batch size for all batches, therefore the number 405 | # of examples must be a multiple of the batch size, or else examples 406 | # will get dropped. So we pad with fake examples which are ignored 407 | # later on. 408 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 409 | predict_examples.append(classifier_utils.PaddingInputExample()) 410 | assert len(predict_examples) % FLAGS.predict_batch_size == 0 411 | predict_steps = int(len(predict_examples) // FLAGS.predict_batch_size) 412 | 413 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 414 | race_utils.file_based_convert_examples_to_features( 415 | predict_examples, label_list, 416 | FLAGS.max_seq_length, tokenizer, 417 | predict_file, FLAGS.max_qa_length) 418 | 419 | tf.logging.info("***** Running prediction*****") 420 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 421 | len(predict_examples), num_actual_predict_examples, 422 | len(predict_examples) - num_actual_predict_examples) 423 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 424 | 425 | predict_drop_remainder = True if FLAGS.use_tpu else False 426 | predict_input_fn = classifier_utils.file_based_input_fn_builder( 427 | input_file=predict_file, 428 | seq_length=FLAGS.max_seq_length, 429 | is_training=False, 430 | drop_remainder=predict_drop_remainder, 431 | task_name=task_name, 432 | use_tpu=FLAGS.use_tpu, 433 | bsz=FLAGS.predict_batch_size, 434 | multiple=len(label_list)) 435 | 436 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 437 | result = estimator.evaluate( 438 | input_fn=predict_input_fn, 439 | steps=predict_steps, 440 | checkpoint_path=checkpoint_path) 441 | 442 | output_predict_file = os.path.join(FLAGS.output_dir, "predict_results.txt") 443 | with tf.gfile.GFile(output_predict_file, "w") as pred_writer: 444 | # num_written_lines = 0 445 | tf.logging.info("***** Predict results *****") 446 | pred_writer.write("***** Predict results *****\n") 447 | for key in sorted(result.keys()): 448 | tf.logging.info(" %s = %s", key, str(result[key])) 449 | pred_writer.write("%s = %s\n" % (key, str(result[key]))) 450 | pred_writer.write("best = {}\n".format(best_perf)) 451 | 452 | 453 | if __name__ == "__main__": 454 | flags.mark_flag_as_required("data_dir") 455 | flags.mark_flag_as_required("spm_model_file") 456 | flags.mark_flag_as_required("albert_config_file") 457 | flags.mark_flag_as_required("output_dir") 458 | tf.app.run() 459 | -------------------------------------------------------------------------------- /run_squad_v1.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run ALBERT on SQuAD v1.1 using sentence piece tokenization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | import json 23 | import os 24 | import random 25 | import time 26 | from albert import fine_tuning_utils 27 | from albert import modeling 28 | from albert import squad_utils 29 | import six 30 | import tensorflow.compat.v1 as tf 31 | from tensorflow.compat.v1 import estimator as tf_estimator 32 | 33 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver 34 | from tensorflow.contrib import tpu as contrib_tpu 35 | 36 | 37 | # pylint: disable=g-import-not-at-top 38 | if six.PY2: 39 | import six.moves.cPickle as pickle 40 | else: 41 | import pickle 42 | # pylint: enable=g-import-not-at-top 43 | 44 | flags = tf.flags 45 | 46 | FLAGS = flags.FLAGS 47 | 48 | ## Required parameters 49 | flags.DEFINE_string( 50 | "albert_config_file", None, 51 | "The config json file corresponding to the pre-trained BERT model. " 52 | "This specifies the model architecture.") 53 | 54 | flags.DEFINE_string("vocab_file", None, 55 | "The vocabulary file that the BERT model was trained on.") 56 | 57 | flags.DEFINE_string("spm_model_file", None, 58 | "The model file for sentence piece tokenization.") 59 | 60 | flags.DEFINE_string( 61 | "output_dir", None, 62 | "The output directory where the model checkpoints will be written.") 63 | 64 | ## Other parameters 65 | flags.DEFINE_string("train_file", None, 66 | "SQuAD json for training. E.g., train-v1.1.json") 67 | 68 | flags.DEFINE_string( 69 | "predict_file", None, 70 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 71 | 72 | flags.DEFINE_string("train_feature_file", None, 73 | "training feature file.") 74 | 75 | flags.DEFINE_string( 76 | "predict_feature_file", None, 77 | "Location of predict features. If it doesn't exist, it will be written. " 78 | "If it does exist, it will be read.") 79 | 80 | flags.DEFINE_string( 81 | "predict_feature_left_file", None, 82 | "Location of predict features not passed to TPU. If it doesn't exist, it " 83 | "will be written. If it does exist, it will be read.") 84 | 85 | flags.DEFINE_string( 86 | "init_checkpoint", None, 87 | "Initial checkpoint (usually from a pre-trained BERT model).") 88 | 89 | flags.DEFINE_string( 90 | "albert_hub_module_handle", None, 91 | "If set, the ALBERT hub module to use.") 92 | 93 | flags.DEFINE_bool( 94 | "do_lower_case", True, 95 | "Whether to lower case the input text. Should be True for uncased " 96 | "models and False for cased models.") 97 | 98 | flags.DEFINE_integer( 99 | "max_seq_length", 384, 100 | "The maximum total input sequence length after WordPiece tokenization. " 101 | "Sequences longer than this will be truncated, and sequences shorter " 102 | "than this will be padded.") 103 | 104 | flags.DEFINE_integer( 105 | "doc_stride", 128, 106 | "When splitting up a long document into chunks, how much stride to " 107 | "take between chunks.") 108 | 109 | flags.DEFINE_integer( 110 | "max_query_length", 64, 111 | "The maximum number of tokens for the question. Questions longer than " 112 | "this will be truncated to this length.") 113 | 114 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 115 | 116 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.") 117 | 118 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 119 | 120 | flags.DEFINE_integer("predict_batch_size", 8, 121 | "Total batch size for predictions.") 122 | 123 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 124 | 125 | flags.DEFINE_float("num_train_epochs", 3.0, 126 | "Total number of training epochs to perform.") 127 | 128 | flags.DEFINE_float( 129 | "warmup_proportion", 0.1, 130 | "Proportion of training to perform linear learning rate warmup for. " 131 | "E.g., 0.1 = 10% of training.") 132 | 133 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 134 | "How often to save the model checkpoint.") 135 | 136 | flags.DEFINE_integer("iterations_per_loop", 1000, 137 | "How many steps to make in each estimator call.") 138 | 139 | flags.DEFINE_integer( 140 | "n_best_size", 20, 141 | "The total number of n-best predictions to generate in the " 142 | "nbest_predictions.json output file.") 143 | 144 | flags.DEFINE_integer( 145 | "max_answer_length", 30, 146 | "The maximum length of an answer that can be generated. This is needed " 147 | "because the start and end predictions are not conditioned on one another.") 148 | 149 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 150 | 151 | tf.flags.DEFINE_string( 152 | "tpu_name", None, 153 | "The Cloud TPU to use for training. This should be either the name " 154 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 155 | "url.") 156 | 157 | tf.flags.DEFINE_string( 158 | "tpu_zone", None, 159 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 160 | "specified, we will attempt to automatically detect the GCE project from " 161 | "metadata.") 162 | 163 | tf.flags.DEFINE_string( 164 | "gcp_project", None, 165 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 166 | "specified, we will attempt to automatically detect the GCE project from " 167 | "metadata.") 168 | 169 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 170 | 171 | flags.DEFINE_integer( 172 | "num_tpu_cores", 8, 173 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 174 | 175 | flags.DEFINE_bool( 176 | "use_einsum", True, 177 | "Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must " 178 | "be set to False for TFLite compatibility.") 179 | 180 | flags.DEFINE_string( 181 | "export_dir", 182 | default=None, 183 | help=("The directory where the exported SavedModel will be stored.")) 184 | 185 | 186 | def validate_flags_or_throw(albert_config): 187 | """Validate the input FLAGS or throw an exception.""" 188 | 189 | if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.export_dir: 190 | err_msg = "At least one of `do_train` or `do_predict` or `export_dir`" + "must be True." 191 | raise ValueError(err_msg) 192 | 193 | if FLAGS.do_train: 194 | if not FLAGS.train_file: 195 | raise ValueError( 196 | "If `do_train` is True, then `train_file` must be specified.") 197 | if FLAGS.do_predict: 198 | if not FLAGS.predict_file: 199 | raise ValueError( 200 | "If `do_predict` is True, then `predict_file` must be specified.") 201 | if not FLAGS.predict_feature_file: 202 | raise ValueError( 203 | "If `do_predict` is True, then `predict_feature_file` must be " 204 | "specified.") 205 | if not FLAGS.predict_feature_left_file: 206 | raise ValueError( 207 | "If `do_predict` is True, then `predict_feature_left_file` must be " 208 | "specified.") 209 | 210 | if FLAGS.max_seq_length > albert_config.max_position_embeddings: 211 | raise ValueError( 212 | "Cannot use sequence length %d because the ALBERT model " 213 | "was only trained up to sequence length %d" % 214 | (FLAGS.max_seq_length, albert_config.max_position_embeddings)) 215 | 216 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3: 217 | raise ValueError( 218 | "The max_seq_length (%d) must be greater than max_query_length " 219 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)) 220 | 221 | 222 | def build_squad_serving_input_fn(seq_length): 223 | """Builds a serving input fn for raw input.""" 224 | 225 | def _seq_serving_input_fn(): 226 | """Serving input fn for raw images.""" 227 | input_ids = tf.placeholder( 228 | shape=[1, seq_length], name="input_ids", dtype=tf.int32) 229 | input_mask = tf.placeholder( 230 | shape=[1, seq_length], name="input_mask", dtype=tf.int32) 231 | segment_ids = tf.placeholder( 232 | shape=[1, seq_length], name="segment_ids", dtype=tf.int32) 233 | 234 | inputs = { 235 | "input_ids": input_ids, 236 | "input_mask": input_mask, 237 | "segment_ids": segment_ids 238 | } 239 | return tf_estimator.export.ServingInputReceiver(features=inputs, 240 | receiver_tensors=inputs) 241 | 242 | return _seq_serving_input_fn 243 | 244 | 245 | def main(_): 246 | tf.logging.set_verbosity(tf.logging.INFO) 247 | 248 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) 249 | 250 | validate_flags_or_throw(albert_config) 251 | 252 | tf.gfile.MakeDirs(FLAGS.output_dir) 253 | 254 | tokenizer = fine_tuning_utils.create_vocab( 255 | vocab_file=FLAGS.vocab_file, 256 | do_lower_case=FLAGS.do_lower_case, 257 | spm_model_file=FLAGS.spm_model_file, 258 | hub_module=FLAGS.albert_hub_module_handle) 259 | 260 | tpu_cluster_resolver = None 261 | if FLAGS.use_tpu and FLAGS.tpu_name: 262 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( 263 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 264 | 265 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 266 | if FLAGS.do_train: 267 | iterations_per_loop = int(min(FLAGS.iterations_per_loop, 268 | FLAGS.save_checkpoints_steps)) 269 | else: 270 | iterations_per_loop = FLAGS.iterations_per_loop 271 | run_config = contrib_tpu.RunConfig( 272 | cluster=tpu_cluster_resolver, 273 | master=FLAGS.master, 274 | model_dir=FLAGS.output_dir, 275 | keep_checkpoint_max=0, 276 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 277 | tpu_config=contrib_tpu.TPUConfig( 278 | iterations_per_loop=iterations_per_loop, 279 | num_shards=FLAGS.num_tpu_cores, 280 | per_host_input_for_training=is_per_host)) 281 | 282 | train_examples = None 283 | num_train_steps = None 284 | num_warmup_steps = None 285 | if FLAGS.do_train: 286 | train_examples = squad_utils.read_squad_examples( 287 | input_file=FLAGS.train_file, is_training=True) 288 | num_train_steps = int( 289 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 290 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 291 | 292 | # Pre-shuffle the input to avoid having to make a very large shuffle 293 | # buffer in in the `input_fn`. 294 | rng = random.Random(12345) 295 | rng.shuffle(train_examples) 296 | 297 | model_fn = squad_utils.v1_model_fn_builder( 298 | albert_config=albert_config, 299 | init_checkpoint=FLAGS.init_checkpoint, 300 | learning_rate=FLAGS.learning_rate, 301 | num_train_steps=num_train_steps, 302 | num_warmup_steps=num_warmup_steps, 303 | use_tpu=FLAGS.use_tpu, 304 | use_one_hot_embeddings=FLAGS.use_tpu, 305 | use_einsum=FLAGS.use_einsum, 306 | hub_module=FLAGS.albert_hub_module_handle) 307 | 308 | # If TPU is not available, this will fall back to normal Estimator on CPU 309 | # or GPU. 310 | estimator = contrib_tpu.TPUEstimator( 311 | use_tpu=FLAGS.use_tpu, 312 | model_fn=model_fn, 313 | config=run_config, 314 | train_batch_size=FLAGS.train_batch_size, 315 | predict_batch_size=FLAGS.predict_batch_size) 316 | 317 | if FLAGS.do_train: 318 | # We write to a temporary file to avoid storing very large constant tensors 319 | # in memory. 320 | 321 | if not tf.gfile.Exists(FLAGS.train_feature_file): 322 | train_writer = squad_utils.FeatureWriter( 323 | filename=os.path.join(FLAGS.train_feature_file), is_training=True) 324 | squad_utils.convert_examples_to_features( 325 | examples=train_examples, 326 | tokenizer=tokenizer, 327 | max_seq_length=FLAGS.max_seq_length, 328 | doc_stride=FLAGS.doc_stride, 329 | max_query_length=FLAGS.max_query_length, 330 | is_training=True, 331 | output_fn=train_writer.process_feature, 332 | do_lower_case=FLAGS.do_lower_case) 333 | train_writer.close() 334 | 335 | tf.logging.info("***** Running training *****") 336 | tf.logging.info(" Num orig examples = %d", len(train_examples)) 337 | # tf.logging.info(" Num split examples = %d", train_writer.num_features) 338 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 339 | tf.logging.info(" Num steps = %d", num_train_steps) 340 | del train_examples 341 | 342 | train_input_fn = squad_utils.input_fn_builder( 343 | input_file=FLAGS.train_feature_file, 344 | seq_length=FLAGS.max_seq_length, 345 | is_training=True, 346 | drop_remainder=True, 347 | use_tpu=FLAGS.use_tpu, 348 | bsz=FLAGS.train_batch_size, 349 | is_v2=False) 350 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 351 | 352 | if FLAGS.do_predict: 353 | with tf.gfile.Open(FLAGS.predict_file) as predict_file: 354 | prediction_json = json.load(predict_file)["data"] 355 | 356 | eval_examples = squad_utils.read_squad_examples( 357 | input_file=FLAGS.predict_file, is_training=False) 358 | 359 | if (tf.gfile.Exists(FLAGS.predict_feature_file) and tf.gfile.Exists( 360 | FLAGS.predict_feature_left_file)): 361 | tf.logging.info("Loading eval features from {}".format( 362 | FLAGS.predict_feature_left_file)) 363 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin: 364 | eval_features = pickle.load(fin) 365 | else: 366 | eval_writer = squad_utils.FeatureWriter( 367 | filename=FLAGS.predict_feature_file, is_training=False) 368 | eval_features = [] 369 | 370 | def append_feature(feature): 371 | eval_features.append(feature) 372 | eval_writer.process_feature(feature) 373 | 374 | squad_utils.convert_examples_to_features( 375 | examples=eval_examples, 376 | tokenizer=tokenizer, 377 | max_seq_length=FLAGS.max_seq_length, 378 | doc_stride=FLAGS.doc_stride, 379 | max_query_length=FLAGS.max_query_length, 380 | is_training=False, 381 | output_fn=append_feature, 382 | do_lower_case=FLAGS.do_lower_case) 383 | eval_writer.close() 384 | 385 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout: 386 | pickle.dump(eval_features, fout) 387 | 388 | tf.logging.info("***** Running predictions *****") 389 | tf.logging.info(" Num orig examples = %d", len(eval_examples)) 390 | tf.logging.info(" Num split examples = %d", len(eval_features)) 391 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 392 | 393 | predict_input_fn = squad_utils.input_fn_builder( 394 | input_file=FLAGS.predict_feature_file, 395 | seq_length=FLAGS.max_seq_length, 396 | is_training=False, 397 | drop_remainder=False, 398 | use_tpu=FLAGS.use_tpu, 399 | bsz=FLAGS.predict_batch_size, 400 | is_v2=False) 401 | 402 | def get_result(checkpoint): 403 | """Evaluate the checkpoint on SQuAD 1.0.""" 404 | # If running eval on the TPU, you will need to specify the number of 405 | # steps. 406 | reader = tf.train.NewCheckpointReader(checkpoint) 407 | global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP) 408 | all_results = [] 409 | for result in estimator.predict( 410 | predict_input_fn, yield_single_examples=True, 411 | checkpoint_path=checkpoint): 412 | if len(all_results) % 1000 == 0: 413 | tf.logging.info("Processing example: %d" % (len(all_results))) 414 | unique_id = int(result["unique_ids"]) 415 | start_log_prob = [float(x) for x in result["start_log_prob"].flat] 416 | end_log_prob = [float(x) for x in result["end_log_prob"].flat] 417 | all_results.append( 418 | squad_utils.RawResult( 419 | unique_id=unique_id, 420 | start_log_prob=start_log_prob, 421 | end_log_prob=end_log_prob)) 422 | 423 | output_prediction_file = os.path.join( 424 | FLAGS.output_dir, "predictions.json") 425 | output_nbest_file = os.path.join( 426 | FLAGS.output_dir, "nbest_predictions.json") 427 | 428 | result_dict = {} 429 | squad_utils.accumulate_predictions_v1( 430 | result_dict, eval_examples, eval_features, 431 | all_results, FLAGS.n_best_size, FLAGS.max_answer_length) 432 | predictions = squad_utils.write_predictions_v1( 433 | result_dict, eval_examples, eval_features, all_results, 434 | FLAGS.n_best_size, FLAGS.max_answer_length, 435 | output_prediction_file, output_nbest_file) 436 | 437 | return squad_utils.evaluate_v1( 438 | prediction_json, predictions), int(global_step) 439 | 440 | def _find_valid_cands(curr_step): 441 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 442 | candidates = [] 443 | for filename in filenames: 444 | if filename.endswith(".index"): 445 | ckpt_name = filename[:-6] 446 | idx = ckpt_name.split("-")[-1] 447 | if idx != "best" and int(idx) > curr_step: 448 | candidates.append(filename) 449 | return candidates 450 | 451 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 452 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 453 | key_name = "f1" 454 | writer = tf.gfile.GFile(output_eval_file, "w") 455 | if tf.gfile.Exists(checkpoint_path + ".index"): 456 | result = get_result(checkpoint_path) 457 | best_perf = result[0][key_name] 458 | global_step = result[1] 459 | else: 460 | global_step = -1 461 | best_perf = -1 462 | checkpoint_path = None 463 | while global_step < num_train_steps: 464 | steps_and_files = {} 465 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 466 | for filename in filenames: 467 | if filename.endswith(".index"): 468 | ckpt_name = filename[:-6] 469 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 470 | if cur_filename.split("-")[-1] == "best": 471 | continue 472 | gstep = int(cur_filename.split("-")[-1]) 473 | if gstep not in steps_and_files: 474 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 475 | steps_and_files[gstep] = cur_filename 476 | tf.logging.info("found {} files.".format(len(steps_and_files))) 477 | if not steps_and_files: 478 | tf.logging.info("found 0 file, global step: {}. Sleeping." 479 | .format(global_step)) 480 | time.sleep(60) 481 | else: 482 | for ele in sorted(steps_and_files.items()): 483 | step, checkpoint_path = ele 484 | if global_step >= step: 485 | if len(_find_valid_cands(step)) > 1: 486 | for ext in ["meta", "data-00000-of-00001", "index"]: 487 | src_ckpt = checkpoint_path + ".{}".format(ext) 488 | tf.logging.info("removing {}".format(src_ckpt)) 489 | tf.gfile.Remove(src_ckpt) 490 | continue 491 | result, global_step = get_result(checkpoint_path) 492 | tf.logging.info("***** Eval results *****") 493 | for key in sorted(result.keys()): 494 | tf.logging.info(" %s = %s", key, str(result[key])) 495 | writer.write("%s = %s\n" % (key, str(result[key]))) 496 | if result[key_name] > best_perf: 497 | best_perf = result[key_name] 498 | for ext in ["meta", "data-00000-of-00001", "index"]: 499 | src_ckpt = checkpoint_path + ".{}".format(ext) 500 | tgt_ckpt = checkpoint_path.rsplit( 501 | "-", 1)[0] + "-best.{}".format(ext) 502 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt)) 503 | tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True) 504 | writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt)) 505 | writer.write("best {} = {}\n".format(key_name, best_perf)) 506 | tf.logging.info(" best {} = {}\n".format(key_name, best_perf)) 507 | 508 | if len(_find_valid_cands(global_step)) > 2: 509 | for ext in ["meta", "data-00000-of-00001", "index"]: 510 | src_ckpt = checkpoint_path + ".{}".format(ext) 511 | tf.logging.info("removing {}".format(src_ckpt)) 512 | tf.gfile.Remove(src_ckpt) 513 | writer.write("=" * 50 + "\n") 514 | 515 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 516 | result, global_step = get_result(checkpoint_path) 517 | tf.logging.info("***** Final Eval results *****") 518 | for key in sorted(result.keys()): 519 | tf.logging.info(" %s = %s", key, str(result[key])) 520 | writer.write("%s = %s\n" % (key, str(result[key]))) 521 | writer.write("best perf happened at step: {}".format(global_step)) 522 | 523 | if FLAGS.export_dir: 524 | tf.gfile.MakeDirs(FLAGS.export_dir) 525 | squad_serving_input_fn = ( 526 | build_squad_serving_input_fn(FLAGS.max_seq_length)) 527 | tf.logging.info("Starting to export model.") 528 | subfolder = estimator.export_saved_model( 529 | export_dir_base=os.path.join(FLAGS.export_dir, "saved_model"), 530 | serving_input_receiver_fn=squad_serving_input_fn) 531 | 532 | tf.logging.info("Starting to export TFLite.") 533 | converter = tf.lite.TFLiteConverter.from_saved_model( 534 | subfolder, 535 | input_arrays=["input_ids", "input_mask", "segment_ids"], 536 | output_arrays=["start_logits", "end_logits"]) 537 | float_model = converter.convert() 538 | tflite_file = os.path.join(FLAGS.export_dir, "albert_model.tflite") 539 | with tf.gfile.GFile(tflite_file, "wb") as f: 540 | f.write(float_model) 541 | 542 | 543 | if __name__ == "__main__": 544 | flags.mark_flag_as_required("spm_model_file") 545 | flags.mark_flag_as_required("albert_config_file") 546 | flags.mark_flag_as_required("output_dir") 547 | tf.app.run() 548 | -------------------------------------------------------------------------------- /run_squad_v2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run ALBERT on SQuAD v2.0 using sentence piece tokenization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | import json 23 | import os 24 | import random 25 | import time 26 | 27 | from albert import fine_tuning_utils 28 | from albert import modeling 29 | from albert import squad_utils 30 | import six 31 | import tensorflow.compat.v1 as tf 32 | 33 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver 34 | from tensorflow.contrib import tpu as contrib_tpu 35 | 36 | 37 | # pylint: disable=g-import-not-at-top 38 | if six.PY2: 39 | import six.moves.cPickle as pickle 40 | else: 41 | import pickle 42 | # pylint: enable=g-import-not-at-top 43 | 44 | flags = tf.flags 45 | 46 | FLAGS = flags.FLAGS 47 | 48 | ## Required parameters 49 | flags.DEFINE_string( 50 | "albert_config_file", None, 51 | "The config json file corresponding to the pre-trained ALBERT model. " 52 | "This specifies the model architecture.") 53 | 54 | flags.DEFINE_string("vocab_file", None, 55 | "The vocabulary file that the ALBERT model was trained on.") 56 | 57 | flags.DEFINE_string("spm_model_file", None, 58 | "The model file for sentence piece tokenization.") 59 | 60 | flags.DEFINE_string( 61 | "output_dir", None, 62 | "The output directory where the model checkpoints will be written.") 63 | 64 | ## Other parameters 65 | flags.DEFINE_string("train_file", None, 66 | "SQuAD json for training. E.g., train-v1.1.json") 67 | 68 | flags.DEFINE_string( 69 | "predict_file", None, 70 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 71 | 72 | flags.DEFINE_string("train_feature_file", None, 73 | "training feature file.") 74 | 75 | flags.DEFINE_string( 76 | "predict_feature_file", None, 77 | "Location of predict features. If it doesn't exist, it will be written. " 78 | "If it does exist, it will be read.") 79 | 80 | flags.DEFINE_string( 81 | "predict_feature_left_file", None, 82 | "Location of predict features not passed to TPU. If it doesn't exist, it " 83 | "will be written. If it does exist, it will be read.") 84 | 85 | flags.DEFINE_string( 86 | "init_checkpoint", None, 87 | "Initial checkpoint (usually from a pre-trained BERT model).") 88 | 89 | flags.DEFINE_string( 90 | "albert_hub_module_handle", None, 91 | "If set, the ALBERT hub module to use.") 92 | 93 | flags.DEFINE_bool( 94 | "do_lower_case", True, 95 | "Whether to lower case the input text. Should be True for uncased " 96 | "models and False for cased models.") 97 | 98 | flags.DEFINE_integer( 99 | "max_seq_length", 384, 100 | "The maximum total input sequence length after WordPiece tokenization. " 101 | "Sequences longer than this will be truncated, and sequences shorter " 102 | "than this will be padded.") 103 | 104 | flags.DEFINE_integer( 105 | "doc_stride", 128, 106 | "When splitting up a long document into chunks, how much stride to " 107 | "take between chunks.") 108 | 109 | flags.DEFINE_integer( 110 | "max_query_length", 64, 111 | "The maximum number of tokens for the question. Questions longer than " 112 | "this will be truncated to this length.") 113 | 114 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 115 | 116 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.") 117 | 118 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 119 | 120 | flags.DEFINE_integer("predict_batch_size", 8, 121 | "Total batch size for predictions.") 122 | 123 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 124 | 125 | flags.DEFINE_float("num_train_epochs", 3.0, 126 | "Total number of training epochs to perform.") 127 | 128 | flags.DEFINE_float( 129 | "warmup_proportion", 0.1, 130 | "Proportion of training to perform linear learning rate warmup for. " 131 | "E.g., 0.1 = 10% of training.") 132 | 133 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 134 | "How often to save the model checkpoint.") 135 | 136 | flags.DEFINE_integer("iterations_per_loop", 1000, 137 | "How many steps to make in each estimator call.") 138 | 139 | flags.DEFINE_integer( 140 | "n_best_size", 20, 141 | "The total number of n-best predictions to generate in the " 142 | "nbest_predictions.json output file.") 143 | 144 | flags.DEFINE_integer( 145 | "max_answer_length", 30, 146 | "The maximum length of an answer that can be generated. This is needed " 147 | "because the start and end predictions are not conditioned on one another.") 148 | 149 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 150 | 151 | tf.flags.DEFINE_string( 152 | "tpu_name", None, 153 | "The Cloud TPU to use for training. This should be either the name " 154 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 155 | "url.") 156 | 157 | tf.flags.DEFINE_string( 158 | "tpu_zone", None, 159 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 160 | "specified, we will attempt to automatically detect the GCE project from " 161 | "metadata.") 162 | 163 | tf.flags.DEFINE_string( 164 | "gcp_project", None, 165 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 166 | "specified, we will attempt to automatically detect the GCE project from " 167 | "metadata.") 168 | 169 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 170 | 171 | flags.DEFINE_integer( 172 | "num_tpu_cores", 8, 173 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 174 | 175 | 176 | flags.DEFINE_integer("start_n_top", 5, "beam size for the start positions.") 177 | 178 | flags.DEFINE_integer("end_n_top", 5, "beam size for the end positions.") 179 | 180 | flags.DEFINE_float("dropout_prob", 0.1, "dropout probability.") 181 | 182 | 183 | def validate_flags_or_throw(albert_config): 184 | """Validate the input FLAGS or throw an exception.""" 185 | 186 | if not FLAGS.do_train and not FLAGS.do_predict: 187 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 188 | 189 | if FLAGS.do_train: 190 | if not FLAGS.train_file: 191 | raise ValueError( 192 | "If `do_train` is True, then `train_file` must be specified.") 193 | if FLAGS.do_predict: 194 | if not FLAGS.predict_file: 195 | raise ValueError( 196 | "If `do_predict` is True, then `predict_file` must be specified.") 197 | if not FLAGS.predict_feature_file: 198 | raise ValueError( 199 | "If `do_predict` is True, then `predict_feature_file` must be " 200 | "specified.") 201 | if not FLAGS.predict_feature_left_file: 202 | raise ValueError( 203 | "If `do_predict` is True, then `predict_feature_left_file` must be " 204 | "specified.") 205 | 206 | if FLAGS.max_seq_length > albert_config.max_position_embeddings: 207 | raise ValueError( 208 | "Cannot use sequence length %d because the ALBERT model " 209 | "was only trained up to sequence length %d" % 210 | (FLAGS.max_seq_length, albert_config.max_position_embeddings)) 211 | 212 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3: 213 | raise ValueError( 214 | "The max_seq_length (%d) must be greater than max_query_length " 215 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)) 216 | 217 | 218 | def main(_): 219 | tf.logging.set_verbosity(tf.logging.INFO) 220 | 221 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) 222 | 223 | validate_flags_or_throw(albert_config) 224 | 225 | tf.gfile.MakeDirs(FLAGS.output_dir) 226 | 227 | tokenizer = fine_tuning_utils.create_vocab( 228 | vocab_file=FLAGS.vocab_file, 229 | do_lower_case=FLAGS.do_lower_case, 230 | spm_model_file=FLAGS.spm_model_file, 231 | hub_module=FLAGS.albert_hub_module_handle) 232 | 233 | tpu_cluster_resolver = None 234 | if FLAGS.use_tpu and FLAGS.tpu_name: 235 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( 236 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 237 | 238 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 239 | if FLAGS.do_train: 240 | iterations_per_loop = int(min(FLAGS.iterations_per_loop, 241 | FLAGS.save_checkpoints_steps)) 242 | else: 243 | iterations_per_loop = FLAGS.iterations_per_loop 244 | run_config = contrib_tpu.RunConfig( 245 | cluster=tpu_cluster_resolver, 246 | master=FLAGS.master, 247 | model_dir=FLAGS.output_dir, 248 | keep_checkpoint_max=0, 249 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 250 | tpu_config=contrib_tpu.TPUConfig( 251 | iterations_per_loop=iterations_per_loop, 252 | num_shards=FLAGS.num_tpu_cores, 253 | per_host_input_for_training=is_per_host)) 254 | 255 | train_examples = None 256 | num_train_steps = None 257 | num_warmup_steps = None 258 | train_examples = squad_utils.read_squad_examples( 259 | input_file=FLAGS.train_file, is_training=True) 260 | num_train_steps = int( 261 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 262 | if FLAGS.do_train: 263 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 264 | 265 | # Pre-shuffle the input to avoid having to make a very large shuffle 266 | # buffer in in the `input_fn`. 267 | rng = random.Random(12345) 268 | rng.shuffle(train_examples) 269 | 270 | model_fn = squad_utils.v2_model_fn_builder( 271 | albert_config=albert_config, 272 | init_checkpoint=FLAGS.init_checkpoint, 273 | learning_rate=FLAGS.learning_rate, 274 | num_train_steps=num_train_steps, 275 | num_warmup_steps=num_warmup_steps, 276 | use_tpu=FLAGS.use_tpu, 277 | use_one_hot_embeddings=FLAGS.use_tpu, 278 | max_seq_length=FLAGS.max_seq_length, 279 | start_n_top=FLAGS.start_n_top, 280 | end_n_top=FLAGS.end_n_top, 281 | dropout_prob=FLAGS.dropout_prob, 282 | hub_module=FLAGS.albert_hub_module_handle) 283 | 284 | # If TPU is not available, this will fall back to normal Estimator on CPU 285 | # or GPU. 286 | estimator = contrib_tpu.TPUEstimator( 287 | use_tpu=FLAGS.use_tpu, 288 | model_fn=model_fn, 289 | config=run_config, 290 | train_batch_size=FLAGS.train_batch_size, 291 | predict_batch_size=FLAGS.predict_batch_size) 292 | 293 | if FLAGS.do_train: 294 | # We write to a temporary file to avoid storing very large constant tensors 295 | # in memory. 296 | 297 | if not tf.gfile.Exists(FLAGS.train_feature_file): 298 | train_writer = squad_utils.FeatureWriter( 299 | filename=os.path.join(FLAGS.train_feature_file), is_training=True) 300 | squad_utils.convert_examples_to_features( 301 | examples=train_examples, 302 | tokenizer=tokenizer, 303 | max_seq_length=FLAGS.max_seq_length, 304 | doc_stride=FLAGS.doc_stride, 305 | max_query_length=FLAGS.max_query_length, 306 | is_training=True, 307 | output_fn=train_writer.process_feature, 308 | do_lower_case=FLAGS.do_lower_case) 309 | train_writer.close() 310 | 311 | tf.logging.info("***** Running training *****") 312 | tf.logging.info(" Num orig examples = %d", len(train_examples)) 313 | # tf.logging.info(" Num split examples = %d", train_writer.num_features) 314 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 315 | tf.logging.info(" Num steps = %d", num_train_steps) 316 | del train_examples 317 | 318 | train_input_fn = squad_utils.input_fn_builder( 319 | input_file=FLAGS.train_feature_file, 320 | seq_length=FLAGS.max_seq_length, 321 | is_training=True, 322 | drop_remainder=True, 323 | use_tpu=FLAGS.use_tpu, 324 | bsz=FLAGS.train_batch_size, 325 | is_v2=True) 326 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 327 | 328 | if FLAGS.do_predict: 329 | with tf.gfile.Open(FLAGS.predict_file) as predict_file: 330 | prediction_json = json.load(predict_file)["data"] 331 | eval_examples = squad_utils.read_squad_examples( 332 | input_file=FLAGS.predict_file, is_training=False) 333 | 334 | if (tf.gfile.Exists(FLAGS.predict_feature_file) and tf.gfile.Exists( 335 | FLAGS.predict_feature_left_file)): 336 | tf.logging.info("Loading eval features from {}".format( 337 | FLAGS.predict_feature_left_file)) 338 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin: 339 | eval_features = pickle.load(fin) 340 | else: 341 | eval_writer = squad_utils.FeatureWriter( 342 | filename=FLAGS.predict_feature_file, is_training=False) 343 | eval_features = [] 344 | 345 | def append_feature(feature): 346 | eval_features.append(feature) 347 | eval_writer.process_feature(feature) 348 | 349 | squad_utils.convert_examples_to_features( 350 | examples=eval_examples, 351 | tokenizer=tokenizer, 352 | max_seq_length=FLAGS.max_seq_length, 353 | doc_stride=FLAGS.doc_stride, 354 | max_query_length=FLAGS.max_query_length, 355 | is_training=False, 356 | output_fn=append_feature, 357 | do_lower_case=FLAGS.do_lower_case) 358 | eval_writer.close() 359 | 360 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout: 361 | pickle.dump(eval_features, fout) 362 | 363 | tf.logging.info("***** Running predictions *****") 364 | tf.logging.info(" Num orig examples = %d", len(eval_examples)) 365 | tf.logging.info(" Num split examples = %d", len(eval_features)) 366 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 367 | 368 | predict_input_fn = squad_utils.input_fn_builder( 369 | input_file=FLAGS.predict_feature_file, 370 | seq_length=FLAGS.max_seq_length, 371 | is_training=False, 372 | drop_remainder=False, 373 | use_tpu=FLAGS.use_tpu, 374 | bsz=FLAGS.predict_batch_size, 375 | is_v2=True) 376 | 377 | def get_result(checkpoint): 378 | """Evaluate the checkpoint on SQuAD v2.0.""" 379 | # If running eval on the TPU, you will need to specify the number of 380 | # steps. 381 | reader = tf.train.NewCheckpointReader(checkpoint) 382 | global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP) 383 | all_results = [] 384 | for result in estimator.predict( 385 | predict_input_fn, yield_single_examples=True, 386 | checkpoint_path=checkpoint): 387 | if len(all_results) % 1000 == 0: 388 | tf.logging.info("Processing example: %d" % (len(all_results))) 389 | unique_id = int(result["unique_ids"]) 390 | start_top_log_probs = ( 391 | [float(x) for x in result["start_top_log_probs"].flat]) 392 | start_top_index = [int(x) for x in result["start_top_index"].flat] 393 | end_top_log_probs = ( 394 | [float(x) for x in result["end_top_log_probs"].flat]) 395 | end_top_index = [int(x) for x in result["end_top_index"].flat] 396 | 397 | cls_logits = float(result["cls_logits"].flat[0]) 398 | all_results.append( 399 | squad_utils.RawResultV2( 400 | unique_id=unique_id, 401 | start_top_log_probs=start_top_log_probs, 402 | start_top_index=start_top_index, 403 | end_top_log_probs=end_top_log_probs, 404 | end_top_index=end_top_index, 405 | cls_logits=cls_logits)) 406 | 407 | output_prediction_file = os.path.join( 408 | FLAGS.output_dir, "predictions.json") 409 | output_nbest_file = os.path.join( 410 | FLAGS.output_dir, "nbest_predictions.json") 411 | output_null_log_odds_file = os.path.join( 412 | FLAGS.output_dir, "null_odds.json") 413 | 414 | result_dict = {} 415 | cls_dict = {} 416 | squad_utils.accumulate_predictions_v2( 417 | result_dict, cls_dict, eval_examples, eval_features, 418 | all_results, FLAGS.n_best_size, FLAGS.max_answer_length, 419 | FLAGS.start_n_top, FLAGS.end_n_top) 420 | 421 | return squad_utils.evaluate_v2( 422 | result_dict, cls_dict, prediction_json, eval_examples, 423 | eval_features, all_results, FLAGS.n_best_size, 424 | FLAGS.max_answer_length, output_prediction_file, output_nbest_file, 425 | output_null_log_odds_file), int(global_step) 426 | 427 | def _find_valid_cands(curr_step): 428 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 429 | candidates = [] 430 | for filename in filenames: 431 | if filename.endswith(".index"): 432 | ckpt_name = filename[:-6] 433 | idx = ckpt_name.split("-")[-1] 434 | if idx != "best" and int(idx) > curr_step: 435 | candidates.append(filename) 436 | return candidates 437 | 438 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 439 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 440 | key_name = "f1" 441 | writer = tf.gfile.GFile(output_eval_file, "w") 442 | if tf.gfile.Exists(checkpoint_path + ".index"): 443 | result = get_result(checkpoint_path) 444 | best_perf = result[0][key_name] 445 | global_step = result[1] 446 | else: 447 | global_step = -1 448 | best_perf = -1 449 | checkpoint_path = None 450 | while global_step < num_train_steps: 451 | steps_and_files = {} 452 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 453 | for filename in filenames: 454 | if filename.endswith(".index"): 455 | ckpt_name = filename[:-6] 456 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 457 | if cur_filename.split("-")[-1] == "best": 458 | continue 459 | gstep = int(cur_filename.split("-")[-1]) 460 | if gstep not in steps_and_files: 461 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 462 | steps_and_files[gstep] = cur_filename 463 | tf.logging.info("found {} files.".format(len(steps_and_files))) 464 | if not steps_and_files: 465 | tf.logging.info("found 0 file, global step: {}. Sleeping." 466 | .format(global_step)) 467 | time.sleep(60) 468 | else: 469 | for ele in sorted(steps_and_files.items()): 470 | step, checkpoint_path = ele 471 | if global_step >= step: 472 | if len(_find_valid_cands(step)) > 1: 473 | for ext in ["meta", "data-00000-of-00001", "index"]: 474 | src_ckpt = checkpoint_path + ".{}".format(ext) 475 | tf.logging.info("removing {}".format(src_ckpt)) 476 | tf.gfile.Remove(src_ckpt) 477 | continue 478 | result, global_step = get_result(checkpoint_path) 479 | tf.logging.info("***** Eval results *****") 480 | for key in sorted(result.keys()): 481 | tf.logging.info(" %s = %s", key, str(result[key])) 482 | writer.write("%s = %s\n" % (key, str(result[key]))) 483 | if result[key_name] > best_perf: 484 | best_perf = result[key_name] 485 | for ext in ["meta", "data-00000-of-00001", "index"]: 486 | src_ckpt = checkpoint_path + ".{}".format(ext) 487 | tgt_ckpt = checkpoint_path.rsplit( 488 | "-", 1)[0] + "-best.{}".format(ext) 489 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt)) 490 | tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True) 491 | writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt)) 492 | writer.write("best {} = {}\n".format(key_name, best_perf)) 493 | tf.logging.info(" best {} = {}\n".format(key_name, best_perf)) 494 | 495 | if len(_find_valid_cands(global_step)) > 2: 496 | for ext in ["meta", "data-00000-of-00001", "index"]: 497 | src_ckpt = checkpoint_path + ".{}".format(ext) 498 | tf.logging.info("removing {}".format(src_ckpt)) 499 | tf.gfile.Remove(src_ckpt) 500 | writer.write("=" * 50 + "\n") 501 | 502 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 503 | result, global_step = get_result(checkpoint_path) 504 | tf.logging.info("***** Final Eval results *****") 505 | for key in sorted(result.keys()): 506 | tf.logging.info(" %s = %s", key, str(result[key])) 507 | writer.write("%s = %s\n" % (key, str(result[key]))) 508 | writer.write("best perf happened at step: {}".format(global_step)) 509 | 510 | 511 | if __name__ == "__main__": 512 | flags.mark_flag_as_required("spm_model_file") 513 | flags.mark_flag_as_required("albert_config_file") 514 | flags.mark_flag_as_required("output_dir") 515 | tf.app.run() 516 | -------------------------------------------------------------------------------- /run_trivial_model_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Small integration test script. 3 | # The values in this file are **not** meant for reproducing actual results. 4 | 5 | set -e 6 | set -x 7 | 8 | virtualenv -p python3 . 9 | source ./bin/activate 10 | 11 | OUTPUT_DIR_BASE="$(mktemp -d)" 12 | OUTPUT_DIR="${OUTPUT_DIR_BASE}/output" 13 | 14 | pip install numpy 15 | pip install -r requirements.txt 16 | python -m run_pretraining_test \ 17 | --output_dir="${OUTPUT_DIR}" \ 18 | --do_train \ 19 | --do_eval \ 20 | --nouse_tpu \ 21 | --train_batch_size=2 \ 22 | --eval_batch_size=1 \ 23 | --max_seq_length=4 \ 24 | --num_train_steps=2 \ 25 | --max_eval_steps=3 26 | 27 | 28 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # coding=utf-8 16 | """Tokenization classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import unicodedata 24 | import six 25 | from six.moves import range 26 | import tensorflow.compat.v1 as tf 27 | import tensorflow_hub as hub 28 | import sentencepiece as spm 29 | 30 | SPIECE_UNDERLINE = u"▁".encode("utf-8") 31 | 32 | 33 | def preprocess_text(inputs, remove_space=True, lower=False): 34 | """preprocess data by removing extra space and normalize data.""" 35 | outputs = inputs 36 | if remove_space: 37 | outputs = " ".join(inputs.strip().split()) 38 | 39 | if six.PY2 and isinstance(outputs, str): 40 | try: 41 | outputs = six.ensure_text(outputs, "utf-8") 42 | except UnicodeDecodeError: 43 | outputs = six.ensure_text(outputs, "latin-1") 44 | 45 | outputs = unicodedata.normalize("NFKD", outputs) 46 | outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) 47 | if lower: 48 | outputs = outputs.lower() 49 | 50 | return outputs 51 | 52 | 53 | def encode_pieces(sp_model, text, return_unicode=True, sample=False): 54 | """turn sentences into word pieces.""" 55 | 56 | if six.PY2 and isinstance(text, six.text_type): 57 | text = six.ensure_binary(text, "utf-8") 58 | 59 | if not sample: 60 | pieces = sp_model.EncodeAsPieces(text) 61 | else: 62 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) 63 | new_pieces = [] 64 | for piece in pieces: 65 | piece = printable_text(piece) 66 | if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit(): 67 | cur_pieces = sp_model.EncodeAsPieces( 68 | six.ensure_binary(piece[:-1]).replace(SPIECE_UNDERLINE, b"")) 69 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 70 | if len(cur_pieces[0]) == 1: 71 | cur_pieces = cur_pieces[1:] 72 | else: 73 | cur_pieces[0] = cur_pieces[0][1:] 74 | cur_pieces.append(piece[-1]) 75 | new_pieces.extend(cur_pieces) 76 | else: 77 | new_pieces.append(piece) 78 | 79 | # note(zhiliny): convert back to unicode for py2 80 | if six.PY2 and return_unicode: 81 | ret_pieces = [] 82 | for piece in new_pieces: 83 | if isinstance(piece, str): 84 | piece = six.ensure_text(piece, "utf-8") 85 | ret_pieces.append(piece) 86 | new_pieces = ret_pieces 87 | 88 | return new_pieces 89 | 90 | 91 | def encode_ids(sp_model, text, sample=False): 92 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) 93 | ids = [sp_model.PieceToId(piece) for piece in pieces] 94 | return ids 95 | 96 | 97 | def convert_to_unicode(text): 98 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 99 | if six.PY3: 100 | if isinstance(text, str): 101 | return text 102 | elif isinstance(text, bytes): 103 | return six.ensure_text(text, "utf-8", "ignore") 104 | else: 105 | raise ValueError("Unsupported string type: %s" % (type(text))) 106 | elif six.PY2: 107 | if isinstance(text, str): 108 | return six.ensure_text(text, "utf-8", "ignore") 109 | elif isinstance(text, six.text_type): 110 | return text 111 | else: 112 | raise ValueError("Unsupported string type: %s" % (type(text))) 113 | else: 114 | raise ValueError("Not running on Python2 or Python 3?") 115 | 116 | 117 | def printable_text(text): 118 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 119 | 120 | # These functions want `str` for both Python2 and Python3, but in one case 121 | # it's a Unicode string and in the other it's a byte string. 122 | if six.PY3: 123 | if isinstance(text, str): 124 | return text 125 | elif isinstance(text, bytes): 126 | return six.ensure_text(text, "utf-8", "ignore") 127 | else: 128 | raise ValueError("Unsupported string type: %s" % (type(text))) 129 | elif six.PY2: 130 | if isinstance(text, str): 131 | return text 132 | elif isinstance(text, six.text_type): 133 | return six.ensure_binary(text, "utf-8") 134 | else: 135 | raise ValueError("Unsupported string type: %s" % (type(text))) 136 | else: 137 | raise ValueError("Not running on Python2 or Python 3?") 138 | 139 | 140 | def load_vocab(vocab_file): 141 | """Loads a vocabulary file into a dictionary.""" 142 | vocab = collections.OrderedDict() 143 | with tf.gfile.GFile(vocab_file, "r") as reader: 144 | while True: 145 | token = convert_to_unicode(reader.readline()) 146 | if not token: 147 | break 148 | token = token.strip().split()[0] if token.strip() else " " 149 | if token not in vocab: 150 | vocab[token] = len(vocab) 151 | return vocab 152 | 153 | 154 | def convert_by_vocab(vocab, items): 155 | """Converts a sequence of [tokens|ids] using the vocab.""" 156 | output = [] 157 | for item in items: 158 | output.append(vocab[item]) 159 | return output 160 | 161 | 162 | def convert_tokens_to_ids(vocab, tokens): 163 | return convert_by_vocab(vocab, tokens) 164 | 165 | 166 | def convert_ids_to_tokens(inv_vocab, ids): 167 | return convert_by_vocab(inv_vocab, ids) 168 | 169 | 170 | def whitespace_tokenize(text): 171 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 172 | text = text.strip() 173 | if not text: 174 | return [] 175 | tokens = text.split() 176 | return tokens 177 | 178 | 179 | class FullTokenizer(object): 180 | """Runs end-to-end tokenziation.""" 181 | 182 | def __init__(self, vocab_file, do_lower_case=True, spm_model_file=None): 183 | self.vocab = None 184 | self.sp_model = None 185 | if spm_model_file: 186 | self.sp_model = spm.SentencePieceProcessor() 187 | tf.logging.info("loading sentence piece model") 188 | # Handle cases where SP can't load the file, but gfile can. 189 | sp_model_ = tf.gfile.GFile(spm_model_file, "rb").read() 190 | self.sp_model.LoadFromSerializedProto(sp_model_) 191 | # Note(mingdachen): For the purpose of consisent API, we are 192 | # generating a vocabulary for the sentence piece tokenizer. 193 | self.vocab = {self.sp_model.IdToPiece(i): i for i 194 | in range(self.sp_model.GetPieceSize())} 195 | else: 196 | self.vocab = load_vocab(vocab_file) 197 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 198 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 199 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 200 | 201 | @classmethod 202 | def from_scratch(cls, vocab_file, do_lower_case, spm_model_file): 203 | return FullTokenizer(vocab_file, do_lower_case, spm_model_file) 204 | 205 | @classmethod 206 | def from_hub_module(cls, hub_module, use_spm=True): 207 | """Get the vocab file and casing info from the Hub module.""" 208 | with tf.Graph().as_default(): 209 | albert_module = hub.Module(hub_module) 210 | tokenization_info = albert_module(signature="tokenization_info", 211 | as_dict=True) 212 | with tf.Session() as sess: 213 | vocab_file, do_lower_case = sess.run( 214 | [tokenization_info["vocab_file"], 215 | tokenization_info["do_lower_case"]]) 216 | if use_spm: 217 | spm_model_file = vocab_file 218 | vocab_file = None 219 | return FullTokenizer( 220 | vocab_file=vocab_file, do_lower_case=do_lower_case, 221 | spm_model_file=spm_model_file) 222 | 223 | def tokenize(self, text): 224 | if self.sp_model: 225 | split_tokens = encode_pieces(self.sp_model, text, return_unicode=False) 226 | else: 227 | split_tokens = [] 228 | for token in self.basic_tokenizer.tokenize(text): 229 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 230 | split_tokens.append(sub_token) 231 | 232 | return split_tokens 233 | 234 | def convert_tokens_to_ids(self, tokens): 235 | if self.sp_model: 236 | tf.logging.info("using sentence piece tokenzier.") 237 | return [self.sp_model.PieceToId( 238 | printable_text(token)) for token in tokens] 239 | else: 240 | return convert_by_vocab(self.vocab, tokens) 241 | 242 | def convert_ids_to_tokens(self, ids): 243 | if self.sp_model: 244 | tf.logging.info("using sentence piece tokenzier.") 245 | return [self.sp_model.IdToPiece(id_) for id_ in ids] 246 | else: 247 | return convert_by_vocab(self.inv_vocab, ids) 248 | 249 | 250 | class BasicTokenizer(object): 251 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 252 | 253 | def __init__(self, do_lower_case=True): 254 | """Constructs a BasicTokenizer. 255 | 256 | Args: 257 | do_lower_case: Whether to lower case the input. 258 | """ 259 | self.do_lower_case = do_lower_case 260 | 261 | def tokenize(self, text): 262 | """Tokenizes a piece of text.""" 263 | text = convert_to_unicode(text) 264 | text = self._clean_text(text) 265 | 266 | # This was added on November 1st, 2018 for the multilingual and Chinese 267 | # models. This is also applied to the English models now, but it doesn't 268 | # matter since the English models were not trained on any Chinese data 269 | # and generally don't have any Chinese data in them (there are Chinese 270 | # characters in the vocabulary because Wikipedia does have some Chinese 271 | # words in the English Wikipedia.). 272 | text = self._tokenize_chinese_chars(text) 273 | 274 | orig_tokens = whitespace_tokenize(text) 275 | split_tokens = [] 276 | for token in orig_tokens: 277 | if self.do_lower_case: 278 | token = token.lower() 279 | token = self._run_strip_accents(token) 280 | split_tokens.extend(self._run_split_on_punc(token)) 281 | 282 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 283 | return output_tokens 284 | 285 | def _run_strip_accents(self, text): 286 | """Strips accents from a piece of text.""" 287 | text = unicodedata.normalize("NFD", text) 288 | output = [] 289 | for char in text: 290 | cat = unicodedata.category(char) 291 | if cat == "Mn": 292 | continue 293 | output.append(char) 294 | return "".join(output) 295 | 296 | def _run_split_on_punc(self, text): 297 | """Splits punctuation on a piece of text.""" 298 | chars = list(text) 299 | i = 0 300 | start_new_word = True 301 | output = [] 302 | while i < len(chars): 303 | char = chars[i] 304 | if _is_punctuation(char): 305 | output.append([char]) 306 | start_new_word = True 307 | else: 308 | if start_new_word: 309 | output.append([]) 310 | start_new_word = False 311 | output[-1].append(char) 312 | i += 1 313 | 314 | return ["".join(x) for x in output] 315 | 316 | def _tokenize_chinese_chars(self, text): 317 | """Adds whitespace around any CJK character.""" 318 | output = [] 319 | for char in text: 320 | cp = ord(char) 321 | if self._is_chinese_char(cp): 322 | output.append(" ") 323 | output.append(char) 324 | output.append(" ") 325 | else: 326 | output.append(char) 327 | return "".join(output) 328 | 329 | def _is_chinese_char(self, cp): 330 | """Checks whether CP is the codepoint of a CJK character.""" 331 | # This defines a "chinese character" as anything in the CJK Unicode block: 332 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 333 | # 334 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 335 | # despite its name. The modern Korean Hangul alphabet is a different block, 336 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 337 | # space-separated words, so they are not treated specially and handled 338 | # like the all of the other languages. 339 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 340 | (cp >= 0x3400 and cp <= 0x4DBF) or # 341 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 342 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 343 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 344 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 345 | (cp >= 0xF900 and cp <= 0xFAFF) or # 346 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 347 | return True 348 | 349 | return False 350 | 351 | def _clean_text(self, text): 352 | """Performs invalid character removal and whitespace cleanup on text.""" 353 | output = [] 354 | for char in text: 355 | cp = ord(char) 356 | if cp == 0 or cp == 0xfffd or _is_control(char): 357 | continue 358 | if _is_whitespace(char): 359 | output.append(" ") 360 | else: 361 | output.append(char) 362 | return "".join(output) 363 | 364 | 365 | class WordpieceTokenizer(object): 366 | """Runs WordPiece tokenziation.""" 367 | 368 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 369 | self.vocab = vocab 370 | self.unk_token = unk_token 371 | self.max_input_chars_per_word = max_input_chars_per_word 372 | 373 | def tokenize(self, text): 374 | """Tokenizes a piece of text into its word pieces. 375 | 376 | This uses a greedy longest-match-first algorithm to perform tokenization 377 | using the given vocabulary. 378 | 379 | For example: 380 | input = "unaffable" 381 | output = ["un", "##aff", "##able"] 382 | 383 | Args: 384 | text: A single token or whitespace separated tokens. This should have 385 | already been passed through `BasicTokenizer. 386 | 387 | Returns: 388 | A list of wordpiece tokens. 389 | """ 390 | 391 | text = convert_to_unicode(text) 392 | 393 | output_tokens = [] 394 | for token in whitespace_tokenize(text): 395 | chars = list(token) 396 | if len(chars) > self.max_input_chars_per_word: 397 | output_tokens.append(self.unk_token) 398 | continue 399 | 400 | is_bad = False 401 | start = 0 402 | sub_tokens = [] 403 | while start < len(chars): 404 | end = len(chars) 405 | cur_substr = None 406 | while start < end: 407 | substr = "".join(chars[start:end]) 408 | if start > 0: 409 | substr = "##" + six.ensure_str(substr) 410 | if substr in self.vocab: 411 | cur_substr = substr 412 | break 413 | end -= 1 414 | if cur_substr is None: 415 | is_bad = True 416 | break 417 | sub_tokens.append(cur_substr) 418 | start = end 419 | 420 | if is_bad: 421 | output_tokens.append(self.unk_token) 422 | else: 423 | output_tokens.extend(sub_tokens) 424 | return output_tokens 425 | 426 | 427 | def _is_whitespace(char): 428 | """Checks whether `chars` is a whitespace character.""" 429 | # \t, \n, and \r are technically control characters but we treat them 430 | # as whitespace since they are generally considered as such. 431 | if char == " " or char == "\t" or char == "\n" or char == "\r": 432 | return True 433 | cat = unicodedata.category(char) 434 | if cat == "Zs": 435 | return True 436 | return False 437 | 438 | 439 | def _is_control(char): 440 | """Checks whether `chars` is a control character.""" 441 | # These are technically control characters but we count them as whitespace 442 | # characters. 443 | if char == "\t" or char == "\n" or char == "\r": 444 | return False 445 | cat = unicodedata.category(char) 446 | if cat in ("Cc", "Cf"): 447 | return True 448 | return False 449 | 450 | 451 | def _is_punctuation(char): 452 | """Checks whether `chars` is a punctuation character.""" 453 | cp = ord(char) 454 | # We treat all non-letter/number ASCII as punctuation. 455 | # Characters such as "^", "$", and "`" are not in the Unicode 456 | # Punctuation class but we treat them as punctuation anyways, for 457 | # consistency. 458 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 459 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 460 | return True 461 | cat = unicodedata.category(char) 462 | if cat.startswith("P"): 463 | return True 464 | return False 465 | -------------------------------------------------------------------------------- /tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | import os 19 | import tempfile 20 | from albert import tokenization 21 | import six 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | class TokenizationTest(tf.test.TestCase): 26 | 27 | def test_full_tokenizer(self): 28 | vocab_tokens = [ 29 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 30 | "##ing", "," 31 | ] 32 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 33 | if six.PY2: 34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 35 | else: 36 | contents = "".join([six.ensure_str(x) + "\n" for x in vocab_tokens]) 37 | vocab_writer.write(six.ensure_binary(contents, "utf-8")) 38 | 39 | vocab_file = vocab_writer.name 40 | 41 | tokenizer = tokenization.FullTokenizer(vocab_file) 42 | os.unlink(vocab_file) 43 | 44 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 45 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 46 | 47 | self.assertAllEqual( 48 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 49 | 50 | def test_chinese(self): 51 | tokenizer = tokenization.BasicTokenizer() 52 | 53 | self.assertAllEqual( 54 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 55 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 56 | 57 | def test_basic_tokenizer_lower(self): 58 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 59 | 60 | self.assertAllEqual( 61 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 62 | ["hello", "!", "how", "are", "you", "?"]) 63 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 64 | 65 | def test_basic_tokenizer_no_lower(self): 66 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 67 | 68 | self.assertAllEqual( 69 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 70 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 71 | 72 | def test_wordpiece_tokenizer(self): 73 | vocab_tokens = [ 74 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 75 | "##ing" 76 | ] 77 | 78 | vocab = {} 79 | for (i, token) in enumerate(vocab_tokens): 80 | vocab[token] = i 81 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 82 | 83 | self.assertAllEqual(tokenizer.tokenize(""), []) 84 | 85 | self.assertAllEqual( 86 | tokenizer.tokenize("unwanted running"), 87 | ["un", "##want", "##ed", "runn", "##ing"]) 88 | 89 | self.assertAllEqual( 90 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 91 | 92 | def test_convert_tokens_to_ids(self): 93 | vocab_tokens = [ 94 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 95 | "##ing" 96 | ] 97 | 98 | vocab = {} 99 | for (i, token) in enumerate(vocab_tokens): 100 | vocab[token] = i 101 | 102 | self.assertAllEqual( 103 | tokenization.convert_tokens_to_ids( 104 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 105 | 106 | def test_is_whitespace(self): 107 | self.assertTrue(tokenization._is_whitespace(u" ")) 108 | self.assertTrue(tokenization._is_whitespace(u"\t")) 109 | self.assertTrue(tokenization._is_whitespace(u"\r")) 110 | self.assertTrue(tokenization._is_whitespace(u"\n")) 111 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 112 | 113 | self.assertFalse(tokenization._is_whitespace(u"A")) 114 | self.assertFalse(tokenization._is_whitespace(u"-")) 115 | 116 | def test_is_control(self): 117 | self.assertTrue(tokenization._is_control(u"\u0005")) 118 | 119 | self.assertFalse(tokenization._is_control(u"A")) 120 | self.assertFalse(tokenization._is_control(u" ")) 121 | self.assertFalse(tokenization._is_control(u"\t")) 122 | self.assertFalse(tokenization._is_control(u"\r")) 123 | self.assertFalse(tokenization._is_control(u"\U0001F4A9")) 124 | 125 | def test_is_punctuation(self): 126 | self.assertTrue(tokenization._is_punctuation(u"-")) 127 | self.assertTrue(tokenization._is_punctuation(u"$")) 128 | self.assertTrue(tokenization._is_punctuation(u"`")) 129 | self.assertTrue(tokenization._is_punctuation(u".")) 130 | 131 | self.assertFalse(tokenization._is_punctuation(u"A")) 132 | self.assertFalse(tokenization._is_punctuation(u" ")) 133 | 134 | 135 | if __name__ == "__main__": 136 | tf.test.main() 137 | --------------------------------------------------------------------------------