├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── __init__.py
├── albert_glue_fine_tuning_tutorial.ipynb
├── classifier_utils.py
├── create_pretraining_data.py
├── export_checkpoints.py
├── export_to_tfhub.py
├── fine_tuning_utils.py
├── lamb_optimizer.py
├── modeling.py
├── modeling_test.py
├── optimization.py
├── optimization_test.py
├── race_utils.py
├── requirements.txt
├── run_classifier.py
├── run_glue.sh
├── run_pretraining.py
├── run_pretraining_test.py
├── run_race.py
├── run_squad_v1.py
├── run_squad_v2.py
├── run_trivial_model_test.sh
├── squad_utils.py
├── tokenization.py
└── tokenization_test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Initially taken from Github's Python gitignore file
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # celery beat schedule file
86 | celerybeat-schedule
87 |
88 | # SageMath parsed files
89 | *.sage.py
90 |
91 | # Environments
92 | .env
93 | .venv
94 | env/
95 | venv/
96 | ENV/
97 | env.bak/
98 | venv.bak/
99 |
100 | # Spyder project settings
101 | .spyderproject
102 | .spyproject
103 |
104 | # Rope project settings
105 | .ropeproject
106 |
107 | # mkdocs documentation
108 | /site
109 |
110 | # mypy
111 | .mypy_cache/
112 | .dmypy.json
113 | dmypy.json
114 |
115 | # Pyre type checker
116 | .pyre/
117 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows
28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ALBERT
2 | ======
3 |
4 | ***************New March 28, 2020 ***************
5 |
6 | Add a colab [tutorial](https://github.com/google-research/albert/blob/master/albert_glue_fine_tuning_tutorial.ipynb) to run fine-tuning for GLUE datasets.
7 |
8 | ***************New January 7, 2020 ***************
9 |
10 | v2 TF-Hub models should be working now with TF 1.15, as we removed the
11 | native Einsum op from the graph. See updated TF-Hub links below.
12 |
13 | ***************New December 30, 2019 ***************
14 |
15 | Chinese models are released. We would like to thank [CLUE team ](https://github.com/CLUEbenchmark/CLUE) for providing the training data.
16 |
17 | - [Base](https://storage.googleapis.com/albert_models/albert_base_zh.tar.gz)
18 | - [Large](https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz)
19 | - [Xlarge](https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz)
20 | - [Xxlarge](https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz)
21 |
22 | Version 2 of ALBERT models is released.
23 |
24 | - Base: [[Tar file](https://storage.googleapis.com/albert_models/albert_base_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_base/3)]
25 | - Large: [[Tar file](https://storage.googleapis.com/albert_models/albert_large_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_large/3)]
26 | - Xlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xlarge_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xlarge/3)]
27 | - Xxlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xxlarge_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xxlarge/3)]
28 |
29 | In this version, we apply 'no dropout', 'additional training data' and 'long training time' strategies to all models. We train ALBERT-base for 10M steps and other models for 3M steps.
30 |
31 | The result comparison to the v1 models is as followings:
32 |
33 | | | Average | SQuAD1.1 | SQuAD2.0 | MNLI | SST-2 | RACE |
34 | |----------------|----------|----------|----------|----------|----------|----------|
35 | |V2 |
36 | |ALBERT-base |82.3 |90.2/83.2 |82.1/79.3 |84.6 |92.9 |66.8 |
37 | |ALBERT-large |85.7 |91.8/85.2 |84.9/81.8 |86.5 |94.9 |75.2 |
38 | |ALBERT-xlarge |87.9 |92.9/86.4 |87.9/84.1 |87.9 |95.4 |80.7 |
39 | |ALBERT-xxlarge |90.9 |94.6/89.1 |89.8/86.9 |90.6 |96.8 |86.8 |
40 | |V1 |
41 | |ALBERT-base |80.1 |89.3/82.3 | 80.0/77.1|81.6 |90.3 | 64.0 |
42 | |ALBERT-large |82.4 |90.6/83.9 | 82.3/79.4|83.5 |91.7 | 68.5 |
43 | |ALBERT-xlarge |85.5 |92.5/86.1 | 86.1/83.1|86.4 |92.4 | 74.8 |
44 | |ALBERT-xxlarge |91.0 |94.8/89.3 | 90.2/87.4|90.8 |96.9 | 86.5 |
45 |
46 | The comparison shows that for ALBERT-base, ALBERT-large, and ALBERT-xlarge, v2 is much better than v1, indicating the importance of applying the above three strategies. On average, ALBERT-xxlarge is slightly worse than the v1, because of the following two reasons: 1) Training additional 1.5 M steps (the only difference between these two models is training for 1.5M steps and 3M steps) did not lead to significant performance improvement. 2) For v1, we did a little bit hyperparameter search among the parameters sets given by BERT, Roberta, and XLnet. For v2, we simply adopt the parameters from v1 except for RACE, where we use a learning rate of 1e-5 and 0 [ALBERT DR](https://arxiv.org/pdf/1909.11942.pdf) (dropout rate for ALBERT in finetuning). The original (v1) RACE hyperparameter will cause model divergence for v2 models. Given that the downstream tasks are sensitive to the fine-tuning hyperparameters, we should be careful about so called slight improvements.
47 |
48 | ALBERT is "A Lite" version of BERT, a popular unsupervised language
49 | representation learning algorithm. ALBERT uses parameter-reduction techniques
50 | that allow for large-scale configurations, overcome previous memory limitations,
51 | and achieve better behavior with respect to model degradation.
52 |
53 | For a technical description of the algorithm, see our paper:
54 |
55 | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942)
56 |
57 | Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut
58 |
59 | Release Notes
60 | =============
61 |
62 | - Initial release: 10/9/2019
63 |
64 | Results
65 | =======
66 |
67 | Performance of ALBERT on GLUE benchmark results using a single-model setup on
68 | dev:
69 |
70 | | Models | MNLI | QNLI | QQP | RTE | SST | MRPC | CoLA | STS |
71 | |-------------------|----------|----------|----------|----------|----------|----------|----------|----------|
72 | | BERT-large | 86.6 | 92.3 | 91.3 | 70.4 | 93.2 | 88.0 | 60.6 | 90.0 |
73 | | XLNet-large | 89.8 | 93.9 | 91.8 | 83.8 | 95.6 | 89.2 | 63.6 | 91.8 |
74 | | RoBERTa-large | 90.2 | 94.7 | **92.2** | 86.6 | 96.4 | **90.9** | 68.0 | 92.4 |
75 | | ALBERT (1M) | 90.4 | 95.2 | 92.0 | 88.1 | 96.8 | 90.2 | 68.7 | 92.7 |
76 | | ALBERT (1.5M) | **90.8** | **95.3** | **92.2** | **89.2** | **96.9** | **90.9** | **71.4** | **93.0** |
77 |
78 | Performance of ALBERT-xxl on SQuaD and RACE benchmarks using a single-model
79 | setup:
80 |
81 | |Models | SQuAD1.1 dev | SQuAD2.0 dev | SQuAD2.0 test | RACE test (Middle/High) |
82 | |--------------------------|---------------|---------------|---------------|-------------------------|
83 | |BERT-large | 90.9/84.1 | 81.8/79.0 | 89.1/86.3 | 72.0 (76.6/70.1) |
84 | |XLNet | 94.5/89.0 | 88.8/86.1 | 89.1/86.3 | 81.8 (85.5/80.2) |
85 | |RoBERTa | 94.6/88.9 | 89.4/86.5 | 89.8/86.8 | 83.2 (86.5/81.3) |
86 | |UPM | - | - | 89.9/87.2 | - |
87 | |XLNet + SG-Net Verifier++ | - | - | 90.1/87.2 | - |
88 | |ALBERT (1M) | 94.8/89.2 | 89.9/87.2 | - | 86.0 (88.2/85.1) |
89 | |ALBERT (1.5M) | **94.8/89.3** | **90.2/87.4** | **90.9/88.1** | **86.5 (89.0/85.5)** |
90 |
91 |
92 | Pre-trained Models
93 | ==================
94 | TF-Hub modules are available:
95 |
96 | - Base: [[Tar file](https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_base/1)]
97 | - Large: [[Tar file](https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_large/1)]
98 | - Xlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xlarge_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xlarge/1)]
99 | - Xxlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xxlarge_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xxlarge/1)]
100 |
101 | Example usage of the TF-Hub module in code:
102 |
103 | ```
104 | tags = set()
105 | if is_training:
106 | tags.add("train")
107 | albert_module = hub.Module("https://tfhub.dev/google/albert_base/1", tags=tags,
108 | trainable=True)
109 | albert_inputs = dict(
110 | input_ids=input_ids,
111 | input_mask=input_mask,
112 | segment_ids=segment_ids)
113 | albert_outputs = albert_module(
114 | inputs=albert_inputs,
115 | signature="tokens",
116 | as_dict=True)
117 |
118 | # If you want to use the token-level output, use
119 | # albert_outputs["sequence_output"] instead.
120 | output_layer = albert_outputs["pooled_output"]
121 | ```
122 |
123 | Most of the fine-tuning scripts in this repository support TF-hub modules
124 | via the `--albert_hub_module_handle` flag.
125 |
126 | Pre-training Instructions
127 | =========================
128 | To pretrain ALBERT, use `run_pretraining.py`:
129 |
130 | ```
131 | pip install -r albert/requirements.txt
132 | python -m albert.run_pretraining \
133 | --input_file=... \
134 | --output_dir=... \
135 | --init_checkpoint=... \
136 | --albert_config_file=... \
137 | --do_train \
138 | --do_eval \
139 | --train_batch_size=4096 \
140 | --eval_batch_size=64 \
141 | --max_seq_length=512 \
142 | --max_predictions_per_seq=20 \
143 | --optimizer='lamb' \
144 | --learning_rate=.00176 \
145 | --num_train_steps=125000 \
146 | --num_warmup_steps=3125 \
147 | --save_checkpoints_steps=5000
148 | ```
149 |
150 | Fine-tuning on GLUE
151 | ===================
152 | To fine-tune and evaluate a pretrained ALBERT on GLUE, please see the
153 | convenience script `run_glue.sh`.
154 |
155 | Lower-level use cases may want to use the `run_classifier.py` script directly.
156 | The `run_classifier.py` script is used both for fine-tuning and evaluation of
157 | ALBERT on individual GLUE benchmark tasks, such as MNLI:
158 |
159 | ```
160 | pip install -r albert/requirements.txt
161 | python -m albert.run_classifier \
162 | --data_dir=... \
163 | --output_dir=... \
164 | --init_checkpoint=... \
165 | --albert_config_file=... \
166 | --spm_model_file=... \
167 | --do_train \
168 | --do_eval \
169 | --do_predict \
170 | --do_lower_case \
171 | --max_seq_length=128 \
172 | --optimizer=adamw \
173 | --task_name=MNLI \
174 | --warmup_step=1000 \
175 | --learning_rate=3e-5 \
176 | --train_step=10000 \
177 | --save_checkpoints_steps=100 \
178 | --train_batch_size=128
179 | ```
180 |
181 | Good default flag values for each GLUE task can be found in `run_glue.sh`.
182 |
183 | You can fine-tune the model starting from TF-Hub modules instead of raw
184 | checkpoints by setting e.g.
185 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead
186 | of `--init_checkpoint`.
187 |
188 | You can find the spm_model_file in the tar files or under the assets folder of
189 | the tf-hub module. The name of the model file is "30k-clean.model".
190 |
191 | After evaluation, the script should report some output like this:
192 |
193 | ```
194 | ***** Eval results *****
195 | global_step = ...
196 | loss = ...
197 | masked_lm_accuracy = ...
198 | masked_lm_loss = ...
199 | sentence_order_accuracy = ...
200 | sentence_order_loss = ...
201 | ```
202 |
203 | Fine-tuning on SQuAD
204 | ====================
205 | To fine-tune and evaluate a pretrained model on SQuAD v1, use the
206 | `run_squad_v1.py` script:
207 |
208 | ```
209 | pip install -r albert/requirements.txt
210 | python -m albert.run_squad_v1 \
211 | --albert_config_file=... \
212 | --output_dir=... \
213 | --train_file=... \
214 | --predict_file=... \
215 | --train_feature_file=... \
216 | --predict_feature_file=... \
217 | --predict_feature_left_file=... \
218 | --init_checkpoint=... \
219 | --spm_model_file=... \
220 | --do_lower_case \
221 | --max_seq_length=384 \
222 | --doc_stride=128 \
223 | --max_query_length=64 \
224 | --do_train=true \
225 | --do_predict=true \
226 | --train_batch_size=48 \
227 | --predict_batch_size=8 \
228 | --learning_rate=5e-5 \
229 | --num_train_epochs=2.0 \
230 | --warmup_proportion=.1 \
231 | --save_checkpoints_steps=5000 \
232 | --n_best_size=20 \
233 | --max_answer_length=30
234 | ```
235 |
236 | You can fine-tune the model starting from TF-Hub modules instead of raw
237 | checkpoints by setting e.g.
238 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead
239 | of `--init_checkpoint`.
240 |
241 | For SQuAD v2, use the `run_squad_v2.py` script:
242 |
243 | ```
244 | pip install -r albert/requirements.txt
245 | python -m albert.run_squad_v2 \
246 | --albert_config_file=... \
247 | --output_dir=... \
248 | --train_file=... \
249 | --predict_file=... \
250 | --train_feature_file=... \
251 | --predict_feature_file=... \
252 | --predict_feature_left_file=... \
253 | --init_checkpoint=... \
254 | --spm_model_file=... \
255 | --do_lower_case \
256 | --max_seq_length=384 \
257 | --doc_stride=128 \
258 | --max_query_length=64 \
259 | --do_train \
260 | --do_predict \
261 | --train_batch_size=48 \
262 | --predict_batch_size=8 \
263 | --learning_rate=5e-5 \
264 | --num_train_epochs=2.0 \
265 | --warmup_proportion=.1 \
266 | --save_checkpoints_steps=5000 \
267 | --n_best_size=20 \
268 | --max_answer_length=30
269 | ```
270 |
271 | You can fine-tune the model starting from TF-Hub modules instead of raw
272 | checkpoints by setting e.g.
273 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead
274 | of `--init_checkpoint`.
275 |
276 | Fine-tuning on RACE
277 | ===================
278 | For RACE, use the `run_race.py` script:
279 |
280 | ```
281 | pip install -r albert/requirements.txt
282 | python -m albert.run_race \
283 | --albert_config_file=... \
284 | --output_dir=... \
285 | --train_file=... \
286 | --eval_file=... \
287 | --data_dir=...\
288 | --init_checkpoint=... \
289 | --spm_model_file=... \
290 | --max_seq_length=512 \
291 | --max_qa_length=128 \
292 | --do_train \
293 | --do_eval \
294 | --train_batch_size=32 \
295 | --eval_batch_size=8 \
296 | --learning_rate=1e-5 \
297 | --train_step=12000 \
298 | --warmup_step=1000 \
299 | --save_checkpoints_steps=100
300 | ```
301 |
302 | You can fine-tune the model starting from TF-Hub modules instead of raw
303 | checkpoints by setting e.g.
304 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead
305 | of `--init_checkpoint`.
306 |
307 | SentencePiece
308 | =============
309 | Command for generating the sentence piece vocabulary:
310 |
311 | ```
312 | spm_train \
313 | --input all.txt --model_prefix=30k-clean --vocab_size=30000 --logtostderr
314 | --pad_id=0 --unk_id=1 --eos_id=-1 --bos_id=-1
315 | --control_symbols=[CLS],[SEP],[MASK]
316 | --user_defined_symbols="(,),\",-,.,–,£,€"
317 | --shuffle_input_sentence=true --input_sentence_size=10000000
318 | --character_coverage=0.99995 --model_type=unigram
319 | ```
320 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/albert_glue_fine_tuning_tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "albert_glue_fine_tuning_tutorial",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "toc_visible": true
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "accelerator": "TPU"
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "y8SJfpgTccDB",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "\n",
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "id": "wHQH4OCHZ9bq",
33 | "colab_type": "code",
34 | "cellView": "form",
35 | "colab": {}
36 | },
37 | "source": [
38 | "# @title Copyright 2020 The ALBERT Authors. All Rights Reserved.\n",
39 | "#\n",
40 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
41 | "# you may not use this file except in compliance with the License.\n",
42 | "# You may obtain a copy of the License at\n",
43 | "#\n",
44 | "# http://www.apache.org/licenses/LICENSE-2.0\n",
45 | "#\n",
46 | "# Unless required by applicable law or agreed to in writing, software\n",
47 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
48 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
49 | "# See the License for the specific language governing permissions and\n",
50 | "# limitations under the License.\n",
51 | "# =============================================================================="
52 | ],
53 | "execution_count": 0,
54 | "outputs": []
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {
59 | "id": "rkTLZ3I4_7c_",
60 | "colab_type": "text"
61 | },
62 | "source": [
63 | "# ALBERT End to End (Fine-tuning + Predicting) with Cloud TPU"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {
69 | "id": "1wtjs1QDb3DX",
70 | "colab_type": "text"
71 | },
72 | "source": [
73 | "## Overview\n",
74 | "\n",
75 | "ALBERT is \"A Lite\" version of BERT, a popular unsupervised language representation learning algorithm. ALBERT uses parameter-reduction techniques that allow for large-scale configurations, overcome previous memory limitations, and achieve better behavior with respect to model degradation.\n",
76 | "\n",
77 | "For a technical description of the algorithm, see our paper:\n",
78 | "\n",
79 | "https://arxiv.org/abs/1909.11942\n",
80 | "\n",
81 | "Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut\n",
82 | "\n",
83 | "This Colab demonstates using a free Colab Cloud TPU to fine-tune GLUE tasks built on top of pretrained ALBERT models and \n",
84 | "run predictions on tuned model. The colab demonsrates loading pretrained ALBERT models from both [TF Hub](https://www.tensorflow.org/hub) and checkpoints.\n",
85 | "\n",
86 | "**Note:** You will need a GCP (Google Compute Engine) account and a GCS (Google Cloud \n",
87 | "Storage) bucket for this Colab to run.\n",
88 | "\n",
89 | "Please follow the [Google Cloud TPU quickstart](https://cloud.google.com/tpu/docs/quickstart) for how to create GCP account and GCS bucket. You have [$300 free credit](https://cloud.google.com/free/) to get started with any GCP product. You can learn more about Cloud TPU at https://cloud.google.com/tpu/docs.\n",
90 | "\n",
91 | "This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select **File > View on GitHub**."
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "metadata": {
97 | "id": "Ld-JXlueIuPH",
98 | "colab_type": "text"
99 | },
100 | "source": [
101 | "## Instructions"
102 | ]
103 | },
104 | {
105 | "cell_type": "markdown",
106 | "metadata": {
107 | "id": "POkof5uHaQ_c",
108 | "colab_type": "text"
109 | },
110 | "source": [
111 | "
Train on TPU
\n",
112 | "\n",
113 | " 1. Create a Cloud Storage bucket for your TensorBoard logs at http://console.cloud.google.com/storage and fill in the BUCKET parameter in the \"Parameters\" section below.\n",
114 | " \n",
115 | " 1. On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
116 | " 1. Click Runtime again and select **Runtime > Run All** (Watch out: the \"Colab-only auth for this notebook and the TPU\" cell requires user input). You can also run the cells manually with Shift-ENTER."
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {
122 | "id": "UdMmwCJFaT8F",
123 | "colab_type": "text"
124 | },
125 | "source": [
126 | "### Set up your TPU environment\n",
127 | "\n",
128 | "In this section, you perform the following tasks:\n",
129 | "\n",
130 | "* Set up a Colab TPU running environment\n",
131 | "* Verify that you are connected to a TPU device\n",
132 | "* Upload your credentials to TPU to access your GCS bucket."
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "metadata": {
138 | "id": "191zq3ZErihP",
139 | "colab_type": "code",
140 | "colab": {}
141 | },
142 | "source": [
143 | "# TODO(lanzhzh): Add support for 2.x.\n",
144 | "%tensorflow_version 1.x\n",
145 | "import os\n",
146 | "import pprint\n",
147 | "import json\n",
148 | "import tensorflow as tf\n",
149 | "\n",
150 | "assert \"COLAB_TPU_ADDR\" in os.environ, \"ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!\"\n",
151 | "TPU_ADDRESS = \"grpc://\" + os.environ[\"COLAB_TPU_ADDR\"] \n",
152 | "TPU_TOPOLOGY = \"2x2\"\n",
153 | "print(\"TPU address is\", TPU_ADDRESS)\n",
154 | "\n",
155 | "from google.colab import auth\n",
156 | "auth.authenticate_user()\n",
157 | "with tf.Session(TPU_ADDRESS) as session:\n",
158 | " print('TPU devices:')\n",
159 | " pprint.pprint(session.list_devices())\n",
160 | "\n",
161 | " # Upload credentials to TPU.\n",
162 | " with open('/content/adc.json', 'r') as f:\n",
163 | " auth_info = json.load(f)\n",
164 | " tf.contrib.cloud.configure_gcs(session, credentials=auth_info)\n",
165 | " # Now credentials are set for all future sessions on this TPU."
166 | ],
167 | "execution_count": 0,
168 | "outputs": []
169 | },
170 | {
171 | "cell_type": "markdown",
172 | "metadata": {
173 | "id": "HUBP35oCDmbF",
174 | "colab_type": "text"
175 | },
176 | "source": [
177 | "### Prepare and import ALBERT modules\n",
178 | "\n",
179 | "With your environment configured, you can now prepare and import the ALBERT modules. The following step clones the source code from GitHub."
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "metadata": {
185 | "id": "7wzwke0sxS6W",
186 | "colab_type": "code",
187 | "colab": {},
188 | "cellView": "code"
189 | },
190 | "source": [
191 | "#TODO(lanzhzh): Add pip support\n",
192 | "import sys\n",
193 | "\n",
194 | "!test -d albert || git clone https://github.com/google-research/albert albert\n",
195 | "if not 'albert' in sys.path:\n",
196 | " sys.path += ['albert']\n",
197 | " \n",
198 | "!pip install sentencepiece\n"
199 | ],
200 | "execution_count": 0,
201 | "outputs": []
202 | },
203 | {
204 | "cell_type": "markdown",
205 | "metadata": {
206 | "id": "RRu1aKO1D7-Z",
207 | "colab_type": "text"
208 | },
209 | "source": [
210 | "### Prepare for training\n",
211 | "\n",
212 | "This next section of code performs the following tasks:\n",
213 | "\n",
214 | "* Specify GS bucket, create output directory for model checkpoints and eval results.\n",
215 | "* Specify task and download training data.\n",
216 | "* Specify ALBERT pretrained model\n",
217 | "\n",
218 | "\n",
219 | "\n"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "metadata": {
225 | "id": "tYkaAlJNfhul",
226 | "colab_type": "code",
227 | "colab": {},
228 | "cellView": "form"
229 | },
230 | "source": [
231 | "# Please find the full list of tasks and their fintuning hyperparameters\n",
232 | "# here https://github.com/google-research/albert/blob/master/run_glue.sh\n",
233 | "\n",
234 | "BUCKET = \"albert_tutorial_glue\" #@param { type: \"string\" }\n",
235 | "TASK = 'MRPC' #@param {type:\"string\"}\n",
236 | "# Available pretrained model checkpoints:\n",
237 | "# base, large, xlarge, xxlarge\n",
238 | "ALBERT_MODEL = 'base' #@param {type:\"string\"}\n",
239 | "\n",
240 | "TASK_DATA_DIR = 'glue_data'\n",
241 | "\n",
242 | "BASE_DIR = \"gs://\" + BUCKET\n",
243 | "if not BASE_DIR or BASE_DIR == \"gs://\":\n",
244 | " raise ValueError(\"You must enter a BUCKET.\")\n",
245 | "DATA_DIR = os.path.join(BASE_DIR, \"data\")\n",
246 | "MODELS_DIR = os.path.join(BASE_DIR, \"models\")\n",
247 | "OUTPUT_DIR = 'gs://{}/albert-tfhub/models/{}'.format(BUCKET, TASK)\n",
248 | "tf.gfile.MakeDirs(OUTPUT_DIR)\n",
249 | "print('***** Model output directory: {} *****'.format(OUTPUT_DIR))\n",
250 | "\n",
251 | "# Download glue data.\n",
252 | "! test -d download_glue_repo || git clone https://gist.github.com/60c2bdb54d156a41194446737ce03e2e.git download_glue_repo\n",
253 | "!python download_glue_repo/download_glue_data.py --data_dir=$TASK_DATA_DIR --tasks=$TASK\n",
254 | "print('***** Task data directory: {} *****'.format(TASK_DATA_DIR))\n",
255 | "\n",
256 | "ALBERT_MODEL_HUB = 'https://tfhub.dev/google/albert_' + ALBERT_MODEL + '/3'"
257 | ],
258 | "execution_count": 0,
259 | "outputs": []
260 | },
261 | {
262 | "cell_type": "markdown",
263 | "metadata": {
264 | "id": "Hcpfl4N2EdOk",
265 | "colab_type": "text"
266 | },
267 | "source": [
268 | "Now let's run the fine-tuning scripts. If you use the default MRPC task, this should be finished in around 10 mintues and you will get an accuracy of around 86.5."
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "metadata": {
274 | "id": "o8qXPxv8-kBO",
275 | "colab_type": "code",
276 | "colab": {}
277 | },
278 | "source": [
279 | "os.environ['TFHUB_CACHE_DIR'] = OUTPUT_DIR\n",
280 | "!python -m albert.run_classifier \\\n",
281 | " --data_dir=\"glue_data/\" \\\n",
282 | " --output_dir=$OUTPUT_DIR \\\n",
283 | " --albert_hub_module_handle=$ALBERT_MODEL_HUB \\\n",
284 | " --spm_model_file=\"from_tf_hub\" \\\n",
285 | " --do_train=True \\\n",
286 | " --do_eval=True \\\n",
287 | " --do_predict=False \\\n",
288 | " --max_seq_length=512 \\\n",
289 | " --optimizer=adamw \\\n",
290 | " --task_name=$TASK \\\n",
291 | " --warmup_step=200 \\\n",
292 | " --learning_rate=2e-5 \\\n",
293 | " --train_step=800 \\\n",
294 | " --save_checkpoints_steps=100 \\\n",
295 | " --train_batch_size=32 \\\n",
296 | " --tpu_name=$TPU_ADDRESS \\\n",
297 | " --use_tpu=True"
298 | ],
299 | "execution_count": 0,
300 | "outputs": []
301 | }
302 | ]
303 | }
304 |
--------------------------------------------------------------------------------
/export_checkpoints.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | r"""Exports a minimal module for ALBERT models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import os
21 | from absl import app
22 | from absl import flags
23 | from albert import modeling
24 | import tensorflow.compat.v1 as tf
25 |
26 | flags.DEFINE_string(
27 | "albert_directory", None,
28 | "The config json file corresponding to the pre-trained ALBERT model. "
29 | "This specifies the model architecture.")
30 |
31 | flags.DEFINE_string(
32 | "checkpoint_name", "model.ckpt-best",
33 | "Name of the checkpoint under albert_directory to be exported.")
34 |
35 | flags.DEFINE_bool(
36 | "do_lower_case", True,
37 | "Whether to lower case the input text. Should be True for uncased "
38 | "models and False for cased models.")
39 |
40 | flags.DEFINE_string("export_path", None, "Path to the output module.")
41 |
42 | FLAGS = flags.FLAGS
43 |
44 |
45 | def gather_indexes(sequence_tensor, positions):
46 | """Gathers the vectors at the specific positions over a minibatch."""
47 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
48 | batch_size = sequence_shape[0]
49 | seq_length = sequence_shape[1]
50 | width = sequence_shape[2]
51 |
52 | flat_offsets = tf.reshape(
53 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
54 | flat_positions = tf.reshape(positions + flat_offsets, [-1])
55 | flat_sequence_tensor = tf.reshape(sequence_tensor,
56 | [batch_size * seq_length, width])
57 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
58 | return output_tensor
59 |
60 |
61 | def get_mlm_logits(input_tensor, albert_config, mlm_positions, output_weights):
62 | """From run_pretraining.py."""
63 | input_tensor = gather_indexes(input_tensor, mlm_positions)
64 | with tf.variable_scope("cls/predictions"):
65 | # We apply one more non-linear transformation before the output layer.
66 | # This matrix is not used after pre-training.
67 | with tf.variable_scope("transform"):
68 | input_tensor = tf.layers.dense(
69 | input_tensor,
70 | units=albert_config.embedding_size,
71 | activation=modeling.get_activation(albert_config.hidden_act),
72 | kernel_initializer=modeling.create_initializer(
73 | albert_config.initializer_range))
74 | input_tensor = modeling.layer_norm(input_tensor)
75 |
76 | # The output weights are the same as the input embeddings, but there is
77 | # an output-only bias for each token.
78 | output_bias = tf.get_variable(
79 | "output_bias",
80 | shape=[albert_config.vocab_size],
81 | initializer=tf.zeros_initializer())
82 | logits = tf.matmul(
83 | input_tensor, output_weights, transpose_b=True)
84 | logits = tf.nn.bias_add(logits, output_bias)
85 | return logits
86 |
87 |
88 | def get_sentence_order_logits(input_tensor, albert_config):
89 | """Get loss and log probs for the next sentence prediction."""
90 |
91 | # Simple binary classification. Note that 0 is "next sentence" and 1 is
92 | # "random sentence". This weight matrix is not used after pre-training.
93 | with tf.variable_scope("cls/seq_relationship"):
94 | output_weights = tf.get_variable(
95 | "output_weights",
96 | shape=[2, albert_config.hidden_size],
97 | initializer=modeling.create_initializer(
98 | albert_config.initializer_range))
99 | output_bias = tf.get_variable(
100 | "output_bias", shape=[2], initializer=tf.zeros_initializer())
101 |
102 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
103 | logits = tf.nn.bias_add(logits, output_bias)
104 | return logits
105 |
106 |
107 | def build_model(sess):
108 | """Module function."""
109 | input_ids = tf.placeholder(tf.int32, [None, None], "input_ids")
110 | input_mask = tf.placeholder(tf.int32, [None, None], "input_mask")
111 | segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids")
112 | mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions")
113 |
114 | albert_config_path = os.path.join(
115 | FLAGS.albert_directory, "albert_config.json")
116 | albert_config = modeling.AlbertConfig.from_json_file(albert_config_path)
117 | model = modeling.AlbertModel(
118 | config=albert_config,
119 | is_training=False,
120 | input_ids=input_ids,
121 | input_mask=input_mask,
122 | token_type_ids=segment_ids,
123 | use_one_hot_embeddings=False)
124 |
125 | get_mlm_logits(model.get_sequence_output(), albert_config,
126 | mlm_positions, model.get_embedding_table())
127 | get_sentence_order_logits(model.get_pooled_output(), albert_config)
128 |
129 | checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name)
130 | tvars = tf.trainable_variables()
131 | (assignment_map, initialized_variable_names
132 | ) = modeling.get_assignment_map_from_checkpoint(tvars, checkpoint_path)
133 |
134 | tf.logging.info("**** Trainable Variables ****")
135 | for var in tvars:
136 | init_string = ""
137 | if var.name in initialized_variable_names:
138 | init_string = ", *INIT_FROM_CKPT*"
139 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
140 | init_string)
141 | tf.train.init_from_checkpoint(checkpoint_path, assignment_map)
142 | init = tf.global_variables_initializer()
143 | sess.run(init)
144 | return sess
145 |
146 |
147 | def main(_):
148 | sess = tf.Session()
149 | tf.train.get_or_create_global_step()
150 | sess = build_model(sess)
151 | my_vars = []
152 | for var in tf.global_variables():
153 | if "lamb_v" not in var.name and "lamb_m" not in var.name:
154 | my_vars.append(var)
155 | saver = tf.train.Saver(my_vars)
156 | saver.save(sess, FLAGS.export_path)
157 |
158 |
159 | if __name__ == "__main__":
160 | flags.mark_flag_as_required("albert_directory")
161 | flags.mark_flag_as_required("export_path")
162 | app.run(main)
163 |
--------------------------------------------------------------------------------
/export_to_tfhub.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | r"""Exports a minimal TF-Hub module for ALBERT models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import os
21 | from absl import app
22 | from absl import flags
23 | from albert import modeling
24 | import tensorflow.compat.v1 as tf
25 | import tensorflow_hub as hub
26 |
27 | flags.DEFINE_string(
28 | "albert_directory", None,
29 | "The config json file corresponding to the pre-trained ALBERT model. "
30 | "This specifies the model architecture.")
31 |
32 | flags.DEFINE_string(
33 | "checkpoint_name", "model.ckpt-best",
34 | "Name of the checkpoint under albert_directory to be exported.")
35 |
36 | flags.DEFINE_bool(
37 | "do_lower_case", True,
38 | "Whether to lower case the input text. Should be True for uncased "
39 | "models and False for cased models.")
40 |
41 | flags.DEFINE_bool(
42 | "use_einsum", True,
43 | "Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must "
44 | "be set to False for TFLite compatibility.")
45 |
46 | flags.DEFINE_string("export_path", None, "Path to the output TF-Hub module.")
47 |
48 | FLAGS = flags.FLAGS
49 |
50 |
51 | def gather_indexes(sequence_tensor, positions):
52 | """Gathers the vectors at the specific positions over a minibatch."""
53 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
54 | batch_size = sequence_shape[0]
55 | seq_length = sequence_shape[1]
56 | width = sequence_shape[2]
57 |
58 | flat_offsets = tf.reshape(
59 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
60 | flat_positions = tf.reshape(positions + flat_offsets, [-1])
61 | flat_sequence_tensor = tf.reshape(sequence_tensor,
62 | [batch_size * seq_length, width])
63 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
64 | return output_tensor
65 |
66 |
67 | def get_mlm_logits(model, albert_config, mlm_positions):
68 | """From run_pretraining.py."""
69 | input_tensor = gather_indexes(model.get_sequence_output(), mlm_positions)
70 | with tf.variable_scope("cls/predictions"):
71 | # We apply one more non-linear transformation before the output layer.
72 | # This matrix is not used after pre-training.
73 | with tf.variable_scope("transform"):
74 | input_tensor = tf.layers.dense(
75 | input_tensor,
76 | units=albert_config.embedding_size,
77 | activation=modeling.get_activation(albert_config.hidden_act),
78 | kernel_initializer=modeling.create_initializer(
79 | albert_config.initializer_range))
80 | input_tensor = modeling.layer_norm(input_tensor)
81 |
82 | # The output weights are the same as the input embeddings, but there is
83 | # an output-only bias for each token.
84 | output_bias = tf.get_variable(
85 | "output_bias",
86 | shape=[albert_config.vocab_size],
87 | initializer=tf.zeros_initializer())
88 | logits = tf.matmul(
89 | input_tensor, model.get_embedding_table(), transpose_b=True)
90 | logits = tf.nn.bias_add(logits, output_bias)
91 | return logits
92 |
93 |
94 | def get_sop_log_probs(model, albert_config):
95 | """Get loss and log probs for the next sentence prediction."""
96 | input_tensor = model.get_pooled_output()
97 | # Simple binary classification. Note that 0 is "next sentence" and 1 is
98 | # "random sentence". This weight matrix is not used after pre-training.
99 | with tf.variable_scope("cls/seq_relationship"):
100 | output_weights = tf.get_variable(
101 | "output_weights",
102 | shape=[2, albert_config.hidden_size],
103 | initializer=modeling.create_initializer(
104 | albert_config.initializer_range))
105 | output_bias = tf.get_variable(
106 | "output_bias", shape=[2], initializer=tf.zeros_initializer())
107 |
108 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
109 | logits = tf.nn.bias_add(logits, output_bias)
110 | log_probs = tf.nn.log_softmax(logits, axis=-1)
111 | return log_probs
112 |
113 |
114 | def module_fn(is_training):
115 | """Module function."""
116 | input_ids = tf.placeholder(tf.int32, [None, None], "input_ids")
117 | input_mask = tf.placeholder(tf.int32, [None, None], "input_mask")
118 | segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids")
119 | mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions")
120 |
121 | albert_config_path = os.path.join(
122 | FLAGS.albert_directory, "albert_config.json")
123 | albert_config = modeling.AlbertConfig.from_json_file(albert_config_path)
124 | model = modeling.AlbertModel(
125 | config=albert_config,
126 | is_training=is_training,
127 | input_ids=input_ids,
128 | input_mask=input_mask,
129 | token_type_ids=segment_ids,
130 | use_one_hot_embeddings=False,
131 | use_einsum=FLAGS.use_einsum)
132 |
133 | mlm_logits = get_mlm_logits(model, albert_config, mlm_positions)
134 | sop_log_probs = get_sop_log_probs(model, albert_config)
135 |
136 | vocab_model_path = os.path.join(FLAGS.albert_directory, "30k-clean.model")
137 | vocab_file_path = os.path.join(FLAGS.albert_directory, "30k-clean.vocab")
138 |
139 | config_file = tf.constant(
140 | value=albert_config_path, dtype=tf.string, name="config_file")
141 | vocab_model = tf.constant(
142 | value=vocab_model_path, dtype=tf.string, name="vocab_model")
143 | # This is only for visualization purpose.
144 | vocab_file = tf.constant(
145 | value=vocab_file_path, dtype=tf.string, name="vocab_file")
146 |
147 | # By adding `config_file, vocab_model and vocab_file`
148 | # to the ASSET_FILEPATHS collection, TF-Hub will
149 | # rewrite this tensor so that this asset is portable.
150 | tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, config_file)
151 | tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_model)
152 | tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file)
153 |
154 | hub.add_signature(
155 | name="tokens",
156 | inputs=dict(
157 | input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids),
158 | outputs=dict(
159 | sequence_output=model.get_sequence_output(),
160 | pooled_output=model.get_pooled_output()))
161 |
162 | hub.add_signature(
163 | name="sop",
164 | inputs=dict(
165 | input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids),
166 | outputs=dict(
167 | sequence_output=model.get_sequence_output(),
168 | pooled_output=model.get_pooled_output(),
169 | sop_log_probs=sop_log_probs))
170 |
171 | hub.add_signature(
172 | name="mlm",
173 | inputs=dict(
174 | input_ids=input_ids,
175 | input_mask=input_mask,
176 | segment_ids=segment_ids,
177 | mlm_positions=mlm_positions),
178 | outputs=dict(
179 | sequence_output=model.get_sequence_output(),
180 | pooled_output=model.get_pooled_output(),
181 | mlm_logits=mlm_logits))
182 |
183 | hub.add_signature(
184 | name="tokenization_info",
185 | inputs={},
186 | outputs=dict(
187 | vocab_file=vocab_model,
188 | do_lower_case=tf.constant(FLAGS.do_lower_case)))
189 |
190 |
191 | def main(_):
192 | tags_and_args = []
193 | for is_training in (True, False):
194 | tags = set()
195 | if is_training:
196 | tags.add("train")
197 | tags_and_args.append((tags, dict(is_training=is_training)))
198 | spec = hub.create_module_spec(module_fn, tags_and_args=tags_and_args)
199 | checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name)
200 | tf.logging.info("Using checkpoint {}".format(checkpoint_path))
201 | spec.export(FLAGS.export_path, checkpoint_path=checkpoint_path)
202 |
203 |
204 | if __name__ == "__main__":
205 | flags.mark_flag_as_required("albert_directory")
206 | flags.mark_flag_as_required("export_path")
207 | app.run(main)
208 |
--------------------------------------------------------------------------------
/fine_tuning_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Helper library for ALBERT fine-tuning.
16 |
17 | This library can be used to construct ALBERT models for fine-tuning, either from
18 | json config files or from TF-Hub modules.
19 | """
20 |
21 | from albert import modeling
22 | from albert import tokenization
23 | import tensorflow.compat.v1 as tf
24 | import tensorflow_hub as hub
25 |
26 |
27 | def _create_model_from_hub(hub_module, is_training, input_ids, input_mask,
28 | segment_ids):
29 | """Creates an ALBERT model from TF-Hub."""
30 | tags = set()
31 | if is_training:
32 | tags.add("train")
33 | albert_module = hub.Module(hub_module, tags=tags, trainable=True)
34 | albert_inputs = dict(
35 | input_ids=input_ids,
36 | input_mask=input_mask,
37 | segment_ids=segment_ids)
38 | albert_outputs = albert_module(
39 | inputs=albert_inputs,
40 | signature="tokens",
41 | as_dict=True)
42 | return (albert_outputs["pooled_output"], albert_outputs["sequence_output"])
43 |
44 |
45 | def _create_model_from_scratch(albert_config, is_training, input_ids,
46 | input_mask, segment_ids, use_one_hot_embeddings,
47 | use_einsum):
48 | """Creates an ALBERT model from scratch/config."""
49 | model = modeling.AlbertModel(
50 | config=albert_config,
51 | is_training=is_training,
52 | input_ids=input_ids,
53 | input_mask=input_mask,
54 | token_type_ids=segment_ids,
55 | use_one_hot_embeddings=use_one_hot_embeddings,
56 | use_einsum=use_einsum)
57 | return (model.get_pooled_output(), model.get_sequence_output())
58 |
59 |
60 | def create_albert(albert_config, is_training, input_ids, input_mask,
61 | segment_ids, use_one_hot_embeddings, use_einsum, hub_module):
62 | """Creates an ALBERT, either from TF-Hub or from scratch."""
63 | if hub_module:
64 | tf.logging.info("creating model from hub_module: %s", hub_module)
65 | return _create_model_from_hub(hub_module, is_training, input_ids,
66 | input_mask, segment_ids)
67 | else:
68 | tf.logging.info("creating model from albert_config")
69 | return _create_model_from_scratch(albert_config, is_training, input_ids,
70 | input_mask, segment_ids,
71 | use_one_hot_embeddings, use_einsum)
72 |
73 |
74 | def create_vocab(vocab_file, do_lower_case, spm_model_file, hub_module):
75 | """Creates a vocab, either from vocab file or from a TF-Hub module."""
76 | if hub_module:
77 | use_spm = True if spm_model_file else False
78 | return tokenization.FullTokenizer.from_hub_module(
79 | hub_module=hub_module, use_spm=use_spm)
80 | else:
81 | return tokenization.FullTokenizer.from_scratch(
82 | vocab_file=vocab_file, do_lower_case=do_lower_case,
83 | spm_model_file=spm_model_file)
84 |
85 |
--------------------------------------------------------------------------------
/lamb_optimizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import six
23 | import tensorflow.compat.v1 as tf
24 |
25 | # pylint: disable=g-direct-tensorflow-import
26 | from tensorflow.python.ops import array_ops
27 | from tensorflow.python.ops import linalg_ops
28 | from tensorflow.python.ops import math_ops
29 | # pylint: enable=g-direct-tensorflow-import
30 |
31 |
32 | class LAMBOptimizer(tf.train.Optimizer):
33 | """LAMB (Layer-wise Adaptive Moments optimizer for Batch training)."""
34 | # A new optimizer that includes correct L2 weight decay, adaptive
35 | # element-wise updating, and layer-wise justification. The LAMB optimizer
36 | # was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song,
37 | # James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT
38 | # Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962)
39 |
40 | def __init__(self,
41 | learning_rate,
42 | weight_decay_rate=0.0,
43 | beta_1=0.9,
44 | beta_2=0.999,
45 | epsilon=1e-6,
46 | exclude_from_weight_decay=None,
47 | exclude_from_layer_adaptation=None,
48 | name="LAMBOptimizer"):
49 | """Constructs a LAMBOptimizer."""
50 | super(LAMBOptimizer, self).__init__(False, name)
51 |
52 | self.learning_rate = learning_rate
53 | self.weight_decay_rate = weight_decay_rate
54 | self.beta_1 = beta_1
55 | self.beta_2 = beta_2
56 | self.epsilon = epsilon
57 | self.exclude_from_weight_decay = exclude_from_weight_decay
58 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
59 | # arg is None.
60 | # TODO(jingli): validate if exclude_from_layer_adaptation is necessary.
61 | if exclude_from_layer_adaptation:
62 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
63 | else:
64 | self.exclude_from_layer_adaptation = exclude_from_weight_decay
65 |
66 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
67 | """See base class."""
68 | assignments = []
69 | for (grad, param) in grads_and_vars:
70 | if grad is None or param is None:
71 | continue
72 |
73 | param_name = self._get_variable_name(param.name)
74 |
75 | m = tf.get_variable(
76 | name=six.ensure_str(param_name) + "/adam_m",
77 | shape=param.shape.as_list(),
78 | dtype=tf.float32,
79 | trainable=False,
80 | initializer=tf.zeros_initializer())
81 | v = tf.get_variable(
82 | name=six.ensure_str(param_name) + "/adam_v",
83 | shape=param.shape.as_list(),
84 | dtype=tf.float32,
85 | trainable=False,
86 | initializer=tf.zeros_initializer())
87 |
88 | # Standard Adam update.
89 | next_m = (
90 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
91 | next_v = (
92 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
93 | tf.square(grad)))
94 |
95 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
96 |
97 | # Just adding the square of the weights to the loss function is *not*
98 | # the correct way of using L2 regularization/weight decay with Adam,
99 | # since that will interact with the m and v parameters in strange ways.
100 | #
101 | # Instead we want ot decay the weights in a manner that doesn't interact
102 | # with the m/v parameters. This is equivalent to adding the square
103 | # of the weights to the loss with plain (non-momentum) SGD.
104 | if self._do_use_weight_decay(param_name):
105 | update += self.weight_decay_rate * param
106 |
107 | ratio = 1.0
108 | if self._do_layer_adaptation(param_name):
109 | w_norm = linalg_ops.norm(param, ord=2)
110 | g_norm = linalg_ops.norm(update, ord=2)
111 | ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where(
112 | math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0)
113 |
114 | update_with_lr = ratio * self.learning_rate * update
115 |
116 | next_param = param - update_with_lr
117 |
118 | assignments.extend(
119 | [param.assign(next_param),
120 | m.assign(next_m),
121 | v.assign(next_v)])
122 | return tf.group(*assignments, name=name)
123 |
124 | def _do_use_weight_decay(self, param_name):
125 | """Whether to use L2 weight decay for `param_name`."""
126 | if not self.weight_decay_rate:
127 | return False
128 | if self.exclude_from_weight_decay:
129 | for r in self.exclude_from_weight_decay:
130 | if re.search(r, param_name) is not None:
131 | return False
132 | return True
133 |
134 | def _do_layer_adaptation(self, param_name):
135 | """Whether to do layer-wise learning rate adaptation for `param_name`."""
136 | if self.exclude_from_layer_adaptation:
137 | for r in self.exclude_from_layer_adaptation:
138 | if re.search(r, param_name) is not None:
139 | return False
140 | return True
141 |
142 | def _get_variable_name(self, param_name):
143 | """Get the variable name from the tensor name."""
144 | m = re.match("^(.*):\\d+$", six.ensure_str(param_name))
145 | if m is not None:
146 | param_name = m.group(1)
147 | return param_name
148 |
--------------------------------------------------------------------------------
/modeling_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import json
21 | import random
22 | import re
23 |
24 | from albert import modeling
25 | import numpy as np
26 | import six
27 | from six.moves import range
28 | import tensorflow.compat.v1 as tf
29 |
30 |
31 | class AlbertModelTest(tf.test.TestCase):
32 |
33 | class AlbertModelTester(object):
34 |
35 | def __init__(self,
36 | parent,
37 | batch_size=13,
38 | seq_length=7,
39 | is_training=True,
40 | use_input_mask=True,
41 | use_token_type_ids=True,
42 | vocab_size=99,
43 | embedding_size=32,
44 | hidden_size=32,
45 | num_hidden_layers=5,
46 | num_attention_heads=4,
47 | intermediate_size=37,
48 | hidden_act="gelu",
49 | hidden_dropout_prob=0.1,
50 | attention_probs_dropout_prob=0.1,
51 | max_position_embeddings=512,
52 | type_vocab_size=16,
53 | initializer_range=0.02,
54 | scope=None):
55 | self.parent = parent
56 | self.batch_size = batch_size
57 | self.seq_length = seq_length
58 | self.is_training = is_training
59 | self.use_input_mask = use_input_mask
60 | self.use_token_type_ids = use_token_type_ids
61 | self.vocab_size = vocab_size
62 | self.embedding_size = embedding_size
63 | self.hidden_size = hidden_size
64 | self.num_hidden_layers = num_hidden_layers
65 | self.num_attention_heads = num_attention_heads
66 | self.intermediate_size = intermediate_size
67 | self.hidden_act = hidden_act
68 | self.hidden_dropout_prob = hidden_dropout_prob
69 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
70 | self.max_position_embeddings = max_position_embeddings
71 | self.type_vocab_size = type_vocab_size
72 | self.initializer_range = initializer_range
73 | self.scope = scope
74 |
75 | def create_model(self):
76 | input_ids = AlbertModelTest.ids_tensor([self.batch_size, self.seq_length],
77 | self.vocab_size)
78 |
79 | input_mask = None
80 | if self.use_input_mask:
81 | input_mask = AlbertModelTest.ids_tensor(
82 | [self.batch_size, self.seq_length], vocab_size=2)
83 |
84 | token_type_ids = None
85 | if self.use_token_type_ids:
86 | token_type_ids = AlbertModelTest.ids_tensor(
87 | [self.batch_size, self.seq_length], self.type_vocab_size)
88 |
89 | config = modeling.AlbertConfig(
90 | vocab_size=self.vocab_size,
91 | embedding_size=self.embedding_size,
92 | hidden_size=self.hidden_size,
93 | num_hidden_layers=self.num_hidden_layers,
94 | num_attention_heads=self.num_attention_heads,
95 | intermediate_size=self.intermediate_size,
96 | hidden_act=self.hidden_act,
97 | hidden_dropout_prob=self.hidden_dropout_prob,
98 | attention_probs_dropout_prob=self.attention_probs_dropout_prob,
99 | max_position_embeddings=self.max_position_embeddings,
100 | type_vocab_size=self.type_vocab_size,
101 | initializer_range=self.initializer_range)
102 |
103 | model = modeling.AlbertModel(
104 | config=config,
105 | is_training=self.is_training,
106 | input_ids=input_ids,
107 | input_mask=input_mask,
108 | token_type_ids=token_type_ids,
109 | scope=self.scope)
110 |
111 | outputs = {
112 | "embedding_output": model.get_embedding_output(),
113 | "sequence_output": model.get_sequence_output(),
114 | "pooled_output": model.get_pooled_output(),
115 | "all_encoder_layers": model.get_all_encoder_layers(),
116 | }
117 | return outputs
118 |
119 | def check_output(self, result):
120 | self.parent.assertAllEqual(
121 | result["embedding_output"].shape,
122 | [self.batch_size, self.seq_length, self.embedding_size])
123 |
124 | self.parent.assertAllEqual(
125 | result["sequence_output"].shape,
126 | [self.batch_size, self.seq_length, self.hidden_size])
127 |
128 | self.parent.assertAllEqual(result["pooled_output"].shape,
129 | [self.batch_size, self.hidden_size])
130 |
131 | def test_default(self):
132 | self.run_tester(AlbertModelTest.AlbertModelTester(self))
133 |
134 | def test_config_to_json_string(self):
135 | config = modeling.AlbertConfig(vocab_size=99, hidden_size=37)
136 | obj = json.loads(config.to_json_string())
137 | self.assertEqual(obj["vocab_size"], 99)
138 | self.assertEqual(obj["hidden_size"], 37)
139 |
140 | def test_einsum_via_matmul(self):
141 | batch_size = 8
142 | seq_length = 12
143 | num_attention_heads = 3
144 | head_size = 6
145 | hidden_size = 10
146 |
147 | input_tensor = np.random.uniform(0, 1,
148 | [batch_size, seq_length, hidden_size])
149 | input_tensor = tf.constant(input_tensor, dtype=tf.float32)
150 | w = np.random.uniform(0, 1, [hidden_size, num_attention_heads, head_size])
151 | w = tf.constant(w, dtype=tf.float32)
152 | ret1 = tf.einsum("BFH,HND->BFND", input_tensor, w)
153 | ret2 = modeling.einsum_via_matmul(input_tensor, w, 1)
154 | self.assertAllClose(ret1, ret2)
155 |
156 | input_tensor = np.random.uniform(0, 1,
157 | [batch_size, seq_length,
158 | num_attention_heads, head_size])
159 | input_tensor = tf.constant(input_tensor, dtype=tf.float32)
160 | w = np.random.uniform(0, 1, [num_attention_heads, head_size, hidden_size])
161 | w = tf.constant(w, dtype=tf.float32)
162 | ret1 = tf.einsum("BFND,NDH->BFH", input_tensor, w)
163 | ret2 = modeling.einsum_via_matmul(input_tensor, w, 2)
164 | self.assertAllClose(ret1, ret2)
165 |
166 | def run_tester(self, tester):
167 | with self.test_session() as sess:
168 | ops = tester.create_model()
169 | init_op = tf.group(tf.global_variables_initializer(),
170 | tf.local_variables_initializer())
171 | sess.run(init_op)
172 | output_result = sess.run(ops)
173 | tester.check_output(output_result)
174 |
175 | self.assert_all_tensors_reachable(sess, [init_op, ops])
176 |
177 | @classmethod
178 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
179 | """Creates a random int32 tensor of the shape within the vocab size."""
180 | if rng is None:
181 | rng = random.Random()
182 |
183 | total_dims = 1
184 | for dim in shape:
185 | total_dims *= dim
186 |
187 | values = []
188 | for _ in range(total_dims):
189 | values.append(rng.randint(0, vocab_size - 1))
190 |
191 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
192 |
193 | def assert_all_tensors_reachable(self, sess, outputs):
194 | """Checks that all the tensors in the graph are reachable from outputs."""
195 | graph = sess.graph
196 |
197 | ignore_strings = [
198 | "^.*/assert_less_equal/.*$",
199 | "^.*/dilation_rate$",
200 | "^.*/Tensordot/concat$",
201 | "^.*/Tensordot/concat/axis$",
202 | "^testing/.*$",
203 | ]
204 |
205 | ignore_regexes = [re.compile(x) for x in ignore_strings]
206 |
207 | unreachable = self.get_unreachable_ops(graph, outputs)
208 | filtered_unreachable = []
209 | for x in unreachable:
210 | do_ignore = False
211 | for r in ignore_regexes:
212 | m = r.match(six.ensure_str(x.name))
213 | if m is not None:
214 | do_ignore = True
215 | if do_ignore:
216 | continue
217 | filtered_unreachable.append(x)
218 | unreachable = filtered_unreachable
219 |
220 | self.assertEqual(
221 | len(unreachable), 0, "The following ops are unreachable: %s" %
222 | (" ".join([x.name for x in unreachable])))
223 |
224 | @classmethod
225 | def get_unreachable_ops(cls, graph, outputs):
226 | """Finds all of the tensors in graph that are unreachable from outputs."""
227 | outputs = cls.flatten_recursive(outputs)
228 | output_to_op = collections.defaultdict(list)
229 | op_to_all = collections.defaultdict(list)
230 | assign_out_to_in = collections.defaultdict(list)
231 |
232 | for op in graph.get_operations():
233 | for x in op.inputs:
234 | op_to_all[op.name].append(x.name)
235 | for y in op.outputs:
236 | output_to_op[y.name].append(op.name)
237 | op_to_all[op.name].append(y.name)
238 | if str(op.type) == "Assign":
239 | for y in op.outputs:
240 | for x in op.inputs:
241 | assign_out_to_in[y.name].append(x.name)
242 |
243 | assign_groups = collections.defaultdict(list)
244 | for out_name in assign_out_to_in.keys():
245 | name_group = assign_out_to_in[out_name]
246 | for n1 in name_group:
247 | assign_groups[n1].append(out_name)
248 | for n2 in name_group:
249 | if n1 != n2:
250 | assign_groups[n1].append(n2)
251 |
252 | seen_tensors = {}
253 | stack = [x.name for x in outputs]
254 | while stack:
255 | name = stack.pop()
256 | if name in seen_tensors:
257 | continue
258 | seen_tensors[name] = True
259 |
260 | if name in output_to_op:
261 | for op_name in output_to_op[name]:
262 | if op_name in op_to_all:
263 | for input_name in op_to_all[op_name]:
264 | if input_name not in stack:
265 | stack.append(input_name)
266 |
267 | expanded_names = []
268 | if name in assign_groups:
269 | for assign_name in assign_groups[name]:
270 | expanded_names.append(assign_name)
271 |
272 | for expanded_name in expanded_names:
273 | if expanded_name not in stack:
274 | stack.append(expanded_name)
275 |
276 | unreachable_ops = []
277 | for op in graph.get_operations():
278 | is_unreachable = False
279 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
280 | for name in all_names:
281 | if name not in seen_tensors:
282 | is_unreachable = True
283 | if is_unreachable:
284 | unreachable_ops.append(op)
285 | return unreachable_ops
286 |
287 | @classmethod
288 | def flatten_recursive(cls, item):
289 | """Flattens (potentially nested) a tuple/dictionary/list to a list."""
290 | output = []
291 | if isinstance(item, list):
292 | output.extend(item)
293 | elif isinstance(item, tuple):
294 | output.extend(list(item))
295 | elif isinstance(item, dict):
296 | for (_, v) in six.iteritems(item):
297 | output.append(v)
298 | else:
299 | return [item]
300 |
301 | flat_output = []
302 | for x in output:
303 | flat_output.extend(cls.flatten_recursive(x))
304 | return flat_output
305 |
306 |
307 | if __name__ == "__main__":
308 | tf.test.main()
309 |
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import re
21 | from albert import lamb_optimizer
22 | import six
23 | from six.moves import zip
24 | import tensorflow.compat.v1 as tf
25 | from tensorflow.contrib import tpu as contrib_tpu
26 |
27 |
28 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu,
29 | optimizer="adamw", poly_power=1.0, start_warmup_step=0,
30 | colocate_gradients_with_ops=False, excluded_tvars=None):
31 | """Creates an optimizer training op."""
32 | global_step = tf.train.get_or_create_global_step()
33 |
34 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
35 |
36 | # Implements linear decay of the learning rate.
37 | learning_rate = tf.train.polynomial_decay(
38 | learning_rate,
39 | global_step,
40 | num_train_steps,
41 | end_learning_rate=0.0,
42 | power=poly_power,
43 | cycle=False)
44 |
45 | # Implements linear warmup. I.e., if global_step - start_warmup_step <
46 | # num_warmup_steps, the learning rate will be
47 | # `(global_step - start_warmup_step)/num_warmup_steps * init_lr`.
48 | if num_warmup_steps:
49 | tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step)
50 | + ", for " + str(num_warmup_steps) + " steps ++++++")
51 | global_steps_int = tf.cast(global_step, tf.int32)
52 | start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32)
53 | global_steps_int = global_steps_int - start_warm_int
54 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
55 |
56 | global_steps_float = tf.cast(global_steps_int, tf.float32)
57 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
58 |
59 | warmup_percent_done = global_steps_float / warmup_steps_float
60 | warmup_learning_rate = init_lr * warmup_percent_done
61 |
62 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
63 | learning_rate = (
64 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
65 |
66 | # It is OK that you use this optimizer for finetuning, since this
67 | # is how the model was trained (note that the Adam m/v variables are NOT
68 | # loaded from init_checkpoint.)
69 | # It is OK to use AdamW in the finetuning even the model is trained by LAMB.
70 | # As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune
71 | # is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a
72 | # batch size of 64 in the finetune.
73 | if optimizer == "adamw":
74 | tf.logging.info("using adamw")
75 | optimizer = AdamWeightDecayOptimizer(
76 | learning_rate=learning_rate,
77 | weight_decay_rate=0.01,
78 | beta_1=0.9,
79 | beta_2=0.999,
80 | epsilon=1e-6,
81 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
82 | elif optimizer == "lamb":
83 | tf.logging.info("using lamb")
84 | optimizer = lamb_optimizer.LAMBOptimizer(
85 | learning_rate=learning_rate,
86 | weight_decay_rate=0.01,
87 | beta_1=0.9,
88 | beta_2=0.999,
89 | epsilon=1e-6,
90 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
91 | else:
92 | raise ValueError("Not supported optimizer: ", optimizer)
93 |
94 | if use_tpu:
95 | optimizer = contrib_tpu.CrossShardOptimizer(optimizer)
96 |
97 | tvars = tf.trainable_variables()
98 | for tvar in tvars:
99 | if excluded_tvars and tvar.name in excluded_tvars:
100 | tvars.remove(tvar)
101 |
102 | grads = tf.gradients(
103 | loss, tvars, colocate_gradients_with_ops=colocate_gradients_with_ops)
104 |
105 | # This is how the model was pre-trained.
106 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
107 |
108 | train_op = optimizer.apply_gradients(
109 | list(zip(grads, tvars)), global_step=global_step)
110 |
111 | # Normally the global step update is done inside of `apply_gradients`.
112 | # However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this.
113 | # But if you use a different optimizer, you should probably take this line
114 | # out.
115 | new_global_step = global_step + 1
116 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
117 | return train_op
118 |
119 |
120 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
121 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
122 |
123 | def __init__(self,
124 | learning_rate,
125 | weight_decay_rate=0.0,
126 | beta_1=0.9,
127 | beta_2=0.999,
128 | epsilon=1e-6,
129 | exclude_from_weight_decay=None,
130 | name="AdamWeightDecayOptimizer"):
131 | """Constructs a AdamWeightDecayOptimizer."""
132 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
133 |
134 | self.learning_rate = learning_rate
135 | self.weight_decay_rate = weight_decay_rate
136 | self.beta_1 = beta_1
137 | self.beta_2 = beta_2
138 | self.epsilon = epsilon
139 | self.exclude_from_weight_decay = exclude_from_weight_decay
140 |
141 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
142 | """See base class."""
143 | assignments = []
144 | for (grad, param) in grads_and_vars:
145 | if grad is None or param is None:
146 | continue
147 |
148 | param_name = self._get_variable_name(param.name)
149 |
150 | m = tf.get_variable(
151 | name=six.ensure_str(param_name) + "/adam_m",
152 | shape=param.shape.as_list(),
153 | dtype=tf.float32,
154 | trainable=False,
155 | initializer=tf.zeros_initializer())
156 | v = tf.get_variable(
157 | name=six.ensure_str(param_name) + "/adam_v",
158 | shape=param.shape.as_list(),
159 | dtype=tf.float32,
160 | trainable=False,
161 | initializer=tf.zeros_initializer())
162 |
163 | # Standard Adam update.
164 | next_m = (
165 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
166 | next_v = (
167 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
168 | tf.square(grad)))
169 |
170 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
171 |
172 | # Just adding the square of the weights to the loss function is *not*
173 | # the correct way of using L2 regularization/weight decay with Adam,
174 | # since that will interact with the m and v parameters in strange ways.
175 | #
176 | # Instead we want ot decay the weights in a manner that doesn't interact
177 | # with the m/v parameters. This is equivalent to adding the square
178 | # of the weights to the loss with plain (non-momentum) SGD.
179 | if self._do_use_weight_decay(param_name):
180 | update += self.weight_decay_rate * param
181 |
182 | update_with_lr = self.learning_rate * update
183 |
184 | next_param = param - update_with_lr
185 |
186 | assignments.extend(
187 | [param.assign(next_param),
188 | m.assign(next_m),
189 | v.assign(next_v)])
190 | return tf.group(*assignments, name=name)
191 |
192 | def _do_use_weight_decay(self, param_name):
193 | """Whether to use L2 weight decay for `param_name`."""
194 | if not self.weight_decay_rate:
195 | return False
196 | if self.exclude_from_weight_decay:
197 | for r in self.exclude_from_weight_decay:
198 | if re.search(r, param_name) is not None:
199 | return False
200 | return True
201 |
202 | def _get_variable_name(self, param_name):
203 | """Get the variable name from the tensor name."""
204 | m = re.match("^(.*):\\d+$", six.ensure_str(param_name))
205 | if m is not None:
206 | param_name = m.group(1)
207 | return param_name
208 |
--------------------------------------------------------------------------------
/optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 | from albert import optimization
19 | from six.moves import range
20 | from six.moves import zip
21 | import tensorflow.compat.v1 as tf
22 |
23 |
24 | class OptimizationTest(tf.test.TestCase):
25 |
26 | def test_adam(self):
27 | with self.test_session() as sess:
28 | w = tf.get_variable(
29 | "w",
30 | shape=[3],
31 | initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
32 | x = tf.constant([0.4, 0.2, -0.5])
33 | loss = tf.reduce_mean(tf.square(x - w))
34 | tvars = tf.trainable_variables()
35 | grads = tf.gradients(loss, tvars)
36 | global_step = tf.train.get_or_create_global_step()
37 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
38 | train_op = optimizer.apply_gradients(list(zip(grads, tvars)), global_step)
39 | init_op = tf.group(tf.global_variables_initializer(),
40 | tf.local_variables_initializer())
41 | sess.run(init_op)
42 | for _ in range(100):
43 | sess.run(train_op)
44 | w_np = sess.run(w)
45 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
46 |
47 |
48 | if __name__ == "__main__":
49 | tf.test.main()
50 |
--------------------------------------------------------------------------------
/race_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Utility functions for RACE dataset."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import json
23 | import os
24 | from albert import classifier_utils
25 | from albert import fine_tuning_utils
26 | from albert import modeling
27 | from albert import optimization
28 | from albert import tokenization
29 | import tensorflow.compat.v1 as tf
30 | from tensorflow.compat.v1 import estimator as tf_estimator
31 | from tensorflow.contrib import tpu as contrib_tpu
32 |
33 |
34 | class InputExample(object):
35 | """A single training/test example for the RACE dataset."""
36 |
37 | def __init__(self,
38 | example_id,
39 | context_sentence,
40 | start_ending,
41 | endings,
42 | label=None):
43 | self.example_id = example_id
44 | self.context_sentence = context_sentence
45 | self.start_ending = start_ending
46 | self.endings = endings
47 | self.label = label
48 |
49 | def __str__(self):
50 | return self.__repr__()
51 |
52 | def __repr__(self):
53 | l = [
54 | "id: {}".format(self.example_id),
55 | "context_sentence: {}".format(self.context_sentence),
56 | "start_ending: {}".format(self.start_ending),
57 | "ending_0: {}".format(self.endings[0]),
58 | "ending_1: {}".format(self.endings[1]),
59 | "ending_2: {}".format(self.endings[2]),
60 | "ending_3: {}".format(self.endings[3]),
61 | ]
62 |
63 | if self.label is not None:
64 | l.append("label: {}".format(self.label))
65 |
66 | return ", ".join(l)
67 |
68 |
69 | class RaceProcessor(object):
70 | """Processor for the RACE data set."""
71 |
72 | def __init__(self, use_spm, do_lower_case, high_only, middle_only):
73 | super(RaceProcessor, self).__init__()
74 | self.use_spm = use_spm
75 | self.do_lower_case = do_lower_case
76 | self.high_only = high_only
77 | self.middle_only = middle_only
78 |
79 | def get_train_examples(self, data_dir):
80 | """Gets a collection of `InputExample`s for the train set."""
81 | return self.read_examples(
82 | os.path.join(data_dir, "RACE", "train"))
83 |
84 | def get_dev_examples(self, data_dir):
85 | """Gets a collection of `InputExample`s for the dev set."""
86 | return self.read_examples(
87 | os.path.join(data_dir, "RACE", "dev"))
88 |
89 | def get_test_examples(self, data_dir):
90 | """Gets a collection of `InputExample`s for prediction."""
91 | return self.read_examples(
92 | os.path.join(data_dir, "RACE", "test"))
93 |
94 | def get_labels(self):
95 | """Gets the list of labels for this data set."""
96 | return ["A", "B", "C", "D"]
97 |
98 | def process_text(self, text):
99 | if self.use_spm:
100 | return tokenization.preprocess_text(text, lower=self.do_lower_case)
101 | else:
102 | return tokenization.convert_to_unicode(text)
103 |
104 | def read_examples(self, data_dir):
105 | """Read examples from RACE json files."""
106 | examples = []
107 | for level in ["middle", "high"]:
108 | if level == "middle" and self.high_only: continue
109 | if level == "high" and self.middle_only: continue
110 | cur_dir = os.path.join(data_dir, level)
111 |
112 | cur_path = os.path.join(cur_dir, "all.txt")
113 | with tf.gfile.Open(cur_path) as f:
114 | for line in f:
115 | cur_data = json.loads(line.strip())
116 |
117 | answers = cur_data["answers"]
118 | options = cur_data["options"]
119 | questions = cur_data["questions"]
120 | context = self.process_text(cur_data["article"])
121 |
122 | for i in range(len(answers)):
123 | label = ord(answers[i]) - ord("A")
124 | qa_list = []
125 |
126 | question = self.process_text(questions[i])
127 | for j in range(4):
128 | option = self.process_text(options[i][j])
129 |
130 | if "_" in question:
131 | qa_cat = question.replace("_", option)
132 | else:
133 | qa_cat = " ".join([question, option])
134 |
135 | qa_list.append(qa_cat)
136 |
137 | examples.append(
138 | InputExample(
139 | example_id=cur_data["id"],
140 | context_sentence=context,
141 | start_ending=None,
142 | endings=[qa_list[0], qa_list[1], qa_list[2], qa_list[3]],
143 | label=label
144 | )
145 | )
146 |
147 | return examples
148 |
149 |
150 | def convert_single_example(example_index, example, label_size, max_seq_length,
151 | tokenizer, max_qa_length):
152 | """Loads a data file into a list of `InputBatch`s."""
153 |
154 | # RACE is a multiple choice task. To perform this task using AlBERT,
155 | # we will use the formatting proposed in "Improving Language
156 | # Understanding by Generative Pre-Training" and suggested by
157 | # @jacobdevlin-google in this issue
158 | # https://github.com/google-research/bert/issues/38.
159 | #
160 | # Each choice will correspond to a sample on which we run the
161 | # inference. For a given RACE example, we will create the 4
162 | # following inputs:
163 | # - [CLS] context [SEP] choice_1 [SEP]
164 | # - [CLS] context [SEP] choice_2 [SEP]
165 | # - [CLS] context [SEP] choice_3 [SEP]
166 | # - [CLS] context [SEP] choice_4 [SEP]
167 | # The model will output a single value for each input. To get the
168 | # final decision of the model, we will run a softmax over these 4
169 | # outputs.
170 | if isinstance(example, classifier_utils.PaddingInputExample):
171 | return classifier_utils.InputFeatures(
172 | example_id=0,
173 | input_ids=[[0] * max_seq_length] * label_size,
174 | input_mask=[[0] * max_seq_length] * label_size,
175 | segment_ids=[[0] * max_seq_length] * label_size,
176 | label_id=0,
177 | is_real_example=False)
178 | else:
179 | context_tokens = tokenizer.tokenize(example.context_sentence)
180 | if example.start_ending is not None:
181 | start_ending_tokens = tokenizer.tokenize(example.start_ending)
182 |
183 | all_input_tokens = []
184 | all_input_ids = []
185 | all_input_mask = []
186 | all_segment_ids = []
187 | for ending in example.endings:
188 | # We create a copy of the context tokens in order to be
189 | # able to shrink it according to ending_tokens
190 | context_tokens_choice = context_tokens[:]
191 | if example.start_ending is not None:
192 | ending_tokens = start_ending_tokens + tokenizer.tokenize(ending)
193 | else:
194 | ending_tokens = tokenizer.tokenize(ending)
195 | # Modifies `context_tokens_choice` and `ending_tokens` in
196 | # place so that the total length is less than the
197 | # specified length. Account for [CLS], [SEP], [SEP] with
198 | # "- 3"
199 | ending_tokens = ending_tokens[- max_qa_length:]
200 |
201 | if len(context_tokens_choice) + len(ending_tokens) > max_seq_length - 3:
202 | context_tokens_choice = context_tokens_choice[: (
203 | max_seq_length - 3 - len(ending_tokens))]
204 | tokens = ["[CLS]"] + context_tokens_choice + (
205 | ["[SEP]"] + ending_tokens + ["[SEP]"])
206 | segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (
207 | len(ending_tokens) + 1)
208 |
209 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
210 | input_mask = [1] * len(input_ids)
211 |
212 | # Zero-pad up to the sequence length.
213 | padding = [0] * (max_seq_length - len(input_ids))
214 | input_ids += padding
215 | input_mask += padding
216 | segment_ids += padding
217 |
218 | assert len(input_ids) == max_seq_length
219 | assert len(input_mask) == max_seq_length
220 | assert len(segment_ids) == max_seq_length
221 |
222 | all_input_tokens.append(tokens)
223 | all_input_ids.append(input_ids)
224 | all_input_mask.append(input_mask)
225 | all_segment_ids.append(segment_ids)
226 |
227 | label = example.label
228 | if example_index < 5:
229 | tf.logging.info("*** Example ***")
230 | tf.logging.info("id: {}".format(example.example_id))
231 | for choice_idx, (tokens, input_ids, input_mask, segment_ids) in \
232 | enumerate(zip(all_input_tokens, all_input_ids, all_input_mask, all_segment_ids)):
233 | tf.logging.info("choice: {}".format(choice_idx))
234 | tf.logging.info("tokens: {}".format(" ".join(tokens)))
235 | tf.logging.info(
236 | "input_ids: {}".format(" ".join(map(str, input_ids))))
237 | tf.logging.info(
238 | "input_mask: {}".format(" ".join(map(str, input_mask))))
239 | tf.logging.info(
240 | "segment_ids: {}".format(" ".join(map(str, segment_ids))))
241 | tf.logging.info("label: {}".format(label))
242 |
243 | return classifier_utils.InputFeatures(
244 | example_id=example.example_id,
245 | input_ids=all_input_ids,
246 | input_mask=all_input_mask,
247 | segment_ids=all_segment_ids,
248 | label_id=label
249 | )
250 |
251 |
252 | def file_based_convert_examples_to_features(
253 | examples, label_list, max_seq_length, tokenizer,
254 | output_file, max_qa_length):
255 | """Convert a set of `InputExample`s to a TFRecord file."""
256 |
257 | writer = tf.python_io.TFRecordWriter(output_file)
258 |
259 | for (ex_index, example) in enumerate(examples):
260 | if ex_index % 10000 == 0:
261 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
262 |
263 | feature = convert_single_example(ex_index, example, len(label_list),
264 | max_seq_length, tokenizer, max_qa_length)
265 |
266 | def create_int_feature(values):
267 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
268 | return f
269 |
270 | features = collections.OrderedDict()
271 | features["input_ids"] = create_int_feature(sum(feature.input_ids, []))
272 | features["input_mask"] = create_int_feature(sum(feature.input_mask, []))
273 | features["segment_ids"] = create_int_feature(sum(feature.segment_ids, []))
274 | features["label_ids"] = create_int_feature([feature.label_id])
275 | features["is_real_example"] = create_int_feature(
276 | [int(feature.is_real_example)])
277 |
278 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
279 | writer.write(tf_example.SerializeToString())
280 | writer.close()
281 |
282 |
283 | def create_model(albert_config, is_training, input_ids, input_mask, segment_ids,
284 | labels, num_labels, use_one_hot_embeddings, max_seq_length,
285 | dropout_prob, hub_module):
286 | """Creates a classification model."""
287 | bsz_per_core = tf.shape(input_ids)[0]
288 |
289 | input_ids = tf.reshape(input_ids, [bsz_per_core * num_labels, max_seq_length])
290 | input_mask = tf.reshape(input_mask,
291 | [bsz_per_core * num_labels, max_seq_length])
292 | token_type_ids = tf.reshape(segment_ids,
293 | [bsz_per_core * num_labels, max_seq_length])
294 |
295 | (output_layer, _) = fine_tuning_utils.create_albert(
296 | albert_config=albert_config,
297 | is_training=is_training,
298 | input_ids=input_ids,
299 | input_mask=input_mask,
300 | segment_ids=token_type_ids,
301 | use_one_hot_embeddings=use_one_hot_embeddings,
302 | use_einsum=True,
303 | hub_module=hub_module)
304 |
305 | hidden_size = output_layer.shape[-1].value
306 |
307 | output_weights = tf.get_variable(
308 | "output_weights", [1, hidden_size],
309 | initializer=tf.truncated_normal_initializer(stddev=0.02))
310 |
311 | output_bias = tf.get_variable(
312 | "output_bias", [1],
313 | initializer=tf.zeros_initializer())
314 |
315 | with tf.variable_scope("loss"):
316 | if is_training:
317 | # I.e., 0.1 dropout
318 | output_layer = tf.nn.dropout(
319 | output_layer, keep_prob=1 - dropout_prob)
320 |
321 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
322 | logits = tf.nn.bias_add(logits, output_bias)
323 | logits = tf.reshape(logits, [bsz_per_core, num_labels])
324 | probabilities = tf.nn.softmax(logits, axis=-1)
325 | predictions = tf.argmax(probabilities, axis=-1, output_type=tf.int32)
326 | log_probs = tf.nn.log_softmax(logits, axis=-1)
327 |
328 | one_hot_labels = tf.one_hot(
329 | labels, depth=tf.cast(num_labels, dtype=tf.int32), dtype=tf.float32)
330 |
331 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
332 | loss = tf.reduce_mean(per_example_loss)
333 |
334 | return (loss, per_example_loss, probabilities, logits, predictions)
335 |
336 |
337 | def model_fn_builder(albert_config, num_labels, init_checkpoint, learning_rate,
338 | num_train_steps, num_warmup_steps, use_tpu,
339 | use_one_hot_embeddings, max_seq_length, dropout_prob,
340 | hub_module):
341 | """Returns `model_fn` closure for TPUEstimator."""
342 |
343 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
344 | """The `model_fn` for TPUEstimator."""
345 |
346 | tf.logging.info("*** Features ***")
347 | for name in sorted(features.keys()):
348 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
349 |
350 | input_ids = features["input_ids"]
351 | input_mask = features["input_mask"]
352 | segment_ids = features["segment_ids"]
353 | label_ids = features["label_ids"]
354 | is_real_example = None
355 | if "is_real_example" in features:
356 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32)
357 | else:
358 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)
359 |
360 | is_training = (mode == tf_estimator.ModeKeys.TRAIN)
361 |
362 | (total_loss, per_example_loss, probabilities, logits, predictions) = \
363 | create_model(albert_config, is_training, input_ids, input_mask,
364 | segment_ids, label_ids, num_labels,
365 | use_one_hot_embeddings, max_seq_length, dropout_prob,
366 | hub_module)
367 |
368 | tvars = tf.trainable_variables()
369 | initialized_variable_names = {}
370 | scaffold_fn = None
371 | if init_checkpoint:
372 | (assignment_map, initialized_variable_names
373 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
374 | if use_tpu:
375 |
376 | def tpu_scaffold():
377 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
378 | return tf.train.Scaffold()
379 |
380 | scaffold_fn = tpu_scaffold
381 | else:
382 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
383 |
384 | tf.logging.info("**** Trainable Variables ****")
385 | for var in tvars:
386 | init_string = ""
387 | if var.name in initialized_variable_names:
388 | init_string = ", *INIT_FROM_CKPT*"
389 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
390 | init_string)
391 |
392 | output_spec = None
393 | if mode == tf_estimator.ModeKeys.TRAIN:
394 |
395 | train_op = optimization.create_optimizer(
396 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
397 |
398 | output_spec = contrib_tpu.TPUEstimatorSpec(
399 | mode=mode,
400 | loss=total_loss,
401 | train_op=train_op,
402 | scaffold_fn=scaffold_fn)
403 | elif mode == tf_estimator.ModeKeys.EVAL:
404 | def metric_fn(per_example_loss, label_ids, logits, is_real_example):
405 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
406 | accuracy = tf.metrics.accuracy(
407 | labels=label_ids, predictions=predictions,
408 | weights=is_real_example)
409 | loss = tf.metrics.mean(
410 | values=per_example_loss, weights=is_real_example)
411 | return {
412 | "eval_accuracy": accuracy,
413 | "eval_loss": loss,
414 | }
415 |
416 | eval_metrics = (metric_fn,
417 | [per_example_loss, label_ids, logits, is_real_example])
418 | output_spec = contrib_tpu.TPUEstimatorSpec(
419 | mode=mode,
420 | loss=total_loss,
421 | eval_metrics=eval_metrics,
422 | scaffold_fn=scaffold_fn)
423 | else:
424 | output_spec = contrib_tpu.TPUEstimatorSpec(
425 | mode=mode,
426 | predictions={"probabilities": probabilities,
427 | "predictions": predictions},
428 | scaffold_fn=scaffold_fn)
429 | return output_spec
430 |
431 | return model_fn
432 |
433 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Run pip install --upgrade pip if tensorflow 1.15 cannot be found
2 | tensorflow==1.15.2 # CPU Version of TensorFlow
3 | tensorflow_hub==0.7
4 | # tensorflow-gpu==1.15 # GPU version of TensorFlow
5 | sentencepiece
6 |
--------------------------------------------------------------------------------
/run_classifier.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """BERT finetuning on classification tasks."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import math
22 | import os
23 | import time
24 | from albert import classifier_utils
25 | from albert import fine_tuning_utils
26 | from albert import modeling
27 | import tensorflow.compat.v1 as tf
28 | from tensorflow.compat.v1 import estimator as tf_estimator
29 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
30 | from tensorflow.contrib import tpu as contrib_tpu
31 |
32 | flags = tf.flags
33 |
34 | FLAGS = flags.FLAGS
35 |
36 | ## Required parameters
37 | flags.DEFINE_string(
38 | "data_dir", None,
39 | "The input data dir. Should contain the .tsv files (or other data files) "
40 | "for the task.")
41 |
42 | flags.DEFINE_string(
43 | "albert_config_file", None,
44 | "The config json file corresponding to the pre-trained ALBERT model. "
45 | "This specifies the model architecture.")
46 |
47 | flags.DEFINE_string("task_name", None, "The name of the task to train.")
48 |
49 | flags.DEFINE_string(
50 | "vocab_file", None,
51 | "The vocabulary file that the ALBERT model was trained on.")
52 |
53 | flags.DEFINE_string("spm_model_file", None,
54 | "The model file for sentence piece tokenization.")
55 |
56 | flags.DEFINE_string(
57 | "output_dir", None,
58 | "The output directory where the model checkpoints will be written.")
59 |
60 | flags.DEFINE_string("cached_dir", None,
61 | "Path to cached training and dev tfrecord file. "
62 | "The file will be generated if not exist.")
63 |
64 | ## Other parameters
65 |
66 | flags.DEFINE_string(
67 | "init_checkpoint", None,
68 | "Initial checkpoint (usually from a pre-trained BERT model).")
69 |
70 | flags.DEFINE_string(
71 | "albert_hub_module_handle", None,
72 | "If set, the ALBERT hub module to use.")
73 |
74 | flags.DEFINE_bool(
75 | "do_lower_case", True,
76 | "Whether to lower case the input text. Should be True for uncased "
77 | "models and False for cased models.")
78 |
79 | flags.DEFINE_integer(
80 | "max_seq_length", 512,
81 | "The maximum total input sequence length after WordPiece tokenization. "
82 | "Sequences longer than this will be truncated, and sequences shorter "
83 | "than this will be padded.")
84 |
85 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
86 |
87 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
88 |
89 | flags.DEFINE_bool(
90 | "do_predict", False,
91 | "Whether to run the model in inference mode on the test set.")
92 |
93 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
94 |
95 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
96 |
97 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
98 |
99 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
100 |
101 | flags.DEFINE_integer("train_step", 1000,
102 | "Total number of training steps to perform.")
103 |
104 | flags.DEFINE_integer(
105 | "warmup_step", 0,
106 | "number of steps to perform linear learning rate warmup for.")
107 |
108 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
109 | "How often to save the model checkpoint.")
110 |
111 | flags.DEFINE_integer("keep_checkpoint_max", 5,
112 | "How many checkpoints to keep.")
113 |
114 | flags.DEFINE_integer("iterations_per_loop", 1000,
115 | "How many steps to make in each estimator call.")
116 |
117 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
118 |
119 | flags.DEFINE_string("optimizer", "adamw", "Optimizer to use")
120 |
121 | tf.flags.DEFINE_string(
122 | "tpu_name", None,
123 | "The Cloud TPU to use for training. This should be either the name "
124 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
125 | "url.")
126 |
127 | tf.flags.DEFINE_string(
128 | "tpu_zone", None,
129 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
130 | "specified, we will attempt to automatically detect the GCE project from "
131 | "metadata.")
132 |
133 | tf.flags.DEFINE_string(
134 | "gcp_project", None,
135 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
136 | "specified, we will attempt to automatically detect the GCE project from "
137 | "metadata.")
138 |
139 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
140 |
141 | flags.DEFINE_integer(
142 | "num_tpu_cores", 8,
143 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
144 |
145 | flags.DEFINE_string(
146 | "export_dir", None,
147 | "The directory where the exported SavedModel will be stored.")
148 |
149 | flags.DEFINE_float(
150 | "threshold_to_export", float("nan"),
151 | "The threshold value that should be used with the exported classifier. "
152 | "When specified, the threshold will be attached to the exported "
153 | "SavedModel, and served along with the predictions. Please use the "
154 | "saved model cli ("
155 | "https://www.tensorflow.org/guide/saved_model#details_of_the_savedmodel_command_line_interface"
156 | ") to view the output signature of the threshold.")
157 |
158 |
159 | def _serving_input_receiver_fn():
160 | """Creates an input function for serving."""
161 | seq_len = FLAGS.max_seq_length
162 | serialized_example = tf.placeholder(
163 | dtype=tf.string, shape=[None], name="serialized_example")
164 | features = {
165 | "input_ids": tf.FixedLenFeature([seq_len], dtype=tf.int64),
166 | "input_mask": tf.FixedLenFeature([seq_len], dtype=tf.int64),
167 | "segment_ids": tf.FixedLenFeature([seq_len], dtype=tf.int64),
168 | }
169 | feature_map = tf.parse_example(serialized_example, features=features)
170 | feature_map["is_real_example"] = tf.constant(1, dtype=tf.int32)
171 | feature_map["label_ids"] = tf.constant(0, dtype=tf.int32)
172 |
173 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
174 | # So cast all int64 to int32.
175 | for name in feature_map.keys():
176 | t = feature_map[name]
177 | if t.dtype == tf.int64:
178 | t = tf.to_int32(t)
179 | feature_map[name] = t
180 |
181 | return tf_estimator.export.ServingInputReceiver(
182 | features=feature_map, receiver_tensors=serialized_example)
183 |
184 |
185 | def _add_threshold_to_model_fn(model_fn, threshold):
186 | """Adds the classifier threshold to the given model_fn."""
187 |
188 | def new_model_fn(features, labels, mode, params):
189 | spec = model_fn(features, labels, mode, params)
190 | threshold_tensor = tf.constant(threshold, dtype=tf.float32)
191 | default_serving_export = spec.export_outputs[
192 | tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
193 | default_serving_export.outputs["threshold"] = threshold_tensor
194 | return spec
195 |
196 | return new_model_fn
197 |
198 |
199 | def main(_):
200 | tf.logging.set_verbosity(tf.logging.INFO)
201 |
202 | processors = {
203 | "cola": classifier_utils.ColaProcessor,
204 | "mnli": classifier_utils.MnliProcessor,
205 | "mismnli": classifier_utils.MisMnliProcessor,
206 | "mrpc": classifier_utils.MrpcProcessor,
207 | "rte": classifier_utils.RteProcessor,
208 | "sst-2": classifier_utils.Sst2Processor,
209 | "sts-b": classifier_utils.StsbProcessor,
210 | "qqp": classifier_utils.QqpProcessor,
211 | "qnli": classifier_utils.QnliProcessor,
212 | "wnli": classifier_utils.WnliProcessor,
213 | }
214 |
215 | if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_predict or
216 | FLAGS.export_dir):
217 | raise ValueError(
218 | "At least one of `do_train`, `do_eval`, `do_predict' or `export_dir` "
219 | "must be True.")
220 |
221 | if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle:
222 | raise ValueError("At least one of `--albert_config_file` and "
223 | "`--albert_hub_module_handle` must be set")
224 |
225 | if FLAGS.albert_config_file:
226 | albert_config = modeling.AlbertConfig.from_json_file(
227 | FLAGS.albert_config_file)
228 | if FLAGS.max_seq_length > albert_config.max_position_embeddings:
229 | raise ValueError(
230 | "Cannot use sequence length %d because the ALBERT model "
231 | "was only trained up to sequence length %d" %
232 | (FLAGS.max_seq_length, albert_config.max_position_embeddings))
233 | else:
234 | albert_config = None # Get the config from TF-Hub.
235 |
236 | tf.gfile.MakeDirs(FLAGS.output_dir)
237 |
238 | task_name = FLAGS.task_name.lower()
239 |
240 | if task_name not in processors:
241 | raise ValueError("Task not found: %s" % (task_name))
242 |
243 | processor = processors[task_name](
244 | use_spm=True if FLAGS.spm_model_file else False,
245 | do_lower_case=FLAGS.do_lower_case)
246 |
247 | label_list = processor.get_labels()
248 |
249 | tokenizer = fine_tuning_utils.create_vocab(
250 | vocab_file=FLAGS.vocab_file,
251 | do_lower_case=FLAGS.do_lower_case,
252 | spm_model_file=FLAGS.spm_model_file,
253 | hub_module=FLAGS.albert_hub_module_handle)
254 |
255 | tpu_cluster_resolver = None
256 | if FLAGS.use_tpu and FLAGS.tpu_name:
257 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
258 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
259 |
260 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
261 | if FLAGS.do_train:
262 | iterations_per_loop = int(min(FLAGS.iterations_per_loop,
263 | FLAGS.save_checkpoints_steps))
264 | else:
265 | iterations_per_loop = FLAGS.iterations_per_loop
266 | run_config = contrib_tpu.RunConfig(
267 | cluster=tpu_cluster_resolver,
268 | master=FLAGS.master,
269 | model_dir=FLAGS.output_dir,
270 | save_checkpoints_steps=int(FLAGS.save_checkpoints_steps),
271 | keep_checkpoint_max=0,
272 | tpu_config=contrib_tpu.TPUConfig(
273 | iterations_per_loop=iterations_per_loop,
274 | num_shards=FLAGS.num_tpu_cores,
275 | per_host_input_for_training=is_per_host))
276 |
277 | train_examples = None
278 | if FLAGS.do_train:
279 | train_examples = processor.get_train_examples(FLAGS.data_dir)
280 | model_fn = classifier_utils.model_fn_builder(
281 | albert_config=albert_config,
282 | num_labels=len(label_list),
283 | init_checkpoint=FLAGS.init_checkpoint,
284 | learning_rate=FLAGS.learning_rate,
285 | num_train_steps=FLAGS.train_step,
286 | num_warmup_steps=FLAGS.warmup_step,
287 | use_tpu=FLAGS.use_tpu,
288 | use_one_hot_embeddings=FLAGS.use_tpu,
289 | task_name=task_name,
290 | hub_module=FLAGS.albert_hub_module_handle,
291 | optimizer=FLAGS.optimizer)
292 |
293 | if not math.isnan(FLAGS.threshold_to_export):
294 | model_fn = _add_threshold_to_model_fn(model_fn, FLAGS.threshold_to_export)
295 |
296 | # If TPU is not available, this will fall back to normal Estimator on CPU
297 | # or GPU.
298 | estimator = contrib_tpu.TPUEstimator(
299 | use_tpu=FLAGS.use_tpu,
300 | model_fn=model_fn,
301 | config=run_config,
302 | train_batch_size=FLAGS.train_batch_size,
303 | eval_batch_size=FLAGS.eval_batch_size,
304 | predict_batch_size=FLAGS.predict_batch_size,
305 | export_to_tpu=False) # http://yaqs/4707241341091840
306 |
307 | if FLAGS.do_train:
308 | cached_dir = FLAGS.cached_dir
309 | if not cached_dir:
310 | cached_dir = FLAGS.output_dir
311 | train_file = os.path.join(cached_dir, task_name + "_train.tf_record")
312 | if not tf.gfile.Exists(train_file):
313 | classifier_utils.file_based_convert_examples_to_features(
314 | train_examples, label_list, FLAGS.max_seq_length, tokenizer,
315 | train_file, task_name)
316 | tf.logging.info("***** Running training *****")
317 | tf.logging.info(" Num examples = %d", len(train_examples))
318 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
319 | tf.logging.info(" Num steps = %d", FLAGS.train_step)
320 | train_input_fn = classifier_utils.file_based_input_fn_builder(
321 | input_file=train_file,
322 | seq_length=FLAGS.max_seq_length,
323 | is_training=True,
324 | drop_remainder=True,
325 | task_name=task_name,
326 | use_tpu=FLAGS.use_tpu,
327 | bsz=FLAGS.train_batch_size)
328 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_step)
329 |
330 | if FLAGS.do_eval:
331 | eval_examples = processor.get_dev_examples(FLAGS.data_dir)
332 | num_actual_eval_examples = len(eval_examples)
333 | if FLAGS.use_tpu:
334 | # TPU requires a fixed batch size for all batches, therefore the number
335 | # of examples must be a multiple of the batch size, or else examples
336 | # will get dropped. So we pad with fake examples which are ignored
337 | # later on. These do NOT count towards the metric (all tf.metrics
338 | # support a per-instance weight, and these get a weight of 0.0).
339 | while len(eval_examples) % FLAGS.eval_batch_size != 0:
340 | eval_examples.append(classifier_utils.PaddingInputExample())
341 |
342 | cached_dir = FLAGS.cached_dir
343 | if not cached_dir:
344 | cached_dir = FLAGS.output_dir
345 | eval_file = os.path.join(cached_dir, task_name + "_eval.tf_record")
346 | if not tf.gfile.Exists(eval_file):
347 | classifier_utils.file_based_convert_examples_to_features(
348 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer,
349 | eval_file, task_name)
350 |
351 | tf.logging.info("***** Running evaluation *****")
352 | tf.logging.info(" Num examples = %d (%d actual, %d padding)",
353 | len(eval_examples), num_actual_eval_examples,
354 | len(eval_examples) - num_actual_eval_examples)
355 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
356 |
357 | # This tells the estimator to run through the entire set.
358 | eval_steps = None
359 | # However, if running eval on the TPU, you will need to specify the
360 | # number of steps.
361 | if FLAGS.use_tpu:
362 | assert len(eval_examples) % FLAGS.eval_batch_size == 0
363 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
364 |
365 | eval_drop_remainder = True if FLAGS.use_tpu else False
366 | eval_input_fn = classifier_utils.file_based_input_fn_builder(
367 | input_file=eval_file,
368 | seq_length=FLAGS.max_seq_length,
369 | is_training=False,
370 | drop_remainder=eval_drop_remainder,
371 | task_name=task_name,
372 | use_tpu=FLAGS.use_tpu,
373 | bsz=FLAGS.eval_batch_size)
374 |
375 | best_trial_info_file = os.path.join(FLAGS.output_dir, "best_trial.txt")
376 |
377 | def _best_trial_info():
378 | """Returns information about which checkpoints have been evaled so far."""
379 | if tf.gfile.Exists(best_trial_info_file):
380 | with tf.gfile.GFile(best_trial_info_file, "r") as best_info:
381 | global_step, best_metric_global_step, metric_value = (
382 | best_info.read().split(":"))
383 | global_step = int(global_step)
384 | best_metric_global_step = int(best_metric_global_step)
385 | metric_value = float(metric_value)
386 | else:
387 | metric_value = -1
388 | best_metric_global_step = -1
389 | global_step = -1
390 | tf.logging.info(
391 | "Best trial info: Step: %s, Best Value Step: %s, "
392 | "Best Value: %s", global_step, best_metric_global_step, metric_value)
393 | return global_step, best_metric_global_step, metric_value
394 |
395 | def _remove_checkpoint(checkpoint_path):
396 | for ext in ["meta", "data-00000-of-00001", "index"]:
397 | src_ckpt = checkpoint_path + ".{}".format(ext)
398 | tf.logging.info("removing {}".format(src_ckpt))
399 | tf.gfile.Remove(src_ckpt)
400 |
401 | def _find_valid_cands(curr_step):
402 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
403 | candidates = []
404 | for filename in filenames:
405 | if filename.endswith(".index"):
406 | ckpt_name = filename[:-6]
407 | idx = ckpt_name.split("-")[-1]
408 | if int(idx) > curr_step:
409 | candidates.append(filename)
410 | return candidates
411 |
412 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
413 |
414 | if task_name == "sts-b":
415 | key_name = "pearson"
416 | elif task_name == "cola":
417 | key_name = "matthew_corr"
418 | else:
419 | key_name = "eval_accuracy"
420 |
421 | global_step, best_perf_global_step, best_perf = _best_trial_info()
422 | writer = tf.gfile.GFile(output_eval_file, "w")
423 | while global_step < FLAGS.train_step:
424 | steps_and_files = {}
425 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
426 | for filename in filenames:
427 | if filename.endswith(".index"):
428 | ckpt_name = filename[:-6]
429 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
430 | if cur_filename.split("-")[-1] == "best":
431 | continue
432 | gstep = int(cur_filename.split("-")[-1])
433 | if gstep not in steps_and_files:
434 | tf.logging.info("Add {} to eval list.".format(cur_filename))
435 | steps_and_files[gstep] = cur_filename
436 | tf.logging.info("found {} files.".format(len(steps_and_files)))
437 | if not steps_and_files:
438 | tf.logging.info("found 0 file, global step: {}. Sleeping."
439 | .format(global_step))
440 | time.sleep(60)
441 | else:
442 | for checkpoint in sorted(steps_and_files.items()):
443 | step, checkpoint_path = checkpoint
444 | if global_step >= step:
445 | if (best_perf_global_step != step and
446 | len(_find_valid_cands(step)) > 1):
447 | _remove_checkpoint(checkpoint_path)
448 | continue
449 | result = estimator.evaluate(
450 | input_fn=eval_input_fn,
451 | steps=eval_steps,
452 | checkpoint_path=checkpoint_path)
453 | global_step = result["global_step"]
454 | tf.logging.info("***** Eval results *****")
455 | for key in sorted(result.keys()):
456 | tf.logging.info(" %s = %s", key, str(result[key]))
457 | writer.write("%s = %s\n" % (key, str(result[key])))
458 | writer.write("best = {}\n".format(best_perf))
459 | if result[key_name] > best_perf:
460 | best_perf = result[key_name]
461 | best_perf_global_step = global_step
462 | elif len(_find_valid_cands(global_step)) > 1:
463 | _remove_checkpoint(checkpoint_path)
464 | writer.write("=" * 50 + "\n")
465 | writer.flush()
466 | with tf.gfile.GFile(best_trial_info_file, "w") as best_info:
467 | best_info.write("{}:{}:{}".format(
468 | global_step, best_perf_global_step, best_perf))
469 | writer.close()
470 |
471 | for ext in ["meta", "data-00000-of-00001", "index"]:
472 | src_ckpt = "model.ckpt-{}.{}".format(best_perf_global_step, ext)
473 | tgt_ckpt = "model.ckpt-best.{}".format(ext)
474 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
475 | tf.io.gfile.rename(
476 | os.path.join(FLAGS.output_dir, src_ckpt),
477 | os.path.join(FLAGS.output_dir, tgt_ckpt),
478 | overwrite=True)
479 |
480 | if FLAGS.do_predict:
481 | predict_examples = processor.get_test_examples(FLAGS.data_dir)
482 | num_actual_predict_examples = len(predict_examples)
483 | if FLAGS.use_tpu:
484 | # TPU requires a fixed batch size for all batches, therefore the number
485 | # of examples must be a multiple of the batch size, or else examples
486 | # will get dropped. So we pad with fake examples which are ignored
487 | # later on.
488 | while len(predict_examples) % FLAGS.predict_batch_size != 0:
489 | predict_examples.append(classifier_utils.PaddingInputExample())
490 |
491 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
492 | classifier_utils.file_based_convert_examples_to_features(
493 | predict_examples, label_list,
494 | FLAGS.max_seq_length, tokenizer,
495 | predict_file, task_name)
496 |
497 | tf.logging.info("***** Running prediction*****")
498 | tf.logging.info(" Num examples = %d (%d actual, %d padding)",
499 | len(predict_examples), num_actual_predict_examples,
500 | len(predict_examples) - num_actual_predict_examples)
501 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
502 |
503 | predict_drop_remainder = True if FLAGS.use_tpu else False
504 | predict_input_fn = classifier_utils.file_based_input_fn_builder(
505 | input_file=predict_file,
506 | seq_length=FLAGS.max_seq_length,
507 | is_training=False,
508 | drop_remainder=predict_drop_remainder,
509 | task_name=task_name,
510 | use_tpu=FLAGS.use_tpu,
511 | bsz=FLAGS.predict_batch_size)
512 |
513 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
514 | result = estimator.predict(
515 | input_fn=predict_input_fn,
516 | checkpoint_path=checkpoint_path)
517 |
518 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
519 | output_submit_file = os.path.join(FLAGS.output_dir, "submit_results.tsv")
520 | with tf.gfile.GFile(output_predict_file, "w") as pred_writer,\
521 | tf.gfile.GFile(output_submit_file, "w") as sub_writer:
522 | sub_writer.write("index" + "\t" + "prediction\n")
523 | num_written_lines = 0
524 | tf.logging.info("***** Predict results *****")
525 | for (i, (example, prediction)) in\
526 | enumerate(zip(predict_examples, result)):
527 | probabilities = prediction["probabilities"]
528 | if i >= num_actual_predict_examples:
529 | break
530 | output_line = "\t".join(
531 | str(class_probability)
532 | for class_probability in probabilities) + "\n"
533 | pred_writer.write(output_line)
534 |
535 | if task_name != "sts-b":
536 | actual_label = label_list[int(prediction["predictions"])]
537 | else:
538 | actual_label = str(prediction["predictions"])
539 | sub_writer.write(example.guid + "\t" + actual_label + "\n")
540 | num_written_lines += 1
541 | assert num_written_lines == num_actual_predict_examples
542 |
543 | if FLAGS.export_dir:
544 | tf.gfile.MakeDirs(FLAGS.export_dir)
545 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
546 | tf.logging.info("Starting to export model.")
547 | subfolder = estimator.export_saved_model(
548 | export_dir_base=FLAGS.export_dir,
549 | serving_input_receiver_fn=_serving_input_receiver_fn,
550 | checkpoint_path=checkpoint_path)
551 | tf.logging.info("Model exported to %s.", subfolder)
552 |
553 |
554 | if __name__ == "__main__":
555 | flags.mark_flag_as_required("data_dir")
556 | flags.mark_flag_as_required("task_name")
557 | flags.mark_flag_as_required("spm_model_file")
558 | flags.mark_flag_as_required("output_dir")
559 | tf.app.run()
560 |
--------------------------------------------------------------------------------
/run_glue.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This is a convenience script for evaluating ALBERT on the GLUE benchmark.
3 | #
4 | # By default, this script uses a pretrained ALBERT v1 BASE model, but you may
5 | # use a custom checkpoint or any compatible TF-Hub checkpoint with minimal
6 | # edits to environment variables (see ALBERT_HUB_MODULE_HANDLE below).
7 | #
8 | # This script does fine-tuning and evaluation on 8 tasks, so it may take a
9 | # while to complete if you do not have a hardware accelerator.
10 |
11 | set -ex
12 |
13 | python3 -m venv $HOME/albertenv
14 | . $HOME/albertenv/bin/activate
15 |
16 | OUTPUT_DIR_BASE="$(mktemp -d)"
17 | OUTPUT_DIR="${OUTPUT_DIR_BASE}/output"
18 |
19 | # To start from a custom pretrained checkpoint, set ALBERT_HUB_MODULE_HANDLE
20 | # below to an empty string and set INIT_CHECKPOINT to your checkpoint path.
21 | ALBERT_HUB_MODULE_HANDLE="https://tfhub.dev/google/albert_base/1"
22 | INIT_CHECKPOINT=""
23 |
24 | pip3 install --upgrade pip
25 | pip3 install numpy
26 | pip3 install -r requirements.txt
27 |
28 | function run_task() {
29 | COMMON_ARGS="--output_dir="${OUTPUT_DIR}/$1" --data_dir="${ALBERT_ROOT}/glue" --vocab_file="${ALBERT_ROOT}/vocab.txt" --spm_model_file="${ALBERT_ROOT}/30k-clean.model" --do_lower_case --max_seq_length=512 --optimizer=adamw --task_name=$1 --warmup_step=$2 --learning_rate=$3 --train_step=$4 --save_checkpoints_steps=$5 --train_batch_size=$6"
30 | python3 -m run_classifier \
31 | ${COMMON_ARGS} \
32 | --do_train \
33 | --nodo_eval \
34 | --nodo_predict \
35 | --albert_hub_module_handle="${ALBERT_HUB_MODULE_HANDLE}" \
36 | --init_checkpoint="${INIT_CHECKPOINT}"
37 | python3 -m run_classifier \
38 | ${COMMON_ARGS} \
39 | --nodo_train \
40 | --do_eval \
41 | --do_predict \
42 | --albert_hub_module_handle="${ALBERT_HUB_MODULE_HANDLE}"
43 | }
44 |
45 | run_task SST-2 1256 1e-5 20935 100 32
46 | run_task MNLI 1000 3e-5 10000 100 128
47 | run_task CoLA 320 1e-5 5336 100 16
48 | run_task QNLI 1986 1e-5 33112 200 32
49 | run_task QQP 1000 5e-5 14000 100 128
50 | run_task RTE 200 3e-5 800 100 32
51 | run_task STS-B 214 2e-5 3598 100 16
52 | run_task MRPC 200 2e-5 800 100 32
53 |
--------------------------------------------------------------------------------
/run_pretraining_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tests for run_pretraining."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import random
23 | import tempfile
24 | from absl.testing import flagsaver
25 | from albert import modeling
26 | from albert import run_pretraining
27 | import tensorflow.compat.v1 as tf
28 |
29 | FLAGS = tf.app.flags.FLAGS
30 |
31 |
32 | def _create_config_file(filename, max_seq_length, vocab_size):
33 | """Creates an AlbertConfig and saves it to file."""
34 | albert_config = modeling.AlbertConfig(
35 | vocab_size,
36 | embedding_size=5,
37 | hidden_size=14,
38 | num_hidden_layers=3,
39 | num_hidden_groups=1,
40 | num_attention_heads=2,
41 | intermediate_size=19,
42 | inner_group_num=1,
43 | down_scale_factor=1,
44 | hidden_act="gelu",
45 | hidden_dropout_prob=0,
46 | attention_probs_dropout_prob=0,
47 | max_position_embeddings=max_seq_length,
48 | type_vocab_size=2,
49 | initializer_range=0.02)
50 | with tf.gfile.Open(filename, "w") as outfile:
51 | outfile.write(albert_config.to_json_string())
52 |
53 |
54 | def _create_record(max_predictions_per_seq, max_seq_length, vocab_size):
55 | """Returns a tf.train.Example containing random data."""
56 | example = tf.train.Example()
57 | example.features.feature["input_ids"].int64_list.value.extend(
58 | [random.randint(0, vocab_size - 1) for _ in range(max_seq_length)])
59 | example.features.feature["input_mask"].int64_list.value.extend(
60 | [random.randint(0, 1) for _ in range(max_seq_length)])
61 | example.features.feature["masked_lm_positions"].int64_list.value.extend([
62 | random.randint(0, max_seq_length - 1)
63 | for _ in range(max_predictions_per_seq)
64 | ])
65 | example.features.feature["masked_lm_ids"].int64_list.value.extend([
66 | random.randint(0, vocab_size - 1) for _ in range(max_predictions_per_seq)
67 | ])
68 | example.features.feature["masked_lm_weights"].float_list.value.extend(
69 | [1. for _ in range(max_predictions_per_seq)])
70 | example.features.feature["segment_ids"].int64_list.value.extend(
71 | [0 for _ in range(max_seq_length)])
72 | example.features.feature["next_sentence_labels"].int64_list.value.append(
73 | random.randint(0, 1))
74 | return example
75 |
76 |
77 | def _create_input_file(filename,
78 | max_predictions_per_seq,
79 | max_seq_length,
80 | vocab_size,
81 | size=1000):
82 | """Creates an input TFRecord file of specified size."""
83 | with tf.io.TFRecordWriter(filename) as writer:
84 | for _ in range(size):
85 | ex = _create_record(max_predictions_per_seq, max_seq_length, vocab_size)
86 | writer.write(ex.SerializeToString())
87 |
88 |
89 | class RunPretrainingTest(tf.test.TestCase):
90 |
91 | def _verify_output_file(self, basename):
92 | self.assertTrue(tf.gfile.Exists(os.path.join(FLAGS.output_dir, basename)))
93 |
94 | def _verify_checkpoint_files(self, name):
95 | self._verify_output_file(name + ".meta")
96 | self._verify_output_file(name + ".index")
97 | self._verify_output_file(name + ".data-00000-of-00001")
98 |
99 | @flagsaver.flagsaver
100 | def test_pretraining(self):
101 | # Set up required flags.
102 | vocab_size = 97
103 | FLAGS.max_predictions_per_seq = 7
104 | FLAGS.max_seq_length = 13
105 | FLAGS.output_dir = tempfile.mkdtemp("output_dir")
106 | FLAGS.albert_config_file = os.path.join(
107 | tempfile.mkdtemp("config_dir"), "albert_config.json")
108 | FLAGS.input_file = os.path.join(
109 | tempfile.mkdtemp("input_dir"), "input_data.tfrecord")
110 | FLAGS.do_train = True
111 | FLAGS.do_eval = True
112 | FLAGS.num_train_steps = 1
113 | FLAGS.save_checkpoints_steps = 1
114 |
115 | # Construct requisite input files.
116 | _create_config_file(FLAGS.albert_config_file, FLAGS.max_seq_length,
117 | vocab_size)
118 | _create_input_file(FLAGS.input_file, FLAGS.max_predictions_per_seq,
119 | FLAGS.max_seq_length, vocab_size)
120 |
121 | # Run the pretraining.
122 | run_pretraining.main(None)
123 |
124 | # Verify output.
125 | self._verify_checkpoint_files("model.ckpt-best")
126 | self._verify_checkpoint_files("model.ckpt-1")
127 | self._verify_output_file("eval_results.txt")
128 | self._verify_output_file("checkpoint")
129 |
130 |
131 | if __name__ == "__main__":
132 | tf.test.main()
133 |
--------------------------------------------------------------------------------
/run_race.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ALBERT finetuning runner with sentence piece tokenization."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import time
23 | from albert import classifier_utils
24 | from albert import fine_tuning_utils
25 | from albert import modeling
26 | from albert import race_utils
27 | import tensorflow.compat.v1 as tf
28 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
29 | from tensorflow.contrib import tpu as contrib_tpu
30 |
31 | flags = tf.flags
32 |
33 | FLAGS = flags.FLAGS
34 |
35 | ## Required parameters
36 | flags.DEFINE_string(
37 | "data_dir", None,
38 | "The input data dir. Should contain the .tsv files (or other data files) "
39 | "for the task.")
40 |
41 | flags.DEFINE_string(
42 | "albert_config_file", None,
43 | "The config json file corresponding to the pre-trained ALBERT model. "
44 | "This specifies the model architecture.")
45 |
46 | flags.DEFINE_string("task_name", "race", "The name of the task to train.")
47 |
48 | flags.DEFINE_string("vocab_file", None,
49 | "The vocabulary file that the ALBERT model was trained on.")
50 |
51 | flags.DEFINE_string("train_file", None,
52 | "path to preprocessed tfrecord file. "
53 | "The file will be generated if not exst.")
54 |
55 | flags.DEFINE_string("eval_file", None,
56 | "path to preprocessed tfrecord file. "
57 | "The file will be generated if not exst.")
58 |
59 | flags.DEFINE_string("predict_file", None,
60 | "path to preprocessed tfrecord file. "
61 | "The file will be generated if not exst.")
62 |
63 | flags.DEFINE_string("spm_model_file", None,
64 | "The model file for sentence piece tokenization.")
65 |
66 | flags.DEFINE_string(
67 | "output_dir", None,
68 | "The output directory where the model checkpoints will be written.")
69 |
70 | ## Other parameters
71 |
72 | flags.DEFINE_string(
73 | "init_checkpoint", None,
74 | "Initial checkpoint (usually from a pre-trained ALBERT model).")
75 |
76 | flags.DEFINE_string(
77 | "albert_hub_module_handle", None,
78 | "If set, the ALBERT hub module to use.")
79 |
80 | flags.DEFINE_bool(
81 | "do_lower_case", True,
82 | "Whether to lower case the input text. Should be True for uncased "
83 | "models and False for cased models.")
84 |
85 | flags.DEFINE_float("dropout_prob", 0.1, "dropout probability.")
86 |
87 | flags.DEFINE_integer(
88 | "max_seq_length", 512,
89 | "The maximum total input sequence length after WordPiece tokenization. "
90 | "Sequences longer than this will be truncated, and sequences shorter "
91 | "than this will be padded.")
92 |
93 | flags.DEFINE_integer(
94 | "max_qa_length", 128,
95 | "The maximum total input sequence length after WordPiece tokenization. "
96 | "Sequences longer than this will be truncated, and sequences shorter "
97 | "than this will be padded.")
98 |
99 | flags.DEFINE_integer(
100 | "num_keep_checkpoint", 5,
101 | "maximum number of keep checkpoints")
102 |
103 |
104 | flags.DEFINE_bool(
105 | "high_only", False,
106 | "Whether to only run the model on the high school set.")
107 |
108 | flags.DEFINE_bool(
109 | "middle_only", False,
110 | "Whether to only run the model on the middle school set.")
111 |
112 | flags.DEFINE_bool("do_train", True, "Whether to run training.")
113 |
114 | flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.")
115 |
116 | flags.DEFINE_bool(
117 | "do_predict", False,
118 | "Whether to run the model in inference mode on the test set.")
119 |
120 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
121 |
122 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
123 |
124 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
125 |
126 | flags.DEFINE_float("learning_rate", 1e-5, "The initial learning rate for Adam.")
127 |
128 | flags.DEFINE_integer("train_step", 12000,
129 | "Total number of training epochs to perform.")
130 |
131 | flags.DEFINE_integer(
132 | "warmup_step", 1000,
133 | "number of steps to perform linear learning rate warmup for.")
134 |
135 | flags.DEFINE_integer("save_checkpoints_steps", 100,
136 | "How often to save the model checkpoint.")
137 |
138 | flags.DEFINE_integer("iterations_per_loop", 1000,
139 | "How many steps to make in each estimator call.")
140 |
141 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
142 |
143 | tf.flags.DEFINE_string(
144 | "tpu_name", None,
145 | "The Cloud TPU to use for training. This should be either the name "
146 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
147 | "url.")
148 |
149 | tf.flags.DEFINE_string(
150 | "tpu_zone", None,
151 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
152 | "specified, we will attempt to automatically detect the GCE project from "
153 | "metadata.")
154 |
155 | tf.flags.DEFINE_string(
156 | "gcp_project", None,
157 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
158 | "specified, we will attempt to automatically detect the GCE project from "
159 | "metadata.")
160 |
161 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
162 |
163 | flags.DEFINE_integer(
164 | "num_tpu_cores", 8,
165 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
166 |
167 |
168 | def main(_):
169 | tf.logging.set_verbosity(tf.logging.INFO)
170 |
171 | processors = {
172 | "race": race_utils.RaceProcessor
173 | }
174 |
175 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
176 | raise ValueError(
177 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.")
178 |
179 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
180 |
181 | if FLAGS.max_seq_length > albert_config.max_position_embeddings:
182 | raise ValueError(
183 | "Cannot use sequence length %d because the ALBERT model "
184 | "was only trained up to sequence length %d" %
185 | (FLAGS.max_seq_length, albert_config.max_position_embeddings))
186 |
187 | tf.gfile.MakeDirs(FLAGS.output_dir)
188 |
189 | task_name = FLAGS.task_name.lower()
190 |
191 | if task_name not in processors:
192 | raise ValueError("Task not found: %s" % (task_name))
193 |
194 | processor = processors[task_name](
195 | use_spm=True if FLAGS.spm_model_file else False,
196 | do_lower_case=FLAGS.do_lower_case,
197 | high_only=FLAGS.high_only,
198 | middle_only=FLAGS.middle_only)
199 |
200 | label_list = processor.get_labels()
201 |
202 | tokenizer = fine_tuning_utils.create_vocab(
203 | vocab_file=FLAGS.vocab_file,
204 | do_lower_case=FLAGS.do_lower_case,
205 | spm_model_file=FLAGS.spm_model_file,
206 | hub_module=FLAGS.albert_hub_module_handle)
207 |
208 | tpu_cluster_resolver = None
209 | if FLAGS.use_tpu and FLAGS.tpu_name:
210 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
211 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
212 |
213 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
214 | if FLAGS.do_train:
215 | iterations_per_loop = int(min(FLAGS.iterations_per_loop,
216 | FLAGS.save_checkpoints_steps))
217 | else:
218 | iterations_per_loop = FLAGS.iterations_per_loop
219 | run_config = contrib_tpu.RunConfig(
220 | cluster=tpu_cluster_resolver,
221 | master=FLAGS.master,
222 | model_dir=FLAGS.output_dir,
223 | save_checkpoints_steps=int(FLAGS.save_checkpoints_steps),
224 | keep_checkpoint_max=0,
225 | tpu_config=contrib_tpu.TPUConfig(
226 | iterations_per_loop=iterations_per_loop,
227 | num_shards=FLAGS.num_tpu_cores,
228 | per_host_input_for_training=is_per_host))
229 |
230 | train_examples = None
231 | if FLAGS.do_train:
232 | train_examples = processor.get_train_examples(FLAGS.data_dir)
233 |
234 | model_fn = race_utils.model_fn_builder(
235 | albert_config=albert_config,
236 | num_labels=len(label_list),
237 | init_checkpoint=FLAGS.init_checkpoint,
238 | learning_rate=FLAGS.learning_rate,
239 | num_train_steps=FLAGS.train_step,
240 | num_warmup_steps=FLAGS.warmup_step,
241 | use_tpu=FLAGS.use_tpu,
242 | use_one_hot_embeddings=FLAGS.use_tpu,
243 | max_seq_length=FLAGS.max_seq_length,
244 | dropout_prob=FLAGS.dropout_prob,
245 | hub_module=FLAGS.albert_hub_module_handle)
246 |
247 | # If TPU is not available, this will fall back to normal Estimator on CPU
248 | # or GPU.
249 | estimator = contrib_tpu.TPUEstimator(
250 | use_tpu=FLAGS.use_tpu,
251 | model_fn=model_fn,
252 | config=run_config,
253 | train_batch_size=FLAGS.train_batch_size,
254 | eval_batch_size=FLAGS.eval_batch_size,
255 | predict_batch_size=FLAGS.predict_batch_size)
256 |
257 | if FLAGS.do_train:
258 | if not tf.gfile.Exists(FLAGS.train_file):
259 | race_utils.file_based_convert_examples_to_features(
260 | train_examples, label_list, FLAGS.max_seq_length, tokenizer,
261 | FLAGS.train_file, FLAGS.max_qa_length)
262 | tf.logging.info("***** Running training *****")
263 | tf.logging.info(" Num examples = %d", len(train_examples))
264 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
265 | tf.logging.info(" Num steps = %d", FLAGS.train_step)
266 | train_input_fn = classifier_utils.file_based_input_fn_builder(
267 | input_file=FLAGS.train_file,
268 | seq_length=FLAGS.max_seq_length,
269 | is_training=True,
270 | drop_remainder=True,
271 | task_name=task_name,
272 | use_tpu=FLAGS.use_tpu,
273 | bsz=FLAGS.train_batch_size,
274 | multiple=len(label_list))
275 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_step)
276 |
277 | if FLAGS.do_eval:
278 | eval_examples = processor.get_dev_examples(FLAGS.data_dir)
279 | num_actual_eval_examples = len(eval_examples)
280 | if FLAGS.use_tpu:
281 | # TPU requires a fixed batch size for all batches, therefore the number
282 | # of examples must be a multiple of the batch size, or else examples
283 | # will get dropped. So we pad with fake examples which are ignored
284 | # later on. These do NOT count towards the metric (all tf.metrics
285 | # support a per-instance weight, and these get a weight of 0.0).
286 | while len(eval_examples) % FLAGS.eval_batch_size != 0:
287 | eval_examples.append(classifier_utils.PaddingInputExample())
288 |
289 | if not tf.gfile.Exists(FLAGS.eval_file):
290 | race_utils.file_based_convert_examples_to_features(
291 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer,
292 | FLAGS.eval_file, FLAGS.max_qa_length)
293 |
294 | tf.logging.info("***** Running evaluation *****")
295 | tf.logging.info(" Num examples = %d (%d actual, %d padding)",
296 | len(eval_examples), num_actual_eval_examples,
297 | len(eval_examples) - num_actual_eval_examples)
298 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
299 |
300 | # This tells the estimator to run through the entire set.
301 | eval_steps = None
302 | # However, if running eval on the TPU, you will need to specify the
303 | # number of steps.
304 | if FLAGS.use_tpu:
305 | assert len(eval_examples) % FLAGS.eval_batch_size == 0
306 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)
307 |
308 | eval_drop_remainder = True if FLAGS.use_tpu else False
309 | eval_input_fn = classifier_utils.file_based_input_fn_builder(
310 | input_file=FLAGS.eval_file,
311 | seq_length=FLAGS.max_seq_length,
312 | is_training=False,
313 | drop_remainder=eval_drop_remainder,
314 | task_name=task_name,
315 | use_tpu=FLAGS.use_tpu,
316 | bsz=FLAGS.eval_batch_size,
317 | multiple=len(label_list))
318 |
319 | def _find_valid_cands(curr_step):
320 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
321 | candidates = []
322 | for filename in filenames:
323 | if filename.endswith(".index"):
324 | ckpt_name = filename[:-6]
325 | idx = ckpt_name.split("-")[-1]
326 | if idx != "best" and int(idx) > curr_step:
327 | candidates.append(filename)
328 | return candidates
329 |
330 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
331 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
332 | key_name = "eval_accuracy"
333 | if tf.gfile.Exists(checkpoint_path + ".index"):
334 | result = estimator.evaluate(
335 | input_fn=eval_input_fn,
336 | steps=eval_steps,
337 | checkpoint_path=checkpoint_path)
338 | best_perf = result[key_name]
339 | global_step = result["global_step"]
340 | else:
341 | global_step = -1
342 | best_perf = -1
343 | checkpoint_path = None
344 | writer = tf.gfile.GFile(output_eval_file, "w")
345 | while global_step < FLAGS.train_step:
346 | steps_and_files = {}
347 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
348 | for filename in filenames:
349 | if filename.endswith(".index"):
350 | ckpt_name = filename[:-6]
351 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
352 | if cur_filename.split("-")[-1] == "best":
353 | continue
354 | gstep = int(cur_filename.split("-")[-1])
355 | if gstep not in steps_and_files:
356 | tf.logging.info("Add {} to eval list.".format(cur_filename))
357 | steps_and_files[gstep] = cur_filename
358 | tf.logging.info("found {} files.".format(len(steps_and_files)))
359 | # steps_and_files = sorted(steps_and_files, key=lambda x: x[0])
360 | if not steps_and_files:
361 | tf.logging.info("found 0 file, global step: {}. Sleeping."
362 | .format(global_step))
363 | time.sleep(1)
364 | else:
365 | for ele in sorted(steps_and_files.items()):
366 | step, checkpoint_path = ele
367 | if global_step >= step:
368 | if len(_find_valid_cands(step)) > 1:
369 | for ext in ["meta", "data-00000-of-00001", "index"]:
370 | src_ckpt = checkpoint_path + ".{}".format(ext)
371 | tf.logging.info("removing {}".format(src_ckpt))
372 | tf.gfile.Remove(src_ckpt)
373 | continue
374 | result = estimator.evaluate(
375 | input_fn=eval_input_fn,
376 | steps=eval_steps,
377 | checkpoint_path=checkpoint_path)
378 | global_step = result["global_step"]
379 | tf.logging.info("***** Eval results *****")
380 | for key in sorted(result.keys()):
381 | tf.logging.info(" %s = %s", key, str(result[key]))
382 | writer.write("%s = %s\n" % (key, str(result[key])))
383 | writer.write("best = {}\n".format(best_perf))
384 | if result[key_name] > best_perf:
385 | best_perf = result[key_name]
386 | for ext in ["meta", "data-00000-of-00001", "index"]:
387 | src_ckpt = checkpoint_path + ".{}".format(ext)
388 | tgt_ckpt = checkpoint_path.rsplit("-", 1)[0] + "-best.{}".format(ext)
389 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
390 | tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
391 | writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
392 |
393 | if len(_find_valid_cands(global_step)) > 1:
394 | for ext in ["meta", "data-00000-of-00001", "index"]:
395 | src_ckpt = checkpoint_path + ".{}".format(ext)
396 | tf.logging.info("removing {}".format(src_ckpt))
397 | tf.gfile.Remove(src_ckpt)
398 | writer.write("=" * 50 + "\n")
399 | writer.close()
400 | if FLAGS.do_predict:
401 | predict_examples = processor.get_test_examples(FLAGS.data_dir)
402 | num_actual_predict_examples = len(predict_examples)
403 | if FLAGS.use_tpu:
404 | # TPU requires a fixed batch size for all batches, therefore the number
405 | # of examples must be a multiple of the batch size, or else examples
406 | # will get dropped. So we pad with fake examples which are ignored
407 | # later on.
408 | while len(predict_examples) % FLAGS.predict_batch_size != 0:
409 | predict_examples.append(classifier_utils.PaddingInputExample())
410 | assert len(predict_examples) % FLAGS.predict_batch_size == 0
411 | predict_steps = int(len(predict_examples) // FLAGS.predict_batch_size)
412 |
413 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
414 | race_utils.file_based_convert_examples_to_features(
415 | predict_examples, label_list,
416 | FLAGS.max_seq_length, tokenizer,
417 | predict_file, FLAGS.max_qa_length)
418 |
419 | tf.logging.info("***** Running prediction*****")
420 | tf.logging.info(" Num examples = %d (%d actual, %d padding)",
421 | len(predict_examples), num_actual_predict_examples,
422 | len(predict_examples) - num_actual_predict_examples)
423 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
424 |
425 | predict_drop_remainder = True if FLAGS.use_tpu else False
426 | predict_input_fn = classifier_utils.file_based_input_fn_builder(
427 | input_file=predict_file,
428 | seq_length=FLAGS.max_seq_length,
429 | is_training=False,
430 | drop_remainder=predict_drop_remainder,
431 | task_name=task_name,
432 | use_tpu=FLAGS.use_tpu,
433 | bsz=FLAGS.predict_batch_size,
434 | multiple=len(label_list))
435 |
436 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
437 | result = estimator.evaluate(
438 | input_fn=predict_input_fn,
439 | steps=predict_steps,
440 | checkpoint_path=checkpoint_path)
441 |
442 | output_predict_file = os.path.join(FLAGS.output_dir, "predict_results.txt")
443 | with tf.gfile.GFile(output_predict_file, "w") as pred_writer:
444 | # num_written_lines = 0
445 | tf.logging.info("***** Predict results *****")
446 | pred_writer.write("***** Predict results *****\n")
447 | for key in sorted(result.keys()):
448 | tf.logging.info(" %s = %s", key, str(result[key]))
449 | pred_writer.write("%s = %s\n" % (key, str(result[key])))
450 | pred_writer.write("best = {}\n".format(best_perf))
451 |
452 |
453 | if __name__ == "__main__":
454 | flags.mark_flag_as_required("data_dir")
455 | flags.mark_flag_as_required("spm_model_file")
456 | flags.mark_flag_as_required("albert_config_file")
457 | flags.mark_flag_as_required("output_dir")
458 | tf.app.run()
459 |
--------------------------------------------------------------------------------
/run_squad_v1.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Run ALBERT on SQuAD v1.1 using sentence piece tokenization."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 |
22 | import json
23 | import os
24 | import random
25 | import time
26 | from albert import fine_tuning_utils
27 | from albert import modeling
28 | from albert import squad_utils
29 | import six
30 | import tensorflow.compat.v1 as tf
31 | from tensorflow.compat.v1 import estimator as tf_estimator
32 |
33 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
34 | from tensorflow.contrib import tpu as contrib_tpu
35 |
36 |
37 | # pylint: disable=g-import-not-at-top
38 | if six.PY2:
39 | import six.moves.cPickle as pickle
40 | else:
41 | import pickle
42 | # pylint: enable=g-import-not-at-top
43 |
44 | flags = tf.flags
45 |
46 | FLAGS = flags.FLAGS
47 |
48 | ## Required parameters
49 | flags.DEFINE_string(
50 | "albert_config_file", None,
51 | "The config json file corresponding to the pre-trained BERT model. "
52 | "This specifies the model architecture.")
53 |
54 | flags.DEFINE_string("vocab_file", None,
55 | "The vocabulary file that the BERT model was trained on.")
56 |
57 | flags.DEFINE_string("spm_model_file", None,
58 | "The model file for sentence piece tokenization.")
59 |
60 | flags.DEFINE_string(
61 | "output_dir", None,
62 | "The output directory where the model checkpoints will be written.")
63 |
64 | ## Other parameters
65 | flags.DEFINE_string("train_file", None,
66 | "SQuAD json for training. E.g., train-v1.1.json")
67 |
68 | flags.DEFINE_string(
69 | "predict_file", None,
70 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
71 |
72 | flags.DEFINE_string("train_feature_file", None,
73 | "training feature file.")
74 |
75 | flags.DEFINE_string(
76 | "predict_feature_file", None,
77 | "Location of predict features. If it doesn't exist, it will be written. "
78 | "If it does exist, it will be read.")
79 |
80 | flags.DEFINE_string(
81 | "predict_feature_left_file", None,
82 | "Location of predict features not passed to TPU. If it doesn't exist, it "
83 | "will be written. If it does exist, it will be read.")
84 |
85 | flags.DEFINE_string(
86 | "init_checkpoint", None,
87 | "Initial checkpoint (usually from a pre-trained BERT model).")
88 |
89 | flags.DEFINE_string(
90 | "albert_hub_module_handle", None,
91 | "If set, the ALBERT hub module to use.")
92 |
93 | flags.DEFINE_bool(
94 | "do_lower_case", True,
95 | "Whether to lower case the input text. Should be True for uncased "
96 | "models and False for cased models.")
97 |
98 | flags.DEFINE_integer(
99 | "max_seq_length", 384,
100 | "The maximum total input sequence length after WordPiece tokenization. "
101 | "Sequences longer than this will be truncated, and sequences shorter "
102 | "than this will be padded.")
103 |
104 | flags.DEFINE_integer(
105 | "doc_stride", 128,
106 | "When splitting up a long document into chunks, how much stride to "
107 | "take between chunks.")
108 |
109 | flags.DEFINE_integer(
110 | "max_query_length", 64,
111 | "The maximum number of tokens for the question. Questions longer than "
112 | "this will be truncated to this length.")
113 |
114 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
115 |
116 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
117 |
118 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
119 |
120 | flags.DEFINE_integer("predict_batch_size", 8,
121 | "Total batch size for predictions.")
122 |
123 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
124 |
125 | flags.DEFINE_float("num_train_epochs", 3.0,
126 | "Total number of training epochs to perform.")
127 |
128 | flags.DEFINE_float(
129 | "warmup_proportion", 0.1,
130 | "Proportion of training to perform linear learning rate warmup for. "
131 | "E.g., 0.1 = 10% of training.")
132 |
133 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
134 | "How often to save the model checkpoint.")
135 |
136 | flags.DEFINE_integer("iterations_per_loop", 1000,
137 | "How many steps to make in each estimator call.")
138 |
139 | flags.DEFINE_integer(
140 | "n_best_size", 20,
141 | "The total number of n-best predictions to generate in the "
142 | "nbest_predictions.json output file.")
143 |
144 | flags.DEFINE_integer(
145 | "max_answer_length", 30,
146 | "The maximum length of an answer that can be generated. This is needed "
147 | "because the start and end predictions are not conditioned on one another.")
148 |
149 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
150 |
151 | tf.flags.DEFINE_string(
152 | "tpu_name", None,
153 | "The Cloud TPU to use for training. This should be either the name "
154 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
155 | "url.")
156 |
157 | tf.flags.DEFINE_string(
158 | "tpu_zone", None,
159 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
160 | "specified, we will attempt to automatically detect the GCE project from "
161 | "metadata.")
162 |
163 | tf.flags.DEFINE_string(
164 | "gcp_project", None,
165 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
166 | "specified, we will attempt to automatically detect the GCE project from "
167 | "metadata.")
168 |
169 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
170 |
171 | flags.DEFINE_integer(
172 | "num_tpu_cores", 8,
173 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
174 |
175 | flags.DEFINE_bool(
176 | "use_einsum", True,
177 | "Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must "
178 | "be set to False for TFLite compatibility.")
179 |
180 | flags.DEFINE_string(
181 | "export_dir",
182 | default=None,
183 | help=("The directory where the exported SavedModel will be stored."))
184 |
185 |
186 | def validate_flags_or_throw(albert_config):
187 | """Validate the input FLAGS or throw an exception."""
188 |
189 | if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.export_dir:
190 | err_msg = "At least one of `do_train` or `do_predict` or `export_dir`" + "must be True."
191 | raise ValueError(err_msg)
192 |
193 | if FLAGS.do_train:
194 | if not FLAGS.train_file:
195 | raise ValueError(
196 | "If `do_train` is True, then `train_file` must be specified.")
197 | if FLAGS.do_predict:
198 | if not FLAGS.predict_file:
199 | raise ValueError(
200 | "If `do_predict` is True, then `predict_file` must be specified.")
201 | if not FLAGS.predict_feature_file:
202 | raise ValueError(
203 | "If `do_predict` is True, then `predict_feature_file` must be "
204 | "specified.")
205 | if not FLAGS.predict_feature_left_file:
206 | raise ValueError(
207 | "If `do_predict` is True, then `predict_feature_left_file` must be "
208 | "specified.")
209 |
210 | if FLAGS.max_seq_length > albert_config.max_position_embeddings:
211 | raise ValueError(
212 | "Cannot use sequence length %d because the ALBERT model "
213 | "was only trained up to sequence length %d" %
214 | (FLAGS.max_seq_length, albert_config.max_position_embeddings))
215 |
216 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
217 | raise ValueError(
218 | "The max_seq_length (%d) must be greater than max_query_length "
219 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))
220 |
221 |
222 | def build_squad_serving_input_fn(seq_length):
223 | """Builds a serving input fn for raw input."""
224 |
225 | def _seq_serving_input_fn():
226 | """Serving input fn for raw images."""
227 | input_ids = tf.placeholder(
228 | shape=[1, seq_length], name="input_ids", dtype=tf.int32)
229 | input_mask = tf.placeholder(
230 | shape=[1, seq_length], name="input_mask", dtype=tf.int32)
231 | segment_ids = tf.placeholder(
232 | shape=[1, seq_length], name="segment_ids", dtype=tf.int32)
233 |
234 | inputs = {
235 | "input_ids": input_ids,
236 | "input_mask": input_mask,
237 | "segment_ids": segment_ids
238 | }
239 | return tf_estimator.export.ServingInputReceiver(features=inputs,
240 | receiver_tensors=inputs)
241 |
242 | return _seq_serving_input_fn
243 |
244 |
245 | def main(_):
246 | tf.logging.set_verbosity(tf.logging.INFO)
247 |
248 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
249 |
250 | validate_flags_or_throw(albert_config)
251 |
252 | tf.gfile.MakeDirs(FLAGS.output_dir)
253 |
254 | tokenizer = fine_tuning_utils.create_vocab(
255 | vocab_file=FLAGS.vocab_file,
256 | do_lower_case=FLAGS.do_lower_case,
257 | spm_model_file=FLAGS.spm_model_file,
258 | hub_module=FLAGS.albert_hub_module_handle)
259 |
260 | tpu_cluster_resolver = None
261 | if FLAGS.use_tpu and FLAGS.tpu_name:
262 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
263 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
264 |
265 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
266 | if FLAGS.do_train:
267 | iterations_per_loop = int(min(FLAGS.iterations_per_loop,
268 | FLAGS.save_checkpoints_steps))
269 | else:
270 | iterations_per_loop = FLAGS.iterations_per_loop
271 | run_config = contrib_tpu.RunConfig(
272 | cluster=tpu_cluster_resolver,
273 | master=FLAGS.master,
274 | model_dir=FLAGS.output_dir,
275 | keep_checkpoint_max=0,
276 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
277 | tpu_config=contrib_tpu.TPUConfig(
278 | iterations_per_loop=iterations_per_loop,
279 | num_shards=FLAGS.num_tpu_cores,
280 | per_host_input_for_training=is_per_host))
281 |
282 | train_examples = None
283 | num_train_steps = None
284 | num_warmup_steps = None
285 | if FLAGS.do_train:
286 | train_examples = squad_utils.read_squad_examples(
287 | input_file=FLAGS.train_file, is_training=True)
288 | num_train_steps = int(
289 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
290 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
291 |
292 | # Pre-shuffle the input to avoid having to make a very large shuffle
293 | # buffer in in the `input_fn`.
294 | rng = random.Random(12345)
295 | rng.shuffle(train_examples)
296 |
297 | model_fn = squad_utils.v1_model_fn_builder(
298 | albert_config=albert_config,
299 | init_checkpoint=FLAGS.init_checkpoint,
300 | learning_rate=FLAGS.learning_rate,
301 | num_train_steps=num_train_steps,
302 | num_warmup_steps=num_warmup_steps,
303 | use_tpu=FLAGS.use_tpu,
304 | use_one_hot_embeddings=FLAGS.use_tpu,
305 | use_einsum=FLAGS.use_einsum,
306 | hub_module=FLAGS.albert_hub_module_handle)
307 |
308 | # If TPU is not available, this will fall back to normal Estimator on CPU
309 | # or GPU.
310 | estimator = contrib_tpu.TPUEstimator(
311 | use_tpu=FLAGS.use_tpu,
312 | model_fn=model_fn,
313 | config=run_config,
314 | train_batch_size=FLAGS.train_batch_size,
315 | predict_batch_size=FLAGS.predict_batch_size)
316 |
317 | if FLAGS.do_train:
318 | # We write to a temporary file to avoid storing very large constant tensors
319 | # in memory.
320 |
321 | if not tf.gfile.Exists(FLAGS.train_feature_file):
322 | train_writer = squad_utils.FeatureWriter(
323 | filename=os.path.join(FLAGS.train_feature_file), is_training=True)
324 | squad_utils.convert_examples_to_features(
325 | examples=train_examples,
326 | tokenizer=tokenizer,
327 | max_seq_length=FLAGS.max_seq_length,
328 | doc_stride=FLAGS.doc_stride,
329 | max_query_length=FLAGS.max_query_length,
330 | is_training=True,
331 | output_fn=train_writer.process_feature,
332 | do_lower_case=FLAGS.do_lower_case)
333 | train_writer.close()
334 |
335 | tf.logging.info("***** Running training *****")
336 | tf.logging.info(" Num orig examples = %d", len(train_examples))
337 | # tf.logging.info(" Num split examples = %d", train_writer.num_features)
338 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
339 | tf.logging.info(" Num steps = %d", num_train_steps)
340 | del train_examples
341 |
342 | train_input_fn = squad_utils.input_fn_builder(
343 | input_file=FLAGS.train_feature_file,
344 | seq_length=FLAGS.max_seq_length,
345 | is_training=True,
346 | drop_remainder=True,
347 | use_tpu=FLAGS.use_tpu,
348 | bsz=FLAGS.train_batch_size,
349 | is_v2=False)
350 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
351 |
352 | if FLAGS.do_predict:
353 | with tf.gfile.Open(FLAGS.predict_file) as predict_file:
354 | prediction_json = json.load(predict_file)["data"]
355 |
356 | eval_examples = squad_utils.read_squad_examples(
357 | input_file=FLAGS.predict_file, is_training=False)
358 |
359 | if (tf.gfile.Exists(FLAGS.predict_feature_file) and tf.gfile.Exists(
360 | FLAGS.predict_feature_left_file)):
361 | tf.logging.info("Loading eval features from {}".format(
362 | FLAGS.predict_feature_left_file))
363 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin:
364 | eval_features = pickle.load(fin)
365 | else:
366 | eval_writer = squad_utils.FeatureWriter(
367 | filename=FLAGS.predict_feature_file, is_training=False)
368 | eval_features = []
369 |
370 | def append_feature(feature):
371 | eval_features.append(feature)
372 | eval_writer.process_feature(feature)
373 |
374 | squad_utils.convert_examples_to_features(
375 | examples=eval_examples,
376 | tokenizer=tokenizer,
377 | max_seq_length=FLAGS.max_seq_length,
378 | doc_stride=FLAGS.doc_stride,
379 | max_query_length=FLAGS.max_query_length,
380 | is_training=False,
381 | output_fn=append_feature,
382 | do_lower_case=FLAGS.do_lower_case)
383 | eval_writer.close()
384 |
385 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout:
386 | pickle.dump(eval_features, fout)
387 |
388 | tf.logging.info("***** Running predictions *****")
389 | tf.logging.info(" Num orig examples = %d", len(eval_examples))
390 | tf.logging.info(" Num split examples = %d", len(eval_features))
391 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
392 |
393 | predict_input_fn = squad_utils.input_fn_builder(
394 | input_file=FLAGS.predict_feature_file,
395 | seq_length=FLAGS.max_seq_length,
396 | is_training=False,
397 | drop_remainder=False,
398 | use_tpu=FLAGS.use_tpu,
399 | bsz=FLAGS.predict_batch_size,
400 | is_v2=False)
401 |
402 | def get_result(checkpoint):
403 | """Evaluate the checkpoint on SQuAD 1.0."""
404 | # If running eval on the TPU, you will need to specify the number of
405 | # steps.
406 | reader = tf.train.NewCheckpointReader(checkpoint)
407 | global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
408 | all_results = []
409 | for result in estimator.predict(
410 | predict_input_fn, yield_single_examples=True,
411 | checkpoint_path=checkpoint):
412 | if len(all_results) % 1000 == 0:
413 | tf.logging.info("Processing example: %d" % (len(all_results)))
414 | unique_id = int(result["unique_ids"])
415 | start_log_prob = [float(x) for x in result["start_log_prob"].flat]
416 | end_log_prob = [float(x) for x in result["end_log_prob"].flat]
417 | all_results.append(
418 | squad_utils.RawResult(
419 | unique_id=unique_id,
420 | start_log_prob=start_log_prob,
421 | end_log_prob=end_log_prob))
422 |
423 | output_prediction_file = os.path.join(
424 | FLAGS.output_dir, "predictions.json")
425 | output_nbest_file = os.path.join(
426 | FLAGS.output_dir, "nbest_predictions.json")
427 |
428 | result_dict = {}
429 | squad_utils.accumulate_predictions_v1(
430 | result_dict, eval_examples, eval_features,
431 | all_results, FLAGS.n_best_size, FLAGS.max_answer_length)
432 | predictions = squad_utils.write_predictions_v1(
433 | result_dict, eval_examples, eval_features, all_results,
434 | FLAGS.n_best_size, FLAGS.max_answer_length,
435 | output_prediction_file, output_nbest_file)
436 |
437 | return squad_utils.evaluate_v1(
438 | prediction_json, predictions), int(global_step)
439 |
440 | def _find_valid_cands(curr_step):
441 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
442 | candidates = []
443 | for filename in filenames:
444 | if filename.endswith(".index"):
445 | ckpt_name = filename[:-6]
446 | idx = ckpt_name.split("-")[-1]
447 | if idx != "best" and int(idx) > curr_step:
448 | candidates.append(filename)
449 | return candidates
450 |
451 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
452 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
453 | key_name = "f1"
454 | writer = tf.gfile.GFile(output_eval_file, "w")
455 | if tf.gfile.Exists(checkpoint_path + ".index"):
456 | result = get_result(checkpoint_path)
457 | best_perf = result[0][key_name]
458 | global_step = result[1]
459 | else:
460 | global_step = -1
461 | best_perf = -1
462 | checkpoint_path = None
463 | while global_step < num_train_steps:
464 | steps_and_files = {}
465 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
466 | for filename in filenames:
467 | if filename.endswith(".index"):
468 | ckpt_name = filename[:-6]
469 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
470 | if cur_filename.split("-")[-1] == "best":
471 | continue
472 | gstep = int(cur_filename.split("-")[-1])
473 | if gstep not in steps_and_files:
474 | tf.logging.info("Add {} to eval list.".format(cur_filename))
475 | steps_and_files[gstep] = cur_filename
476 | tf.logging.info("found {} files.".format(len(steps_and_files)))
477 | if not steps_and_files:
478 | tf.logging.info("found 0 file, global step: {}. Sleeping."
479 | .format(global_step))
480 | time.sleep(60)
481 | else:
482 | for ele in sorted(steps_and_files.items()):
483 | step, checkpoint_path = ele
484 | if global_step >= step:
485 | if len(_find_valid_cands(step)) > 1:
486 | for ext in ["meta", "data-00000-of-00001", "index"]:
487 | src_ckpt = checkpoint_path + ".{}".format(ext)
488 | tf.logging.info("removing {}".format(src_ckpt))
489 | tf.gfile.Remove(src_ckpt)
490 | continue
491 | result, global_step = get_result(checkpoint_path)
492 | tf.logging.info("***** Eval results *****")
493 | for key in sorted(result.keys()):
494 | tf.logging.info(" %s = %s", key, str(result[key]))
495 | writer.write("%s = %s\n" % (key, str(result[key])))
496 | if result[key_name] > best_perf:
497 | best_perf = result[key_name]
498 | for ext in ["meta", "data-00000-of-00001", "index"]:
499 | src_ckpt = checkpoint_path + ".{}".format(ext)
500 | tgt_ckpt = checkpoint_path.rsplit(
501 | "-", 1)[0] + "-best.{}".format(ext)
502 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
503 | tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
504 | writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
505 | writer.write("best {} = {}\n".format(key_name, best_perf))
506 | tf.logging.info(" best {} = {}\n".format(key_name, best_perf))
507 |
508 | if len(_find_valid_cands(global_step)) > 2:
509 | for ext in ["meta", "data-00000-of-00001", "index"]:
510 | src_ckpt = checkpoint_path + ".{}".format(ext)
511 | tf.logging.info("removing {}".format(src_ckpt))
512 | tf.gfile.Remove(src_ckpt)
513 | writer.write("=" * 50 + "\n")
514 |
515 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
516 | result, global_step = get_result(checkpoint_path)
517 | tf.logging.info("***** Final Eval results *****")
518 | for key in sorted(result.keys()):
519 | tf.logging.info(" %s = %s", key, str(result[key]))
520 | writer.write("%s = %s\n" % (key, str(result[key])))
521 | writer.write("best perf happened at step: {}".format(global_step))
522 |
523 | if FLAGS.export_dir:
524 | tf.gfile.MakeDirs(FLAGS.export_dir)
525 | squad_serving_input_fn = (
526 | build_squad_serving_input_fn(FLAGS.max_seq_length))
527 | tf.logging.info("Starting to export model.")
528 | subfolder = estimator.export_saved_model(
529 | export_dir_base=os.path.join(FLAGS.export_dir, "saved_model"),
530 | serving_input_receiver_fn=squad_serving_input_fn)
531 |
532 | tf.logging.info("Starting to export TFLite.")
533 | converter = tf.lite.TFLiteConverter.from_saved_model(
534 | subfolder,
535 | input_arrays=["input_ids", "input_mask", "segment_ids"],
536 | output_arrays=["start_logits", "end_logits"])
537 | float_model = converter.convert()
538 | tflite_file = os.path.join(FLAGS.export_dir, "albert_model.tflite")
539 | with tf.gfile.GFile(tflite_file, "wb") as f:
540 | f.write(float_model)
541 |
542 |
543 | if __name__ == "__main__":
544 | flags.mark_flag_as_required("spm_model_file")
545 | flags.mark_flag_as_required("albert_config_file")
546 | flags.mark_flag_as_required("output_dir")
547 | tf.app.run()
548 |
--------------------------------------------------------------------------------
/run_squad_v2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Run ALBERT on SQuAD v2.0 using sentence piece tokenization."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 |
22 | import json
23 | import os
24 | import random
25 | import time
26 |
27 | from albert import fine_tuning_utils
28 | from albert import modeling
29 | from albert import squad_utils
30 | import six
31 | import tensorflow.compat.v1 as tf
32 |
33 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
34 | from tensorflow.contrib import tpu as contrib_tpu
35 |
36 |
37 | # pylint: disable=g-import-not-at-top
38 | if six.PY2:
39 | import six.moves.cPickle as pickle
40 | else:
41 | import pickle
42 | # pylint: enable=g-import-not-at-top
43 |
44 | flags = tf.flags
45 |
46 | FLAGS = flags.FLAGS
47 |
48 | ## Required parameters
49 | flags.DEFINE_string(
50 | "albert_config_file", None,
51 | "The config json file corresponding to the pre-trained ALBERT model. "
52 | "This specifies the model architecture.")
53 |
54 | flags.DEFINE_string("vocab_file", None,
55 | "The vocabulary file that the ALBERT model was trained on.")
56 |
57 | flags.DEFINE_string("spm_model_file", None,
58 | "The model file for sentence piece tokenization.")
59 |
60 | flags.DEFINE_string(
61 | "output_dir", None,
62 | "The output directory where the model checkpoints will be written.")
63 |
64 | ## Other parameters
65 | flags.DEFINE_string("train_file", None,
66 | "SQuAD json for training. E.g., train-v1.1.json")
67 |
68 | flags.DEFINE_string(
69 | "predict_file", None,
70 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
71 |
72 | flags.DEFINE_string("train_feature_file", None,
73 | "training feature file.")
74 |
75 | flags.DEFINE_string(
76 | "predict_feature_file", None,
77 | "Location of predict features. If it doesn't exist, it will be written. "
78 | "If it does exist, it will be read.")
79 |
80 | flags.DEFINE_string(
81 | "predict_feature_left_file", None,
82 | "Location of predict features not passed to TPU. If it doesn't exist, it "
83 | "will be written. If it does exist, it will be read.")
84 |
85 | flags.DEFINE_string(
86 | "init_checkpoint", None,
87 | "Initial checkpoint (usually from a pre-trained BERT model).")
88 |
89 | flags.DEFINE_string(
90 | "albert_hub_module_handle", None,
91 | "If set, the ALBERT hub module to use.")
92 |
93 | flags.DEFINE_bool(
94 | "do_lower_case", True,
95 | "Whether to lower case the input text. Should be True for uncased "
96 | "models and False for cased models.")
97 |
98 | flags.DEFINE_integer(
99 | "max_seq_length", 384,
100 | "The maximum total input sequence length after WordPiece tokenization. "
101 | "Sequences longer than this will be truncated, and sequences shorter "
102 | "than this will be padded.")
103 |
104 | flags.DEFINE_integer(
105 | "doc_stride", 128,
106 | "When splitting up a long document into chunks, how much stride to "
107 | "take between chunks.")
108 |
109 | flags.DEFINE_integer(
110 | "max_query_length", 64,
111 | "The maximum number of tokens for the question. Questions longer than "
112 | "this will be truncated to this length.")
113 |
114 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
115 |
116 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
117 |
118 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
119 |
120 | flags.DEFINE_integer("predict_batch_size", 8,
121 | "Total batch size for predictions.")
122 |
123 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
124 |
125 | flags.DEFINE_float("num_train_epochs", 3.0,
126 | "Total number of training epochs to perform.")
127 |
128 | flags.DEFINE_float(
129 | "warmup_proportion", 0.1,
130 | "Proportion of training to perform linear learning rate warmup for. "
131 | "E.g., 0.1 = 10% of training.")
132 |
133 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
134 | "How often to save the model checkpoint.")
135 |
136 | flags.DEFINE_integer("iterations_per_loop", 1000,
137 | "How many steps to make in each estimator call.")
138 |
139 | flags.DEFINE_integer(
140 | "n_best_size", 20,
141 | "The total number of n-best predictions to generate in the "
142 | "nbest_predictions.json output file.")
143 |
144 | flags.DEFINE_integer(
145 | "max_answer_length", 30,
146 | "The maximum length of an answer that can be generated. This is needed "
147 | "because the start and end predictions are not conditioned on one another.")
148 |
149 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
150 |
151 | tf.flags.DEFINE_string(
152 | "tpu_name", None,
153 | "The Cloud TPU to use for training. This should be either the name "
154 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
155 | "url.")
156 |
157 | tf.flags.DEFINE_string(
158 | "tpu_zone", None,
159 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
160 | "specified, we will attempt to automatically detect the GCE project from "
161 | "metadata.")
162 |
163 | tf.flags.DEFINE_string(
164 | "gcp_project", None,
165 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
166 | "specified, we will attempt to automatically detect the GCE project from "
167 | "metadata.")
168 |
169 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
170 |
171 | flags.DEFINE_integer(
172 | "num_tpu_cores", 8,
173 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
174 |
175 |
176 | flags.DEFINE_integer("start_n_top", 5, "beam size for the start positions.")
177 |
178 | flags.DEFINE_integer("end_n_top", 5, "beam size for the end positions.")
179 |
180 | flags.DEFINE_float("dropout_prob", 0.1, "dropout probability.")
181 |
182 |
183 | def validate_flags_or_throw(albert_config):
184 | """Validate the input FLAGS or throw an exception."""
185 |
186 | if not FLAGS.do_train and not FLAGS.do_predict:
187 | raise ValueError("At least one of `do_train` or `do_predict` must be True.")
188 |
189 | if FLAGS.do_train:
190 | if not FLAGS.train_file:
191 | raise ValueError(
192 | "If `do_train` is True, then `train_file` must be specified.")
193 | if FLAGS.do_predict:
194 | if not FLAGS.predict_file:
195 | raise ValueError(
196 | "If `do_predict` is True, then `predict_file` must be specified.")
197 | if not FLAGS.predict_feature_file:
198 | raise ValueError(
199 | "If `do_predict` is True, then `predict_feature_file` must be "
200 | "specified.")
201 | if not FLAGS.predict_feature_left_file:
202 | raise ValueError(
203 | "If `do_predict` is True, then `predict_feature_left_file` must be "
204 | "specified.")
205 |
206 | if FLAGS.max_seq_length > albert_config.max_position_embeddings:
207 | raise ValueError(
208 | "Cannot use sequence length %d because the ALBERT model "
209 | "was only trained up to sequence length %d" %
210 | (FLAGS.max_seq_length, albert_config.max_position_embeddings))
211 |
212 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
213 | raise ValueError(
214 | "The max_seq_length (%d) must be greater than max_query_length "
215 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))
216 |
217 |
218 | def main(_):
219 | tf.logging.set_verbosity(tf.logging.INFO)
220 |
221 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
222 |
223 | validate_flags_or_throw(albert_config)
224 |
225 | tf.gfile.MakeDirs(FLAGS.output_dir)
226 |
227 | tokenizer = fine_tuning_utils.create_vocab(
228 | vocab_file=FLAGS.vocab_file,
229 | do_lower_case=FLAGS.do_lower_case,
230 | spm_model_file=FLAGS.spm_model_file,
231 | hub_module=FLAGS.albert_hub_module_handle)
232 |
233 | tpu_cluster_resolver = None
234 | if FLAGS.use_tpu and FLAGS.tpu_name:
235 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver(
236 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
237 |
238 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2
239 | if FLAGS.do_train:
240 | iterations_per_loop = int(min(FLAGS.iterations_per_loop,
241 | FLAGS.save_checkpoints_steps))
242 | else:
243 | iterations_per_loop = FLAGS.iterations_per_loop
244 | run_config = contrib_tpu.RunConfig(
245 | cluster=tpu_cluster_resolver,
246 | master=FLAGS.master,
247 | model_dir=FLAGS.output_dir,
248 | keep_checkpoint_max=0,
249 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
250 | tpu_config=contrib_tpu.TPUConfig(
251 | iterations_per_loop=iterations_per_loop,
252 | num_shards=FLAGS.num_tpu_cores,
253 | per_host_input_for_training=is_per_host))
254 |
255 | train_examples = None
256 | num_train_steps = None
257 | num_warmup_steps = None
258 | train_examples = squad_utils.read_squad_examples(
259 | input_file=FLAGS.train_file, is_training=True)
260 | num_train_steps = int(
261 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
262 | if FLAGS.do_train:
263 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
264 |
265 | # Pre-shuffle the input to avoid having to make a very large shuffle
266 | # buffer in in the `input_fn`.
267 | rng = random.Random(12345)
268 | rng.shuffle(train_examples)
269 |
270 | model_fn = squad_utils.v2_model_fn_builder(
271 | albert_config=albert_config,
272 | init_checkpoint=FLAGS.init_checkpoint,
273 | learning_rate=FLAGS.learning_rate,
274 | num_train_steps=num_train_steps,
275 | num_warmup_steps=num_warmup_steps,
276 | use_tpu=FLAGS.use_tpu,
277 | use_one_hot_embeddings=FLAGS.use_tpu,
278 | max_seq_length=FLAGS.max_seq_length,
279 | start_n_top=FLAGS.start_n_top,
280 | end_n_top=FLAGS.end_n_top,
281 | dropout_prob=FLAGS.dropout_prob,
282 | hub_module=FLAGS.albert_hub_module_handle)
283 |
284 | # If TPU is not available, this will fall back to normal Estimator on CPU
285 | # or GPU.
286 | estimator = contrib_tpu.TPUEstimator(
287 | use_tpu=FLAGS.use_tpu,
288 | model_fn=model_fn,
289 | config=run_config,
290 | train_batch_size=FLAGS.train_batch_size,
291 | predict_batch_size=FLAGS.predict_batch_size)
292 |
293 | if FLAGS.do_train:
294 | # We write to a temporary file to avoid storing very large constant tensors
295 | # in memory.
296 |
297 | if not tf.gfile.Exists(FLAGS.train_feature_file):
298 | train_writer = squad_utils.FeatureWriter(
299 | filename=os.path.join(FLAGS.train_feature_file), is_training=True)
300 | squad_utils.convert_examples_to_features(
301 | examples=train_examples,
302 | tokenizer=tokenizer,
303 | max_seq_length=FLAGS.max_seq_length,
304 | doc_stride=FLAGS.doc_stride,
305 | max_query_length=FLAGS.max_query_length,
306 | is_training=True,
307 | output_fn=train_writer.process_feature,
308 | do_lower_case=FLAGS.do_lower_case)
309 | train_writer.close()
310 |
311 | tf.logging.info("***** Running training *****")
312 | tf.logging.info(" Num orig examples = %d", len(train_examples))
313 | # tf.logging.info(" Num split examples = %d", train_writer.num_features)
314 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
315 | tf.logging.info(" Num steps = %d", num_train_steps)
316 | del train_examples
317 |
318 | train_input_fn = squad_utils.input_fn_builder(
319 | input_file=FLAGS.train_feature_file,
320 | seq_length=FLAGS.max_seq_length,
321 | is_training=True,
322 | drop_remainder=True,
323 | use_tpu=FLAGS.use_tpu,
324 | bsz=FLAGS.train_batch_size,
325 | is_v2=True)
326 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
327 |
328 | if FLAGS.do_predict:
329 | with tf.gfile.Open(FLAGS.predict_file) as predict_file:
330 | prediction_json = json.load(predict_file)["data"]
331 | eval_examples = squad_utils.read_squad_examples(
332 | input_file=FLAGS.predict_file, is_training=False)
333 |
334 | if (tf.gfile.Exists(FLAGS.predict_feature_file) and tf.gfile.Exists(
335 | FLAGS.predict_feature_left_file)):
336 | tf.logging.info("Loading eval features from {}".format(
337 | FLAGS.predict_feature_left_file))
338 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin:
339 | eval_features = pickle.load(fin)
340 | else:
341 | eval_writer = squad_utils.FeatureWriter(
342 | filename=FLAGS.predict_feature_file, is_training=False)
343 | eval_features = []
344 |
345 | def append_feature(feature):
346 | eval_features.append(feature)
347 | eval_writer.process_feature(feature)
348 |
349 | squad_utils.convert_examples_to_features(
350 | examples=eval_examples,
351 | tokenizer=tokenizer,
352 | max_seq_length=FLAGS.max_seq_length,
353 | doc_stride=FLAGS.doc_stride,
354 | max_query_length=FLAGS.max_query_length,
355 | is_training=False,
356 | output_fn=append_feature,
357 | do_lower_case=FLAGS.do_lower_case)
358 | eval_writer.close()
359 |
360 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout:
361 | pickle.dump(eval_features, fout)
362 |
363 | tf.logging.info("***** Running predictions *****")
364 | tf.logging.info(" Num orig examples = %d", len(eval_examples))
365 | tf.logging.info(" Num split examples = %d", len(eval_features))
366 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
367 |
368 | predict_input_fn = squad_utils.input_fn_builder(
369 | input_file=FLAGS.predict_feature_file,
370 | seq_length=FLAGS.max_seq_length,
371 | is_training=False,
372 | drop_remainder=False,
373 | use_tpu=FLAGS.use_tpu,
374 | bsz=FLAGS.predict_batch_size,
375 | is_v2=True)
376 |
377 | def get_result(checkpoint):
378 | """Evaluate the checkpoint on SQuAD v2.0."""
379 | # If running eval on the TPU, you will need to specify the number of
380 | # steps.
381 | reader = tf.train.NewCheckpointReader(checkpoint)
382 | global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
383 | all_results = []
384 | for result in estimator.predict(
385 | predict_input_fn, yield_single_examples=True,
386 | checkpoint_path=checkpoint):
387 | if len(all_results) % 1000 == 0:
388 | tf.logging.info("Processing example: %d" % (len(all_results)))
389 | unique_id = int(result["unique_ids"])
390 | start_top_log_probs = (
391 | [float(x) for x in result["start_top_log_probs"].flat])
392 | start_top_index = [int(x) for x in result["start_top_index"].flat]
393 | end_top_log_probs = (
394 | [float(x) for x in result["end_top_log_probs"].flat])
395 | end_top_index = [int(x) for x in result["end_top_index"].flat]
396 |
397 | cls_logits = float(result["cls_logits"].flat[0])
398 | all_results.append(
399 | squad_utils.RawResultV2(
400 | unique_id=unique_id,
401 | start_top_log_probs=start_top_log_probs,
402 | start_top_index=start_top_index,
403 | end_top_log_probs=end_top_log_probs,
404 | end_top_index=end_top_index,
405 | cls_logits=cls_logits))
406 |
407 | output_prediction_file = os.path.join(
408 | FLAGS.output_dir, "predictions.json")
409 | output_nbest_file = os.path.join(
410 | FLAGS.output_dir, "nbest_predictions.json")
411 | output_null_log_odds_file = os.path.join(
412 | FLAGS.output_dir, "null_odds.json")
413 |
414 | result_dict = {}
415 | cls_dict = {}
416 | squad_utils.accumulate_predictions_v2(
417 | result_dict, cls_dict, eval_examples, eval_features,
418 | all_results, FLAGS.n_best_size, FLAGS.max_answer_length,
419 | FLAGS.start_n_top, FLAGS.end_n_top)
420 |
421 | return squad_utils.evaluate_v2(
422 | result_dict, cls_dict, prediction_json, eval_examples,
423 | eval_features, all_results, FLAGS.n_best_size,
424 | FLAGS.max_answer_length, output_prediction_file, output_nbest_file,
425 | output_null_log_odds_file), int(global_step)
426 |
427 | def _find_valid_cands(curr_step):
428 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
429 | candidates = []
430 | for filename in filenames:
431 | if filename.endswith(".index"):
432 | ckpt_name = filename[:-6]
433 | idx = ckpt_name.split("-")[-1]
434 | if idx != "best" and int(idx) > curr_step:
435 | candidates.append(filename)
436 | return candidates
437 |
438 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
439 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
440 | key_name = "f1"
441 | writer = tf.gfile.GFile(output_eval_file, "w")
442 | if tf.gfile.Exists(checkpoint_path + ".index"):
443 | result = get_result(checkpoint_path)
444 | best_perf = result[0][key_name]
445 | global_step = result[1]
446 | else:
447 | global_step = -1
448 | best_perf = -1
449 | checkpoint_path = None
450 | while global_step < num_train_steps:
451 | steps_and_files = {}
452 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
453 | for filename in filenames:
454 | if filename.endswith(".index"):
455 | ckpt_name = filename[:-6]
456 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
457 | if cur_filename.split("-")[-1] == "best":
458 | continue
459 | gstep = int(cur_filename.split("-")[-1])
460 | if gstep not in steps_and_files:
461 | tf.logging.info("Add {} to eval list.".format(cur_filename))
462 | steps_and_files[gstep] = cur_filename
463 | tf.logging.info("found {} files.".format(len(steps_and_files)))
464 | if not steps_and_files:
465 | tf.logging.info("found 0 file, global step: {}. Sleeping."
466 | .format(global_step))
467 | time.sleep(60)
468 | else:
469 | for ele in sorted(steps_and_files.items()):
470 | step, checkpoint_path = ele
471 | if global_step >= step:
472 | if len(_find_valid_cands(step)) > 1:
473 | for ext in ["meta", "data-00000-of-00001", "index"]:
474 | src_ckpt = checkpoint_path + ".{}".format(ext)
475 | tf.logging.info("removing {}".format(src_ckpt))
476 | tf.gfile.Remove(src_ckpt)
477 | continue
478 | result, global_step = get_result(checkpoint_path)
479 | tf.logging.info("***** Eval results *****")
480 | for key in sorted(result.keys()):
481 | tf.logging.info(" %s = %s", key, str(result[key]))
482 | writer.write("%s = %s\n" % (key, str(result[key])))
483 | if result[key_name] > best_perf:
484 | best_perf = result[key_name]
485 | for ext in ["meta", "data-00000-of-00001", "index"]:
486 | src_ckpt = checkpoint_path + ".{}".format(ext)
487 | tgt_ckpt = checkpoint_path.rsplit(
488 | "-", 1)[0] + "-best.{}".format(ext)
489 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt))
490 | tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True)
491 | writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt))
492 | writer.write("best {} = {}\n".format(key_name, best_perf))
493 | tf.logging.info(" best {} = {}\n".format(key_name, best_perf))
494 |
495 | if len(_find_valid_cands(global_step)) > 2:
496 | for ext in ["meta", "data-00000-of-00001", "index"]:
497 | src_ckpt = checkpoint_path + ".{}".format(ext)
498 | tf.logging.info("removing {}".format(src_ckpt))
499 | tf.gfile.Remove(src_ckpt)
500 | writer.write("=" * 50 + "\n")
501 |
502 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
503 | result, global_step = get_result(checkpoint_path)
504 | tf.logging.info("***** Final Eval results *****")
505 | for key in sorted(result.keys()):
506 | tf.logging.info(" %s = %s", key, str(result[key]))
507 | writer.write("%s = %s\n" % (key, str(result[key])))
508 | writer.write("best perf happened at step: {}".format(global_step))
509 |
510 |
511 | if __name__ == "__main__":
512 | flags.mark_flag_as_required("spm_model_file")
513 | flags.mark_flag_as_required("albert_config_file")
514 | flags.mark_flag_as_required("output_dir")
515 | tf.app.run()
516 |
--------------------------------------------------------------------------------
/run_trivial_model_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Small integration test script.
3 | # The values in this file are **not** meant for reproducing actual results.
4 |
5 | set -e
6 | set -x
7 |
8 | virtualenv -p python3 .
9 | source ./bin/activate
10 |
11 | OUTPUT_DIR_BASE="$(mktemp -d)"
12 | OUTPUT_DIR="${OUTPUT_DIR_BASE}/output"
13 |
14 | pip install numpy
15 | pip install -r requirements.txt
16 | python -m run_pretraining_test \
17 | --output_dir="${OUTPUT_DIR}" \
18 | --do_train \
19 | --do_eval \
20 | --nouse_tpu \
21 | --train_batch_size=2 \
22 | --eval_batch_size=1 \
23 | --max_seq_length=4 \
24 | --num_train_steps=2 \
25 | --max_eval_steps=3
26 |
27 |
28 |
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # coding=utf-8
16 | """Tokenization classes."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 | import unicodedata
24 | import six
25 | from six.moves import range
26 | import tensorflow.compat.v1 as tf
27 | import tensorflow_hub as hub
28 | import sentencepiece as spm
29 |
30 | SPIECE_UNDERLINE = u"▁".encode("utf-8")
31 |
32 |
33 | def preprocess_text(inputs, remove_space=True, lower=False):
34 | """preprocess data by removing extra space and normalize data."""
35 | outputs = inputs
36 | if remove_space:
37 | outputs = " ".join(inputs.strip().split())
38 |
39 | if six.PY2 and isinstance(outputs, str):
40 | try:
41 | outputs = six.ensure_text(outputs, "utf-8")
42 | except UnicodeDecodeError:
43 | outputs = six.ensure_text(outputs, "latin-1")
44 |
45 | outputs = unicodedata.normalize("NFKD", outputs)
46 | outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
47 | if lower:
48 | outputs = outputs.lower()
49 |
50 | return outputs
51 |
52 |
53 | def encode_pieces(sp_model, text, return_unicode=True, sample=False):
54 | """turn sentences into word pieces."""
55 |
56 | if six.PY2 and isinstance(text, six.text_type):
57 | text = six.ensure_binary(text, "utf-8")
58 |
59 | if not sample:
60 | pieces = sp_model.EncodeAsPieces(text)
61 | else:
62 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
63 | new_pieces = []
64 | for piece in pieces:
65 | piece = printable_text(piece)
66 | if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit():
67 | cur_pieces = sp_model.EncodeAsPieces(
68 | six.ensure_binary(piece[:-1]).replace(SPIECE_UNDERLINE, b""))
69 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
70 | if len(cur_pieces[0]) == 1:
71 | cur_pieces = cur_pieces[1:]
72 | else:
73 | cur_pieces[0] = cur_pieces[0][1:]
74 | cur_pieces.append(piece[-1])
75 | new_pieces.extend(cur_pieces)
76 | else:
77 | new_pieces.append(piece)
78 |
79 | # note(zhiliny): convert back to unicode for py2
80 | if six.PY2 and return_unicode:
81 | ret_pieces = []
82 | for piece in new_pieces:
83 | if isinstance(piece, str):
84 | piece = six.ensure_text(piece, "utf-8")
85 | ret_pieces.append(piece)
86 | new_pieces = ret_pieces
87 |
88 | return new_pieces
89 |
90 |
91 | def encode_ids(sp_model, text, sample=False):
92 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
93 | ids = [sp_model.PieceToId(piece) for piece in pieces]
94 | return ids
95 |
96 |
97 | def convert_to_unicode(text):
98 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
99 | if six.PY3:
100 | if isinstance(text, str):
101 | return text
102 | elif isinstance(text, bytes):
103 | return six.ensure_text(text, "utf-8", "ignore")
104 | else:
105 | raise ValueError("Unsupported string type: %s" % (type(text)))
106 | elif six.PY2:
107 | if isinstance(text, str):
108 | return six.ensure_text(text, "utf-8", "ignore")
109 | elif isinstance(text, six.text_type):
110 | return text
111 | else:
112 | raise ValueError("Unsupported string type: %s" % (type(text)))
113 | else:
114 | raise ValueError("Not running on Python2 or Python 3?")
115 |
116 |
117 | def printable_text(text):
118 | """Returns text encoded in a way suitable for print or `tf.logging`."""
119 |
120 | # These functions want `str` for both Python2 and Python3, but in one case
121 | # it's a Unicode string and in the other it's a byte string.
122 | if six.PY3:
123 | if isinstance(text, str):
124 | return text
125 | elif isinstance(text, bytes):
126 | return six.ensure_text(text, "utf-8", "ignore")
127 | else:
128 | raise ValueError("Unsupported string type: %s" % (type(text)))
129 | elif six.PY2:
130 | if isinstance(text, str):
131 | return text
132 | elif isinstance(text, six.text_type):
133 | return six.ensure_binary(text, "utf-8")
134 | else:
135 | raise ValueError("Unsupported string type: %s" % (type(text)))
136 | else:
137 | raise ValueError("Not running on Python2 or Python 3?")
138 |
139 |
140 | def load_vocab(vocab_file):
141 | """Loads a vocabulary file into a dictionary."""
142 | vocab = collections.OrderedDict()
143 | with tf.gfile.GFile(vocab_file, "r") as reader:
144 | while True:
145 | token = convert_to_unicode(reader.readline())
146 | if not token:
147 | break
148 | token = token.strip().split()[0] if token.strip() else " "
149 | if token not in vocab:
150 | vocab[token] = len(vocab)
151 | return vocab
152 |
153 |
154 | def convert_by_vocab(vocab, items):
155 | """Converts a sequence of [tokens|ids] using the vocab."""
156 | output = []
157 | for item in items:
158 | output.append(vocab[item])
159 | return output
160 |
161 |
162 | def convert_tokens_to_ids(vocab, tokens):
163 | return convert_by_vocab(vocab, tokens)
164 |
165 |
166 | def convert_ids_to_tokens(inv_vocab, ids):
167 | return convert_by_vocab(inv_vocab, ids)
168 |
169 |
170 | def whitespace_tokenize(text):
171 | """Runs basic whitespace cleaning and splitting on a piece of text."""
172 | text = text.strip()
173 | if not text:
174 | return []
175 | tokens = text.split()
176 | return tokens
177 |
178 |
179 | class FullTokenizer(object):
180 | """Runs end-to-end tokenziation."""
181 |
182 | def __init__(self, vocab_file, do_lower_case=True, spm_model_file=None):
183 | self.vocab = None
184 | self.sp_model = None
185 | if spm_model_file:
186 | self.sp_model = spm.SentencePieceProcessor()
187 | tf.logging.info("loading sentence piece model")
188 | # Handle cases where SP can't load the file, but gfile can.
189 | sp_model_ = tf.gfile.GFile(spm_model_file, "rb").read()
190 | self.sp_model.LoadFromSerializedProto(sp_model_)
191 | # Note(mingdachen): For the purpose of consisent API, we are
192 | # generating a vocabulary for the sentence piece tokenizer.
193 | self.vocab = {self.sp_model.IdToPiece(i): i for i
194 | in range(self.sp_model.GetPieceSize())}
195 | else:
196 | self.vocab = load_vocab(vocab_file)
197 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
198 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
199 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
200 |
201 | @classmethod
202 | def from_scratch(cls, vocab_file, do_lower_case, spm_model_file):
203 | return FullTokenizer(vocab_file, do_lower_case, spm_model_file)
204 |
205 | @classmethod
206 | def from_hub_module(cls, hub_module, use_spm=True):
207 | """Get the vocab file and casing info from the Hub module."""
208 | with tf.Graph().as_default():
209 | albert_module = hub.Module(hub_module)
210 | tokenization_info = albert_module(signature="tokenization_info",
211 | as_dict=True)
212 | with tf.Session() as sess:
213 | vocab_file, do_lower_case = sess.run(
214 | [tokenization_info["vocab_file"],
215 | tokenization_info["do_lower_case"]])
216 | if use_spm:
217 | spm_model_file = vocab_file
218 | vocab_file = None
219 | return FullTokenizer(
220 | vocab_file=vocab_file, do_lower_case=do_lower_case,
221 | spm_model_file=spm_model_file)
222 |
223 | def tokenize(self, text):
224 | if self.sp_model:
225 | split_tokens = encode_pieces(self.sp_model, text, return_unicode=False)
226 | else:
227 | split_tokens = []
228 | for token in self.basic_tokenizer.tokenize(text):
229 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
230 | split_tokens.append(sub_token)
231 |
232 | return split_tokens
233 |
234 | def convert_tokens_to_ids(self, tokens):
235 | if self.sp_model:
236 | tf.logging.info("using sentence piece tokenzier.")
237 | return [self.sp_model.PieceToId(
238 | printable_text(token)) for token in tokens]
239 | else:
240 | return convert_by_vocab(self.vocab, tokens)
241 |
242 | def convert_ids_to_tokens(self, ids):
243 | if self.sp_model:
244 | tf.logging.info("using sentence piece tokenzier.")
245 | return [self.sp_model.IdToPiece(id_) for id_ in ids]
246 | else:
247 | return convert_by_vocab(self.inv_vocab, ids)
248 |
249 |
250 | class BasicTokenizer(object):
251 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
252 |
253 | def __init__(self, do_lower_case=True):
254 | """Constructs a BasicTokenizer.
255 |
256 | Args:
257 | do_lower_case: Whether to lower case the input.
258 | """
259 | self.do_lower_case = do_lower_case
260 |
261 | def tokenize(self, text):
262 | """Tokenizes a piece of text."""
263 | text = convert_to_unicode(text)
264 | text = self._clean_text(text)
265 |
266 | # This was added on November 1st, 2018 for the multilingual and Chinese
267 | # models. This is also applied to the English models now, but it doesn't
268 | # matter since the English models were not trained on any Chinese data
269 | # and generally don't have any Chinese data in them (there are Chinese
270 | # characters in the vocabulary because Wikipedia does have some Chinese
271 | # words in the English Wikipedia.).
272 | text = self._tokenize_chinese_chars(text)
273 |
274 | orig_tokens = whitespace_tokenize(text)
275 | split_tokens = []
276 | for token in orig_tokens:
277 | if self.do_lower_case:
278 | token = token.lower()
279 | token = self._run_strip_accents(token)
280 | split_tokens.extend(self._run_split_on_punc(token))
281 |
282 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
283 | return output_tokens
284 |
285 | def _run_strip_accents(self, text):
286 | """Strips accents from a piece of text."""
287 | text = unicodedata.normalize("NFD", text)
288 | output = []
289 | for char in text:
290 | cat = unicodedata.category(char)
291 | if cat == "Mn":
292 | continue
293 | output.append(char)
294 | return "".join(output)
295 |
296 | def _run_split_on_punc(self, text):
297 | """Splits punctuation on a piece of text."""
298 | chars = list(text)
299 | i = 0
300 | start_new_word = True
301 | output = []
302 | while i < len(chars):
303 | char = chars[i]
304 | if _is_punctuation(char):
305 | output.append([char])
306 | start_new_word = True
307 | else:
308 | if start_new_word:
309 | output.append([])
310 | start_new_word = False
311 | output[-1].append(char)
312 | i += 1
313 |
314 | return ["".join(x) for x in output]
315 |
316 | def _tokenize_chinese_chars(self, text):
317 | """Adds whitespace around any CJK character."""
318 | output = []
319 | for char in text:
320 | cp = ord(char)
321 | if self._is_chinese_char(cp):
322 | output.append(" ")
323 | output.append(char)
324 | output.append(" ")
325 | else:
326 | output.append(char)
327 | return "".join(output)
328 |
329 | def _is_chinese_char(self, cp):
330 | """Checks whether CP is the codepoint of a CJK character."""
331 | # This defines a "chinese character" as anything in the CJK Unicode block:
332 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
333 | #
334 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
335 | # despite its name. The modern Korean Hangul alphabet is a different block,
336 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
337 | # space-separated words, so they are not treated specially and handled
338 | # like the all of the other languages.
339 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
340 | (cp >= 0x3400 and cp <= 0x4DBF) or #
341 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
342 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
343 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
344 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
345 | (cp >= 0xF900 and cp <= 0xFAFF) or #
346 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
347 | return True
348 |
349 | return False
350 |
351 | def _clean_text(self, text):
352 | """Performs invalid character removal and whitespace cleanup on text."""
353 | output = []
354 | for char in text:
355 | cp = ord(char)
356 | if cp == 0 or cp == 0xfffd or _is_control(char):
357 | continue
358 | if _is_whitespace(char):
359 | output.append(" ")
360 | else:
361 | output.append(char)
362 | return "".join(output)
363 |
364 |
365 | class WordpieceTokenizer(object):
366 | """Runs WordPiece tokenziation."""
367 |
368 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
369 | self.vocab = vocab
370 | self.unk_token = unk_token
371 | self.max_input_chars_per_word = max_input_chars_per_word
372 |
373 | def tokenize(self, text):
374 | """Tokenizes a piece of text into its word pieces.
375 |
376 | This uses a greedy longest-match-first algorithm to perform tokenization
377 | using the given vocabulary.
378 |
379 | For example:
380 | input = "unaffable"
381 | output = ["un", "##aff", "##able"]
382 |
383 | Args:
384 | text: A single token or whitespace separated tokens. This should have
385 | already been passed through `BasicTokenizer.
386 |
387 | Returns:
388 | A list of wordpiece tokens.
389 | """
390 |
391 | text = convert_to_unicode(text)
392 |
393 | output_tokens = []
394 | for token in whitespace_tokenize(text):
395 | chars = list(token)
396 | if len(chars) > self.max_input_chars_per_word:
397 | output_tokens.append(self.unk_token)
398 | continue
399 |
400 | is_bad = False
401 | start = 0
402 | sub_tokens = []
403 | while start < len(chars):
404 | end = len(chars)
405 | cur_substr = None
406 | while start < end:
407 | substr = "".join(chars[start:end])
408 | if start > 0:
409 | substr = "##" + six.ensure_str(substr)
410 | if substr in self.vocab:
411 | cur_substr = substr
412 | break
413 | end -= 1
414 | if cur_substr is None:
415 | is_bad = True
416 | break
417 | sub_tokens.append(cur_substr)
418 | start = end
419 |
420 | if is_bad:
421 | output_tokens.append(self.unk_token)
422 | else:
423 | output_tokens.extend(sub_tokens)
424 | return output_tokens
425 |
426 |
427 | def _is_whitespace(char):
428 | """Checks whether `chars` is a whitespace character."""
429 | # \t, \n, and \r are technically control characters but we treat them
430 | # as whitespace since they are generally considered as such.
431 | if char == " " or char == "\t" or char == "\n" or char == "\r":
432 | return True
433 | cat = unicodedata.category(char)
434 | if cat == "Zs":
435 | return True
436 | return False
437 |
438 |
439 | def _is_control(char):
440 | """Checks whether `chars` is a control character."""
441 | # These are technically control characters but we count them as whitespace
442 | # characters.
443 | if char == "\t" or char == "\n" or char == "\r":
444 | return False
445 | cat = unicodedata.category(char)
446 | if cat in ("Cc", "Cf"):
447 | return True
448 | return False
449 |
450 |
451 | def _is_punctuation(char):
452 | """Checks whether `chars` is a punctuation character."""
453 | cp = ord(char)
454 | # We treat all non-letter/number ASCII as punctuation.
455 | # Characters such as "^", "$", and "`" are not in the Unicode
456 | # Punctuation class but we treat them as punctuation anyways, for
457 | # consistency.
458 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
459 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
460 | return True
461 | cat = unicodedata.category(char)
462 | if cat.startswith("P"):
463 | return True
464 | return False
465 |
--------------------------------------------------------------------------------
/tokenization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 | import os
19 | import tempfile
20 | from albert import tokenization
21 | import six
22 | import tensorflow.compat.v1 as tf
23 |
24 |
25 | class TokenizationTest(tf.test.TestCase):
26 |
27 | def test_full_tokenizer(self):
28 | vocab_tokens = [
29 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
30 | "##ing", ","
31 | ]
32 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
33 | if six.PY2:
34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
35 | else:
36 | contents = "".join([six.ensure_str(x) + "\n" for x in vocab_tokens])
37 | vocab_writer.write(six.ensure_binary(contents, "utf-8"))
38 |
39 | vocab_file = vocab_writer.name
40 |
41 | tokenizer = tokenization.FullTokenizer(vocab_file)
42 | os.unlink(vocab_file)
43 |
44 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
45 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
46 |
47 | self.assertAllEqual(
48 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
49 |
50 | def test_chinese(self):
51 | tokenizer = tokenization.BasicTokenizer()
52 |
53 | self.assertAllEqual(
54 | tokenizer.tokenize(u"ah\u535A\u63A8zz"),
55 | [u"ah", u"\u535A", u"\u63A8", u"zz"])
56 |
57 | def test_basic_tokenizer_lower(self):
58 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
59 |
60 | self.assertAllEqual(
61 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
62 | ["hello", "!", "how", "are", "you", "?"])
63 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
64 |
65 | def test_basic_tokenizer_no_lower(self):
66 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
67 |
68 | self.assertAllEqual(
69 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
70 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
71 |
72 | def test_wordpiece_tokenizer(self):
73 | vocab_tokens = [
74 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
75 | "##ing"
76 | ]
77 |
78 | vocab = {}
79 | for (i, token) in enumerate(vocab_tokens):
80 | vocab[token] = i
81 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
82 |
83 | self.assertAllEqual(tokenizer.tokenize(""), [])
84 |
85 | self.assertAllEqual(
86 | tokenizer.tokenize("unwanted running"),
87 | ["un", "##want", "##ed", "runn", "##ing"])
88 |
89 | self.assertAllEqual(
90 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
91 |
92 | def test_convert_tokens_to_ids(self):
93 | vocab_tokens = [
94 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
95 | "##ing"
96 | ]
97 |
98 | vocab = {}
99 | for (i, token) in enumerate(vocab_tokens):
100 | vocab[token] = i
101 |
102 | self.assertAllEqual(
103 | tokenization.convert_tokens_to_ids(
104 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
105 |
106 | def test_is_whitespace(self):
107 | self.assertTrue(tokenization._is_whitespace(u" "))
108 | self.assertTrue(tokenization._is_whitespace(u"\t"))
109 | self.assertTrue(tokenization._is_whitespace(u"\r"))
110 | self.assertTrue(tokenization._is_whitespace(u"\n"))
111 | self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
112 |
113 | self.assertFalse(tokenization._is_whitespace(u"A"))
114 | self.assertFalse(tokenization._is_whitespace(u"-"))
115 |
116 | def test_is_control(self):
117 | self.assertTrue(tokenization._is_control(u"\u0005"))
118 |
119 | self.assertFalse(tokenization._is_control(u"A"))
120 | self.assertFalse(tokenization._is_control(u" "))
121 | self.assertFalse(tokenization._is_control(u"\t"))
122 | self.assertFalse(tokenization._is_control(u"\r"))
123 | self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
124 |
125 | def test_is_punctuation(self):
126 | self.assertTrue(tokenization._is_punctuation(u"-"))
127 | self.assertTrue(tokenization._is_punctuation(u"$"))
128 | self.assertTrue(tokenization._is_punctuation(u"`"))
129 | self.assertTrue(tokenization._is_punctuation(u"."))
130 |
131 | self.assertFalse(tokenization._is_punctuation(u"A"))
132 | self.assertFalse(tokenization._is_punctuation(u" "))
133 |
134 |
135 | if __name__ == "__main__":
136 | tf.test.main()
137 |
--------------------------------------------------------------------------------