├── requirements.txt ├── __init__.py ├── run_trivial_model_test.sh ├── CONTRIBUTING.md ├── .gitignore ├── optimization_test.py ├── run_glue.sh ├── fine_tuning_utils.py ├── tokenization_test.py ├── run_pretraining_test.py ├── lamb_optimizer.py ├── export_checkpoints.py ├── optimization.py ├── export_to_tfhub.py ├── modeling_test.py ├── LICENSE ├── albert_glue_fine_tuning_tutorial.ipynb ├── README.md ├── tokenization.py ├── race_utils.py ├── run_race.py ├── run_squad_v2.py ├── run_squad_v1.py └── run_classifier.py /requirements.txt: -------------------------------------------------------------------------------- 1 | # Run pip install --upgrade pip if tensorflow 1.15 cannot be found 2 | tensorflow==1.15.2 # CPU Version of TensorFlow 3 | tensorflow_hub==0.7 4 | # tensorflow-gpu==1.15 # GPU version of TensorFlow 5 | sentencepiece 6 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /run_trivial_model_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Small integration test script. 3 | # The values in this file are **not** meant for reproducing actual results. 4 | 5 | set -e 6 | set -x 7 | 8 | virtualenv -p python3 . 9 | source ./bin/activate 10 | 11 | OUTPUT_DIR_BASE="$(mktemp -d)" 12 | OUTPUT_DIR="${OUTPUT_DIR_BASE}/output" 13 | 14 | pip install numpy 15 | pip install -r requirements.txt 16 | python -m run_pretraining_test \ 17 | --output_dir="${OUTPUT_DIR}" \ 18 | --do_train \ 19 | --do_eval \ 20 | --nouse_tpu \ 21 | --train_batch_size=2 \ 22 | --eval_batch_size=1 \ 23 | --max_seq_length=4 \ 24 | --num_train_steps=2 \ 25 | --max_eval_steps=3 26 | 27 | 28 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows 28 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | -------------------------------------------------------------------------------- /optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | from albert import optimization 20 | from six.moves import range 21 | from six.moves import zip 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | class OptimizationTest(tf.test.TestCase): 26 | 27 | def test_adam(self): 28 | with self.test_session() as sess: 29 | w = tf.get_variable( 30 | "w", 31 | shape=[3], 32 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 33 | x = tf.constant([0.4, 0.2, -0.5]) 34 | loss = tf.reduce_mean(tf.square(x - w)) 35 | tvars = tf.trainable_variables() 36 | grads = tf.gradients(loss, tvars) 37 | global_step = tf.train.get_or_create_global_step() 38 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 39 | train_op = optimizer.apply_gradients(list(zip(grads, tvars)), global_step) 40 | init_op = tf.group(tf.global_variables_initializer(), 41 | tf.local_variables_initializer()) 42 | sess.run(init_op) 43 | for _ in range(100): 44 | sess.run(train_op) 45 | w_np = sess.run(w) 46 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 47 | 48 | 49 | if __name__ == "__main__": 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /run_glue.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This is a convenience script for evaluating ALBERT on the GLUE benchmark. 3 | # 4 | # By default, this script uses a pretrained ALBERT v1 BASE model, but you may 5 | # use a custom checkpoint or any compatible TF-Hub checkpoint with minimal 6 | # edits to environment variables (see ALBERT_HUB_MODULE_HANDLE below). 7 | # 8 | # This script does fine-tuning and evaluation on 8 tasks, so it may take a 9 | # while to complete if you do not have a hardware accelerator. 10 | 11 | set -ex 12 | 13 | python3 -m venv $HOME/albertenv 14 | . $HOME/albertenv/bin/activate 15 | 16 | OUTPUT_DIR_BASE="$(mktemp -d)" 17 | OUTPUT_DIR="${OUTPUT_DIR_BASE}/output" 18 | 19 | # To start from a custom pretrained checkpoint, set ALBERT_HUB_MODULE_HANDLE 20 | # below to an empty string and set INIT_CHECKPOINT to your checkpoint path. 21 | ALBERT_HUB_MODULE_HANDLE="https://tfhub.dev/google/albert_base/1" 22 | INIT_CHECKPOINT="" 23 | 24 | pip3 install --upgrade pip 25 | pip3 install numpy 26 | pip3 install -r requirements.txt 27 | 28 | function run_task() { 29 | COMMON_ARGS="--output_dir="${OUTPUT_DIR}/$1" --data_dir="${ALBERT_ROOT}/glue" --vocab_file="${ALBERT_ROOT}/vocab.txt" --spm_model_file="${ALBERT_ROOT}/30k-clean.model" --do_lower_case --max_seq_length=512 --optimizer=adamw --task_name=$1 --warmup_step=$2 --learning_rate=$3 --train_step=$4 --save_checkpoints_steps=$5 --train_batch_size=$6" 30 | python3 -m run_classifier \ 31 | ${COMMON_ARGS} \ 32 | --do_train \ 33 | --nodo_eval \ 34 | --nodo_predict \ 35 | --albert_hub_module_handle="${ALBERT_HUB_MODULE_HANDLE}" \ 36 | --init_checkpoint="${INIT_CHECKPOINT}" 37 | python3 -m run_classifier \ 38 | ${COMMON_ARGS} \ 39 | --nodo_train \ 40 | --do_eval \ 41 | --do_predict \ 42 | --albert_hub_module_handle="${ALBERT_HUB_MODULE_HANDLE}" 43 | } 44 | 45 | run_task SST-2 1256 1e-5 20935 100 32 46 | run_task MNLI 1000 3e-5 10000 100 128 47 | run_task CoLA 320 1e-5 5336 100 16 48 | run_task QNLI 1986 1e-5 33112 200 32 49 | run_task QQP 1000 5e-5 14000 100 128 50 | run_task RTE 200 3e-5 800 100 32 51 | run_task STS-B 214 2e-5 3598 100 16 52 | run_task MRPC 200 2e-5 800 100 32 53 | -------------------------------------------------------------------------------- /fine_tuning_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python3 16 | """Helper library for ALBERT fine-tuning. 17 | 18 | This library can be used to construct ALBERT models for fine-tuning, either from 19 | json config files or from TF-Hub modules. 20 | """ 21 | 22 | from albert import modeling 23 | from albert import tokenization 24 | import tensorflow.compat.v1 as tf 25 | import tensorflow_hub as hub 26 | 27 | 28 | def _create_model_from_hub(hub_module, is_training, input_ids, input_mask, 29 | segment_ids): 30 | """Creates an ALBERT model from TF-Hub.""" 31 | tags = set() 32 | if is_training: 33 | tags.add("train") 34 | albert_module = hub.Module(hub_module, tags=tags, trainable=True) 35 | albert_inputs = dict( 36 | input_ids=input_ids, 37 | input_mask=input_mask, 38 | segment_ids=segment_ids) 39 | albert_outputs = albert_module( 40 | inputs=albert_inputs, 41 | signature="tokens", 42 | as_dict=True) 43 | return (albert_outputs["pooled_output"], albert_outputs["sequence_output"]) 44 | 45 | 46 | def _create_model_from_scratch(albert_config, is_training, input_ids, 47 | input_mask, segment_ids, use_one_hot_embeddings, 48 | use_einsum): 49 | """Creates an ALBERT model from scratch/config.""" 50 | model = modeling.AlbertModel( 51 | config=albert_config, 52 | is_training=is_training, 53 | input_ids=input_ids, 54 | input_mask=input_mask, 55 | token_type_ids=segment_ids, 56 | use_one_hot_embeddings=use_one_hot_embeddings, 57 | use_einsum=use_einsum) 58 | return (model.get_pooled_output(), model.get_sequence_output()) 59 | 60 | 61 | def create_albert(albert_config, is_training, input_ids, input_mask, 62 | segment_ids, use_one_hot_embeddings, use_einsum, hub_module): 63 | """Creates an ALBERT, either from TF-Hub or from scratch.""" 64 | if hub_module: 65 | tf.logging.info("creating model from hub_module: %s", hub_module) 66 | return _create_model_from_hub(hub_module, is_training, input_ids, 67 | input_mask, segment_ids) 68 | else: 69 | tf.logging.info("creating model from albert_config") 70 | return _create_model_from_scratch(albert_config, is_training, input_ids, 71 | input_mask, segment_ids, 72 | use_one_hot_embeddings, use_einsum) 73 | 74 | 75 | def create_vocab(vocab_file, do_lower_case, spm_model_file, hub_module): 76 | """Creates a vocab, either from vocab file or from a TF-Hub module.""" 77 | if hub_module: 78 | use_spm = True if spm_model_file else False 79 | return tokenization.FullTokenizer.from_hub_module( 80 | hub_module=hub_module, use_spm=use_spm) 81 | else: 82 | return tokenization.FullTokenizer.from_scratch( 83 | vocab_file=vocab_file, do_lower_case=do_lower_case, 84 | spm_model_file=spm_model_file) 85 | 86 | -------------------------------------------------------------------------------- /tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | import os 20 | import tempfile 21 | from albert import tokenization 22 | import six 23 | import tensorflow.compat.v1 as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | if six.PY2: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | else: 37 | contents = "".join([six.ensure_str(x) + "\n" for x in vocab_tokens]) 38 | vocab_writer.write(six.ensure_binary(contents, "utf-8")) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = tokenization.FullTokenizer(vocab_file) 43 | os.unlink(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertAllEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | def test_chinese(self): 52 | tokenizer = tokenization.BasicTokenizer() 53 | 54 | self.assertAllEqual( 55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 56 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 57 | 58 | def test_basic_tokenizer_lower(self): 59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 60 | 61 | self.assertAllEqual( 62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 65 | 66 | def test_basic_tokenizer_no_lower(self): 67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 68 | 69 | self.assertAllEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 72 | 73 | def test_wordpiece_tokenizer(self): 74 | vocab_tokens = [ 75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 76 | "##ing" 77 | ] 78 | 79 | vocab = {} 80 | for (i, token) in enumerate(vocab_tokens): 81 | vocab[token] = i 82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 83 | 84 | self.assertAllEqual(tokenizer.tokenize(""), []) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwanted running"), 88 | ["un", "##want", "##ed", "runn", "##ing"]) 89 | 90 | self.assertAllEqual( 91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 92 | 93 | def test_convert_tokens_to_ids(self): 94 | vocab_tokens = [ 95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 96 | "##ing" 97 | ] 98 | 99 | vocab = {} 100 | for (i, token) in enumerate(vocab_tokens): 101 | vocab[token] = i 102 | 103 | self.assertAllEqual( 104 | tokenization.convert_tokens_to_ids( 105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(tokenization._is_whitespace(u" ")) 109 | self.assertTrue(tokenization._is_whitespace(u"\t")) 110 | self.assertTrue(tokenization._is_whitespace(u"\r")) 111 | self.assertTrue(tokenization._is_whitespace(u"\n")) 112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(tokenization._is_whitespace(u"A")) 115 | self.assertFalse(tokenization._is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(tokenization._is_control(u"\u0005")) 119 | 120 | self.assertFalse(tokenization._is_control(u"A")) 121 | self.assertFalse(tokenization._is_control(u" ")) 122 | self.assertFalse(tokenization._is_control(u"\t")) 123 | self.assertFalse(tokenization._is_control(u"\r")) 124 | self.assertFalse(tokenization._is_control(u"\U0001F4A9")) 125 | 126 | def test_is_punctuation(self): 127 | self.assertTrue(tokenization._is_punctuation(u"-")) 128 | self.assertTrue(tokenization._is_punctuation(u"$")) 129 | self.assertTrue(tokenization._is_punctuation(u"`")) 130 | self.assertTrue(tokenization._is_punctuation(u".")) 131 | 132 | self.assertFalse(tokenization._is_punctuation(u"A")) 133 | self.assertFalse(tokenization._is_punctuation(u" ")) 134 | 135 | 136 | if __name__ == "__main__": 137 | tf.test.main() 138 | -------------------------------------------------------------------------------- /run_pretraining_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | """Tests for run_pretraining.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import random 24 | import tempfile 25 | from absl.testing import flagsaver 26 | from albert import modeling 27 | from albert import run_pretraining 28 | import tensorflow.compat.v1 as tf 29 | 30 | FLAGS = tf.app.flags.FLAGS 31 | 32 | 33 | def _create_config_file(filename, max_seq_length, vocab_size): 34 | """Creates an AlbertConfig and saves it to file.""" 35 | albert_config = modeling.AlbertConfig( 36 | vocab_size, 37 | embedding_size=5, 38 | hidden_size=14, 39 | num_hidden_layers=3, 40 | num_hidden_groups=1, 41 | num_attention_heads=2, 42 | intermediate_size=19, 43 | inner_group_num=1, 44 | down_scale_factor=1, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0, 47 | attention_probs_dropout_prob=0, 48 | max_position_embeddings=max_seq_length, 49 | type_vocab_size=2, 50 | initializer_range=0.02) 51 | with tf.gfile.Open(filename, "w") as outfile: 52 | outfile.write(albert_config.to_json_string()) 53 | 54 | 55 | def _create_record(max_predictions_per_seq, max_seq_length, vocab_size): 56 | """Returns a tf.train.Example containing random data.""" 57 | example = tf.train.Example() 58 | example.features.feature["input_ids"].int64_list.value.extend( 59 | [random.randint(0, vocab_size - 1) for _ in range(max_seq_length)]) 60 | example.features.feature["input_mask"].int64_list.value.extend( 61 | [random.randint(0, 1) for _ in range(max_seq_length)]) 62 | example.features.feature["masked_lm_positions"].int64_list.value.extend([ 63 | random.randint(0, max_seq_length - 1) 64 | for _ in range(max_predictions_per_seq) 65 | ]) 66 | example.features.feature["masked_lm_ids"].int64_list.value.extend([ 67 | random.randint(0, vocab_size - 1) for _ in range(max_predictions_per_seq) 68 | ]) 69 | example.features.feature["masked_lm_weights"].float_list.value.extend( 70 | [1. for _ in range(max_predictions_per_seq)]) 71 | example.features.feature["segment_ids"].int64_list.value.extend( 72 | [0 for _ in range(max_seq_length)]) 73 | example.features.feature["next_sentence_labels"].int64_list.value.append( 74 | random.randint(0, 1)) 75 | return example 76 | 77 | 78 | def _create_input_file(filename, 79 | max_predictions_per_seq, 80 | max_seq_length, 81 | vocab_size, 82 | size=1000): 83 | """Creates an input TFRecord file of specified size.""" 84 | with tf.io.TFRecordWriter(filename) as writer: 85 | for _ in range(size): 86 | ex = _create_record(max_predictions_per_seq, max_seq_length, vocab_size) 87 | writer.write(ex.SerializeToString()) 88 | 89 | 90 | class RunPretrainingTest(tf.test.TestCase): 91 | 92 | def _verify_output_file(self, basename): 93 | self.assertTrue(tf.gfile.Exists(os.path.join(FLAGS.output_dir, basename))) 94 | 95 | def _verify_checkpoint_files(self, name): 96 | self._verify_output_file(name + ".meta") 97 | self._verify_output_file(name + ".index") 98 | self._verify_output_file(name + ".data-00000-of-00001") 99 | 100 | @flagsaver.flagsaver 101 | def test_pretraining(self): 102 | # Set up required flags. 103 | vocab_size = 97 104 | FLAGS.max_predictions_per_seq = 7 105 | FLAGS.max_seq_length = 13 106 | FLAGS.output_dir = tempfile.mkdtemp("output_dir") 107 | FLAGS.albert_config_file = os.path.join( 108 | tempfile.mkdtemp("config_dir"), "albert_config.json") 109 | FLAGS.input_file = os.path.join( 110 | tempfile.mkdtemp("input_dir"), "input_data.tfrecord") 111 | FLAGS.do_train = True 112 | FLAGS.do_eval = True 113 | FLAGS.num_train_steps = 1 114 | FLAGS.save_checkpoints_steps = 1 115 | 116 | # Construct requisite input files. 117 | _create_config_file(FLAGS.albert_config_file, FLAGS.max_seq_length, 118 | vocab_size) 119 | _create_input_file(FLAGS.input_file, FLAGS.max_predictions_per_seq, 120 | FLAGS.max_seq_length, vocab_size) 121 | 122 | # Run the pretraining. 123 | run_pretraining.main(None) 124 | 125 | # Verify output. 126 | self._verify_checkpoint_files("model.ckpt-best") 127 | self._verify_checkpoint_files("model.ckpt-1") 128 | self._verify_output_file("eval_results.txt") 129 | self._verify_output_file("checkpoint") 130 | 131 | 132 | if __name__ == "__main__": 133 | tf.test.main() 134 | -------------------------------------------------------------------------------- /lamb_optimizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | """Functions and classes related to optimization (weight updates).""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import re 23 | import six 24 | import tensorflow.compat.v1 as tf 25 | 26 | # pylint: disable=g-direct-tensorflow-import 27 | from tensorflow.python.ops import array_ops 28 | from tensorflow.python.ops import linalg_ops 29 | from tensorflow.python.ops import math_ops 30 | # pylint: enable=g-direct-tensorflow-import 31 | 32 | 33 | class LAMBOptimizer(tf.train.Optimizer): 34 | """LAMB (Layer-wise Adaptive Moments optimizer for Batch training).""" 35 | # A new optimizer that includes correct L2 weight decay, adaptive 36 | # element-wise updating, and layer-wise justification. The LAMB optimizer 37 | # was proposed by Yang You, Jing Li, Jonathan Hseu, Xiaodan Song, 38 | # James Demmel, and Cho-Jui Hsieh in a paper titled as Reducing BERT 39 | # Pre-Training Time from 3 Days to 76 Minutes (arxiv.org/abs/1904.00962) 40 | 41 | def __init__(self, 42 | learning_rate, 43 | weight_decay_rate=0.0, 44 | beta_1=0.9, 45 | beta_2=0.999, 46 | epsilon=1e-6, 47 | exclude_from_weight_decay=None, 48 | exclude_from_layer_adaptation=None, 49 | name="LAMBOptimizer"): 50 | """Constructs a LAMBOptimizer.""" 51 | super(LAMBOptimizer, self).__init__(False, name) 52 | 53 | self.learning_rate = learning_rate 54 | self.weight_decay_rate = weight_decay_rate 55 | self.beta_1 = beta_1 56 | self.beta_2 = beta_2 57 | self.epsilon = epsilon 58 | self.exclude_from_weight_decay = exclude_from_weight_decay 59 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the 60 | # arg is None. 61 | # TODO(jingli): validate if exclude_from_layer_adaptation is necessary. 62 | if exclude_from_layer_adaptation: 63 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation 64 | else: 65 | self.exclude_from_layer_adaptation = exclude_from_weight_decay 66 | 67 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 68 | """See base class.""" 69 | assignments = [] 70 | for (grad, param) in grads_and_vars: 71 | if grad is None or param is None: 72 | continue 73 | 74 | param_name = self._get_variable_name(param.name) 75 | 76 | m = tf.get_variable( 77 | name=six.ensure_str(param_name) + "/adam_m", 78 | shape=param.shape.as_list(), 79 | dtype=tf.float32, 80 | trainable=False, 81 | initializer=tf.zeros_initializer()) 82 | v = tf.get_variable( 83 | name=six.ensure_str(param_name) + "/adam_v", 84 | shape=param.shape.as_list(), 85 | dtype=tf.float32, 86 | trainable=False, 87 | initializer=tf.zeros_initializer()) 88 | 89 | # Standard Adam update. 90 | next_m = ( 91 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 92 | next_v = ( 93 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 94 | tf.square(grad))) 95 | 96 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 97 | 98 | # Just adding the square of the weights to the loss function is *not* 99 | # the correct way of using L2 regularization/weight decay with Adam, 100 | # since that will interact with the m and v parameters in strange ways. 101 | # 102 | # Instead we want ot decay the weights in a manner that doesn't interact 103 | # with the m/v parameters. This is equivalent to adding the square 104 | # of the weights to the loss with plain (non-momentum) SGD. 105 | if self._do_use_weight_decay(param_name): 106 | update += self.weight_decay_rate * param 107 | 108 | ratio = 1.0 109 | if self._do_layer_adaptation(param_name): 110 | w_norm = linalg_ops.norm(param, ord=2) 111 | g_norm = linalg_ops.norm(update, ord=2) 112 | ratio = array_ops.where(math_ops.greater(w_norm, 0), array_ops.where( 113 | math_ops.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) 114 | 115 | update_with_lr = ratio * self.learning_rate * update 116 | 117 | next_param = param - update_with_lr 118 | 119 | assignments.extend( 120 | [param.assign(next_param), 121 | m.assign(next_m), 122 | v.assign(next_v)]) 123 | return tf.group(*assignments, name=name) 124 | 125 | def _do_use_weight_decay(self, param_name): 126 | """Whether to use L2 weight decay for `param_name`.""" 127 | if not self.weight_decay_rate: 128 | return False 129 | if self.exclude_from_weight_decay: 130 | for r in self.exclude_from_weight_decay: 131 | if re.search(r, param_name) is not None: 132 | return False 133 | return True 134 | 135 | def _do_layer_adaptation(self, param_name): 136 | """Whether to do layer-wise learning rate adaptation for `param_name`.""" 137 | if self.exclude_from_layer_adaptation: 138 | for r in self.exclude_from_layer_adaptation: 139 | if re.search(r, param_name) is not None: 140 | return False 141 | return True 142 | 143 | def _get_variable_name(self, param_name): 144 | """Get the variable name from the tensor name.""" 145 | m = re.match("^(.*):\\d+$", six.ensure_str(param_name)) 146 | if m is not None: 147 | param_name = m.group(1) 148 | return param_name 149 | -------------------------------------------------------------------------------- /export_checkpoints.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Exports a minimal module for ALBERT models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import os 21 | from absl import app 22 | from absl import flags 23 | from albert import modeling 24 | import tensorflow.compat.v1 as tf 25 | 26 | flags.DEFINE_string( 27 | "albert_directory", None, 28 | "The config json file corresponding to the pre-trained ALBERT model. " 29 | "This specifies the model architecture.") 30 | 31 | flags.DEFINE_string( 32 | "checkpoint_name", "model.ckpt-best", 33 | "Name of the checkpoint under albert_directory to be exported.") 34 | 35 | flags.DEFINE_bool( 36 | "do_lower_case", True, 37 | "Whether to lower case the input text. Should be True for uncased " 38 | "models and False for cased models.") 39 | 40 | flags.DEFINE_string("export_path", None, "Path to the output module.") 41 | 42 | FLAGS = flags.FLAGS 43 | 44 | 45 | def gather_indexes(sequence_tensor, positions): 46 | """Gathers the vectors at the specific positions over a minibatch.""" 47 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 48 | batch_size = sequence_shape[0] 49 | seq_length = sequence_shape[1] 50 | width = sequence_shape[2] 51 | 52 | flat_offsets = tf.reshape( 53 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 54 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 55 | flat_sequence_tensor = tf.reshape(sequence_tensor, 56 | [batch_size * seq_length, width]) 57 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 58 | return output_tensor 59 | 60 | 61 | def get_mlm_logits(input_tensor, albert_config, mlm_positions, output_weights): 62 | """From run_pretraining.py.""" 63 | input_tensor = gather_indexes(input_tensor, mlm_positions) 64 | with tf.variable_scope("cls/predictions"): 65 | # We apply one more non-linear transformation before the output layer. 66 | # This matrix is not used after pre-training. 67 | with tf.variable_scope("transform"): 68 | input_tensor = tf.layers.dense( 69 | input_tensor, 70 | units=albert_config.embedding_size, 71 | activation=modeling.get_activation(albert_config.hidden_act), 72 | kernel_initializer=modeling.create_initializer( 73 | albert_config.initializer_range)) 74 | input_tensor = modeling.layer_norm(input_tensor) 75 | 76 | # The output weights are the same as the input embeddings, but there is 77 | # an output-only bias for each token. 78 | output_bias = tf.get_variable( 79 | "output_bias", 80 | shape=[albert_config.vocab_size], 81 | initializer=tf.zeros_initializer()) 82 | logits = tf.matmul( 83 | input_tensor, output_weights, transpose_b=True) 84 | logits = tf.nn.bias_add(logits, output_bias) 85 | return logits 86 | 87 | 88 | def get_sentence_order_logits(input_tensor, albert_config): 89 | """Get loss and log probs for the next sentence prediction.""" 90 | 91 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 92 | # "random sentence". This weight matrix is not used after pre-training. 93 | with tf.variable_scope("cls/seq_relationship"): 94 | output_weights = tf.get_variable( 95 | "output_weights", 96 | shape=[2, albert_config.hidden_size], 97 | initializer=modeling.create_initializer( 98 | albert_config.initializer_range)) 99 | output_bias = tf.get_variable( 100 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 101 | 102 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 103 | logits = tf.nn.bias_add(logits, output_bias) 104 | return logits 105 | 106 | 107 | def build_model(sess): 108 | """Module function.""" 109 | input_ids = tf.placeholder(tf.int32, [None, None], "input_ids") 110 | input_mask = tf.placeholder(tf.int32, [None, None], "input_mask") 111 | segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids") 112 | mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions") 113 | 114 | albert_config_path = os.path.join( 115 | FLAGS.albert_directory, "albert_config.json") 116 | albert_config = modeling.AlbertConfig.from_json_file(albert_config_path) 117 | model = modeling.AlbertModel( 118 | config=albert_config, 119 | is_training=False, 120 | input_ids=input_ids, 121 | input_mask=input_mask, 122 | token_type_ids=segment_ids, 123 | use_one_hot_embeddings=False) 124 | 125 | get_mlm_logits(model.get_sequence_output(), albert_config, 126 | mlm_positions, model.get_embedding_table()) 127 | get_sentence_order_logits(model.get_pooled_output(), albert_config) 128 | 129 | checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name) 130 | tvars = tf.trainable_variables() 131 | (assignment_map, initialized_variable_names 132 | ) = modeling.get_assignment_map_from_checkpoint(tvars, checkpoint_path) 133 | 134 | tf.logging.info("**** Trainable Variables ****") 135 | for var in tvars: 136 | init_string = "" 137 | if var.name in initialized_variable_names: 138 | init_string = ", *INIT_FROM_CKPT*" 139 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 140 | init_string) 141 | tf.train.init_from_checkpoint(checkpoint_path, assignment_map) 142 | init = tf.global_variables_initializer() 143 | sess.run(init) 144 | return sess 145 | 146 | 147 | def main(_): 148 | sess = tf.Session() 149 | tf.train.get_or_create_global_step() 150 | sess = build_model(sess) 151 | my_vars = [] 152 | for var in tf.global_variables(): 153 | if "lamb_v" not in var.name and "lamb_m" not in var.name: 154 | my_vars.append(var) 155 | saver = tf.train.Saver(my_vars) 156 | saver.save(sess, FLAGS.export_path) 157 | 158 | 159 | if __name__ == "__main__": 160 | flags.mark_flag_as_required("albert_directory") 161 | flags.mark_flag_as_required("export_path") 162 | app.run(main) 163 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | """Functions and classes related to optimization (weight updates).""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import re 22 | from albert import lamb_optimizer 23 | import six 24 | from six.moves import zip 25 | import tensorflow.compat.v1 as tf 26 | from tensorflow.contrib import tpu as contrib_tpu 27 | 28 | 29 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, 30 | optimizer="adamw", poly_power=1.0, start_warmup_step=0, 31 | colocate_gradients_with_ops=False): 32 | """Creates an optimizer training op.""" 33 | global_step = tf.train.get_or_create_global_step() 34 | 35 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 36 | 37 | # Implements linear decay of the learning rate. 38 | learning_rate = tf.train.polynomial_decay( 39 | learning_rate, 40 | global_step, 41 | num_train_steps, 42 | end_learning_rate=0.0, 43 | power=poly_power, 44 | cycle=False) 45 | 46 | # Implements linear warmup. I.e., if global_step - start_warmup_step < 47 | # num_warmup_steps, the learning rate will be 48 | # `(global_step - start_warmup_step)/num_warmup_steps * init_lr`. 49 | if num_warmup_steps: 50 | tf.logging.info("++++++ warmup starts at step " + str(start_warmup_step) 51 | + ", for " + str(num_warmup_steps) + " steps ++++++") 52 | global_steps_int = tf.cast(global_step, tf.int32) 53 | start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32) 54 | global_steps_int = global_steps_int - start_warm_int 55 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 56 | 57 | global_steps_float = tf.cast(global_steps_int, tf.float32) 58 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 59 | 60 | warmup_percent_done = global_steps_float / warmup_steps_float 61 | warmup_learning_rate = init_lr * warmup_percent_done 62 | 63 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 64 | learning_rate = ( 65 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 66 | 67 | # It is OK that you use this optimizer for finetuning, since this 68 | # is how the model was trained (note that the Adam m/v variables are NOT 69 | # loaded from init_checkpoint.) 70 | # It is OK to use AdamW in the finetuning even the model is trained by LAMB. 71 | # As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune 72 | # is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a 73 | # batch size of 64 in the finetune. 74 | if optimizer == "adamw": 75 | tf.logging.info("using adamw") 76 | optimizer = AdamWeightDecayOptimizer( 77 | learning_rate=learning_rate, 78 | weight_decay_rate=0.01, 79 | beta_1=0.9, 80 | beta_2=0.999, 81 | epsilon=1e-6, 82 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 83 | elif optimizer == "lamb": 84 | tf.logging.info("using lamb") 85 | optimizer = lamb_optimizer.LAMBOptimizer( 86 | learning_rate=learning_rate, 87 | weight_decay_rate=0.01, 88 | beta_1=0.9, 89 | beta_2=0.999, 90 | epsilon=1e-6, 91 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 92 | else: 93 | raise ValueError("Not supported optimizer: ", optimizer) 94 | 95 | if use_tpu: 96 | optimizer = contrib_tpu.CrossShardOptimizer(optimizer) 97 | 98 | tvars = tf.trainable_variables() 99 | grads = tf.gradients( 100 | loss, tvars, colocate_gradients_with_ops=colocate_gradients_with_ops) 101 | 102 | # This is how the model was pre-trained. 103 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 104 | 105 | train_op = optimizer.apply_gradients( 106 | list(zip(grads, tvars)), global_step=global_step) 107 | 108 | # Normally the global step update is done inside of `apply_gradients`. 109 | # However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this. 110 | # But if you use a different optimizer, you should probably take this line 111 | # out. 112 | new_global_step = global_step + 1 113 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 114 | return train_op 115 | 116 | 117 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 118 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 119 | 120 | def __init__(self, 121 | learning_rate, 122 | weight_decay_rate=0.0, 123 | beta_1=0.9, 124 | beta_2=0.999, 125 | epsilon=1e-6, 126 | exclude_from_weight_decay=None, 127 | name="AdamWeightDecayOptimizer"): 128 | """Constructs a AdamWeightDecayOptimizer.""" 129 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 130 | 131 | self.learning_rate = learning_rate 132 | self.weight_decay_rate = weight_decay_rate 133 | self.beta_1 = beta_1 134 | self.beta_2 = beta_2 135 | self.epsilon = epsilon 136 | self.exclude_from_weight_decay = exclude_from_weight_decay 137 | 138 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 139 | """See base class.""" 140 | assignments = [] 141 | for (grad, param) in grads_and_vars: 142 | if grad is None or param is None: 143 | continue 144 | 145 | param_name = self._get_variable_name(param.name) 146 | 147 | m = tf.get_variable( 148 | name=six.ensure_str(param_name) + "/adam_m", 149 | shape=param.shape.as_list(), 150 | dtype=tf.float32, 151 | trainable=False, 152 | initializer=tf.zeros_initializer()) 153 | v = tf.get_variable( 154 | name=six.ensure_str(param_name) + "/adam_v", 155 | shape=param.shape.as_list(), 156 | dtype=tf.float32, 157 | trainable=False, 158 | initializer=tf.zeros_initializer()) 159 | 160 | # Standard Adam update. 161 | next_m = ( 162 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 163 | next_v = ( 164 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 165 | tf.square(grad))) 166 | 167 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 168 | 169 | # Just adding the square of the weights to the loss function is *not* 170 | # the correct way of using L2 regularization/weight decay with Adam, 171 | # since that will interact with the m and v parameters in strange ways. 172 | # 173 | # Instead we want ot decay the weights in a manner that doesn't interact 174 | # with the m/v parameters. This is equivalent to adding the square 175 | # of the weights to the loss with plain (non-momentum) SGD. 176 | if self._do_use_weight_decay(param_name): 177 | update += self.weight_decay_rate * param 178 | 179 | update_with_lr = self.learning_rate * update 180 | 181 | next_param = param - update_with_lr 182 | 183 | assignments.extend( 184 | [param.assign(next_param), 185 | m.assign(next_m), 186 | v.assign(next_v)]) 187 | return tf.group(*assignments, name=name) 188 | 189 | def _do_use_weight_decay(self, param_name): 190 | """Whether to use L2 weight decay for `param_name`.""" 191 | if not self.weight_decay_rate: 192 | return False 193 | if self.exclude_from_weight_decay: 194 | for r in self.exclude_from_weight_decay: 195 | if re.search(r, param_name) is not None: 196 | return False 197 | return True 198 | 199 | def _get_variable_name(self, param_name): 200 | """Get the variable name from the tensor name.""" 201 | m = re.match("^(.*):\\d+$", six.ensure_str(param_name)) 202 | if m is not None: 203 | param_name = m.group(1) 204 | return param_name 205 | -------------------------------------------------------------------------------- /export_to_tfhub.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | r"""Exports a minimal TF-Hub module for ALBERT models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import os 21 | from absl import app 22 | from absl import flags 23 | from albert import modeling 24 | import tensorflow.compat.v1 as tf 25 | import tensorflow_hub as hub 26 | 27 | flags.DEFINE_string( 28 | "albert_directory", None, 29 | "The config json file corresponding to the pre-trained ALBERT model. " 30 | "This specifies the model architecture.") 31 | 32 | flags.DEFINE_string( 33 | "checkpoint_name", "model.ckpt-best", 34 | "Name of the checkpoint under albert_directory to be exported.") 35 | 36 | flags.DEFINE_bool( 37 | "do_lower_case", True, 38 | "Whether to lower case the input text. Should be True for uncased " 39 | "models and False for cased models.") 40 | 41 | flags.DEFINE_bool( 42 | "use_einsum", True, 43 | "Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must " 44 | "be set to False for TFLite compatibility.") 45 | 46 | flags.DEFINE_string("export_path", None, "Path to the output TF-Hub module.") 47 | 48 | FLAGS = flags.FLAGS 49 | 50 | 51 | def gather_indexes(sequence_tensor, positions): 52 | """Gathers the vectors at the specific positions over a minibatch.""" 53 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 54 | batch_size = sequence_shape[0] 55 | seq_length = sequence_shape[1] 56 | width = sequence_shape[2] 57 | 58 | flat_offsets = tf.reshape( 59 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 60 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 61 | flat_sequence_tensor = tf.reshape(sequence_tensor, 62 | [batch_size * seq_length, width]) 63 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 64 | return output_tensor 65 | 66 | 67 | def get_mlm_logits(model, albert_config, mlm_positions): 68 | """From run_pretraining.py.""" 69 | input_tensor = gather_indexes(model.get_sequence_output(), mlm_positions) 70 | with tf.variable_scope("cls/predictions"): 71 | # We apply one more non-linear transformation before the output layer. 72 | # This matrix is not used after pre-training. 73 | with tf.variable_scope("transform"): 74 | input_tensor = tf.layers.dense( 75 | input_tensor, 76 | units=albert_config.embedding_size, 77 | activation=modeling.get_activation(albert_config.hidden_act), 78 | kernel_initializer=modeling.create_initializer( 79 | albert_config.initializer_range)) 80 | input_tensor = modeling.layer_norm(input_tensor) 81 | 82 | # The output weights are the same as the input embeddings, but there is 83 | # an output-only bias for each token. 84 | output_bias = tf.get_variable( 85 | "output_bias", 86 | shape=[albert_config.vocab_size], 87 | initializer=tf.zeros_initializer()) 88 | logits = tf.matmul( 89 | input_tensor, model.get_embedding_table(), transpose_b=True) 90 | logits = tf.nn.bias_add(logits, output_bias) 91 | return logits 92 | 93 | 94 | def get_sop_log_probs(model, albert_config): 95 | """Get loss and log probs for the next sentence prediction.""" 96 | input_tensor = model.get_pooled_output() 97 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 98 | # "random sentence". This weight matrix is not used after pre-training. 99 | with tf.variable_scope("cls/seq_relationship"): 100 | output_weights = tf.get_variable( 101 | "output_weights", 102 | shape=[2, albert_config.hidden_size], 103 | initializer=modeling.create_initializer( 104 | albert_config.initializer_range)) 105 | output_bias = tf.get_variable( 106 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 107 | 108 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 109 | logits = tf.nn.bias_add(logits, output_bias) 110 | log_probs = tf.nn.log_softmax(logits, axis=-1) 111 | return log_probs 112 | 113 | 114 | def module_fn(is_training): 115 | """Module function.""" 116 | input_ids = tf.placeholder(tf.int32, [None, None], "input_ids") 117 | input_mask = tf.placeholder(tf.int32, [None, None], "input_mask") 118 | segment_ids = tf.placeholder(tf.int32, [None, None], "segment_ids") 119 | mlm_positions = tf.placeholder(tf.int32, [None, None], "mlm_positions") 120 | 121 | albert_config_path = os.path.join( 122 | FLAGS.albert_directory, "albert_config.json") 123 | albert_config = modeling.AlbertConfig.from_json_file(albert_config_path) 124 | model = modeling.AlbertModel( 125 | config=albert_config, 126 | is_training=is_training, 127 | input_ids=input_ids, 128 | input_mask=input_mask, 129 | token_type_ids=segment_ids, 130 | use_one_hot_embeddings=False, 131 | use_einsum=FLAGS.use_einsum) 132 | 133 | mlm_logits = get_mlm_logits(model, albert_config, mlm_positions) 134 | sop_log_probs = get_sop_log_probs(model, albert_config) 135 | 136 | vocab_model_path = os.path.join(FLAGS.albert_directory, "30k-clean.model") 137 | vocab_file_path = os.path.join(FLAGS.albert_directory, "30k-clean.vocab") 138 | 139 | config_file = tf.constant( 140 | value=albert_config_path, dtype=tf.string, name="config_file") 141 | vocab_model = tf.constant( 142 | value=vocab_model_path, dtype=tf.string, name="vocab_model") 143 | # This is only for visualization purpose. 144 | vocab_file = tf.constant( 145 | value=vocab_file_path, dtype=tf.string, name="vocab_file") 146 | 147 | # By adding `config_file, vocab_model and vocab_file` 148 | # to the ASSET_FILEPATHS collection, TF-Hub will 149 | # rewrite this tensor so that this asset is portable. 150 | tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, config_file) 151 | tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_model) 152 | tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file) 153 | 154 | hub.add_signature( 155 | name="tokens", 156 | inputs=dict( 157 | input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids), 158 | outputs=dict( 159 | sequence_output=model.get_sequence_output(), 160 | pooled_output=model.get_pooled_output())) 161 | 162 | hub.add_signature( 163 | name="sop", 164 | inputs=dict( 165 | input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids), 166 | outputs=dict( 167 | sequence_output=model.get_sequence_output(), 168 | pooled_output=model.get_pooled_output(), 169 | sop_log_probs=sop_log_probs)) 170 | 171 | hub.add_signature( 172 | name="mlm", 173 | inputs=dict( 174 | input_ids=input_ids, 175 | input_mask=input_mask, 176 | segment_ids=segment_ids, 177 | mlm_positions=mlm_positions), 178 | outputs=dict( 179 | sequence_output=model.get_sequence_output(), 180 | pooled_output=model.get_pooled_output(), 181 | mlm_logits=mlm_logits)) 182 | 183 | hub.add_signature( 184 | name="tokenization_info", 185 | inputs={}, 186 | outputs=dict( 187 | vocab_file=vocab_model, 188 | do_lower_case=tf.constant(FLAGS.do_lower_case))) 189 | 190 | 191 | def main(_): 192 | tags_and_args = [] 193 | for is_training in (True, False): 194 | tags = set() 195 | if is_training: 196 | tags.add("train") 197 | tags_and_args.append((tags, dict(is_training=is_training))) 198 | spec = hub.create_module_spec(module_fn, tags_and_args=tags_and_args) 199 | checkpoint_path = os.path.join(FLAGS.albert_directory, FLAGS.checkpoint_name) 200 | tf.logging.info("Using checkpoint {}".format(checkpoint_path)) 201 | spec.export(FLAGS.export_path, checkpoint_path=checkpoint_path) 202 | 203 | 204 | if __name__ == "__main__": 205 | flags.mark_flag_as_required("albert_directory") 206 | flags.mark_flag_as_required("export_path") 207 | app.run(main) 208 | -------------------------------------------------------------------------------- /modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import collections 21 | import json 22 | import random 23 | import re 24 | 25 | from albert import modeling 26 | import numpy as np 27 | import six 28 | from six.moves import range 29 | import tensorflow.compat.v1 as tf 30 | 31 | 32 | class AlbertModelTest(tf.test.TestCase): 33 | 34 | class AlbertModelTester(object): 35 | 36 | def __init__(self, 37 | parent, 38 | batch_size=13, 39 | seq_length=7, 40 | is_training=True, 41 | use_input_mask=True, 42 | use_token_type_ids=True, 43 | vocab_size=99, 44 | embedding_size=32, 45 | hidden_size=32, 46 | num_hidden_layers=5, 47 | num_attention_heads=4, 48 | intermediate_size=37, 49 | hidden_act="gelu", 50 | hidden_dropout_prob=0.1, 51 | attention_probs_dropout_prob=0.1, 52 | max_position_embeddings=512, 53 | type_vocab_size=16, 54 | initializer_range=0.02, 55 | scope=None): 56 | self.parent = parent 57 | self.batch_size = batch_size 58 | self.seq_length = seq_length 59 | self.is_training = is_training 60 | self.use_input_mask = use_input_mask 61 | self.use_token_type_ids = use_token_type_ids 62 | self.vocab_size = vocab_size 63 | self.embedding_size = embedding_size 64 | self.hidden_size = hidden_size 65 | self.num_hidden_layers = num_hidden_layers 66 | self.num_attention_heads = num_attention_heads 67 | self.intermediate_size = intermediate_size 68 | self.hidden_act = hidden_act 69 | self.hidden_dropout_prob = hidden_dropout_prob 70 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 71 | self.max_position_embeddings = max_position_embeddings 72 | self.type_vocab_size = type_vocab_size 73 | self.initializer_range = initializer_range 74 | self.scope = scope 75 | 76 | def create_model(self): 77 | input_ids = AlbertModelTest.ids_tensor([self.batch_size, self.seq_length], 78 | self.vocab_size) 79 | 80 | input_mask = None 81 | if self.use_input_mask: 82 | input_mask = AlbertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], vocab_size=2) 84 | 85 | token_type_ids = None 86 | if self.use_token_type_ids: 87 | token_type_ids = AlbertModelTest.ids_tensor( 88 | [self.batch_size, self.seq_length], self.type_vocab_size) 89 | 90 | config = modeling.AlbertConfig( 91 | vocab_size=self.vocab_size, 92 | embedding_size=self.embedding_size, 93 | hidden_size=self.hidden_size, 94 | num_hidden_layers=self.num_hidden_layers, 95 | num_attention_heads=self.num_attention_heads, 96 | intermediate_size=self.intermediate_size, 97 | hidden_act=self.hidden_act, 98 | hidden_dropout_prob=self.hidden_dropout_prob, 99 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 100 | max_position_embeddings=self.max_position_embeddings, 101 | type_vocab_size=self.type_vocab_size, 102 | initializer_range=self.initializer_range) 103 | 104 | model = modeling.AlbertModel( 105 | config=config, 106 | is_training=self.is_training, 107 | input_ids=input_ids, 108 | input_mask=input_mask, 109 | token_type_ids=token_type_ids, 110 | scope=self.scope) 111 | 112 | outputs = { 113 | "embedding_output": model.get_embedding_output(), 114 | "sequence_output": model.get_sequence_output(), 115 | "pooled_output": model.get_pooled_output(), 116 | "all_encoder_layers": model.get_all_encoder_layers(), 117 | } 118 | return outputs 119 | 120 | def check_output(self, result): 121 | self.parent.assertAllEqual( 122 | result["embedding_output"].shape, 123 | [self.batch_size, self.seq_length, self.embedding_size]) 124 | 125 | self.parent.assertAllEqual( 126 | result["sequence_output"].shape, 127 | [self.batch_size, self.seq_length, self.hidden_size]) 128 | 129 | self.parent.assertAllEqual(result["pooled_output"].shape, 130 | [self.batch_size, self.hidden_size]) 131 | 132 | def test_default(self): 133 | self.run_tester(AlbertModelTest.AlbertModelTester(self)) 134 | 135 | def test_config_to_json_string(self): 136 | config = modeling.AlbertConfig(vocab_size=99, hidden_size=37) 137 | obj = json.loads(config.to_json_string()) 138 | self.assertEqual(obj["vocab_size"], 99) 139 | self.assertEqual(obj["hidden_size"], 37) 140 | 141 | def test_einsum_via_matmul(self): 142 | batch_size = 8 143 | seq_length = 12 144 | num_attention_heads = 3 145 | head_size = 6 146 | hidden_size = 10 147 | 148 | input_tensor = np.random.uniform(0, 1, 149 | [batch_size, seq_length, hidden_size]) 150 | input_tensor = tf.constant(input_tensor, dtype=tf.float32) 151 | w = np.random.uniform(0, 1, [hidden_size, num_attention_heads, head_size]) 152 | w = tf.constant(w, dtype=tf.float32) 153 | ret1 = tf.einsum("BFH,HND->BFND", input_tensor, w) 154 | ret2 = modeling.einsum_via_matmul(input_tensor, w, 1) 155 | self.assertAllClose(ret1, ret2) 156 | 157 | input_tensor = np.random.uniform(0, 1, 158 | [batch_size, seq_length, 159 | num_attention_heads, head_size]) 160 | input_tensor = tf.constant(input_tensor, dtype=tf.float32) 161 | w = np.random.uniform(0, 1, [num_attention_heads, head_size, hidden_size]) 162 | w = tf.constant(w, dtype=tf.float32) 163 | ret1 = tf.einsum("BFND,NDH->BFH", input_tensor, w) 164 | ret2 = modeling.einsum_via_matmul(input_tensor, w, 2) 165 | self.assertAllClose(ret1, ret2) 166 | 167 | def run_tester(self, tester): 168 | with self.test_session() as sess: 169 | ops = tester.create_model() 170 | init_op = tf.group(tf.global_variables_initializer(), 171 | tf.local_variables_initializer()) 172 | sess.run(init_op) 173 | output_result = sess.run(ops) 174 | tester.check_output(output_result) 175 | 176 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 177 | 178 | @classmethod 179 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 180 | """Creates a random int32 tensor of the shape within the vocab size.""" 181 | if rng is None: 182 | rng = random.Random() 183 | 184 | total_dims = 1 185 | for dim in shape: 186 | total_dims *= dim 187 | 188 | values = [] 189 | for _ in range(total_dims): 190 | values.append(rng.randint(0, vocab_size - 1)) 191 | 192 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 193 | 194 | def assert_all_tensors_reachable(self, sess, outputs): 195 | """Checks that all the tensors in the graph are reachable from outputs.""" 196 | graph = sess.graph 197 | 198 | ignore_strings = [ 199 | "^.*/assert_less_equal/.*$", 200 | "^.*/dilation_rate$", 201 | "^.*/Tensordot/concat$", 202 | "^.*/Tensordot/concat/axis$", 203 | "^testing/.*$", 204 | ] 205 | 206 | ignore_regexes = [re.compile(x) for x in ignore_strings] 207 | 208 | unreachable = self.get_unreachable_ops(graph, outputs) 209 | filtered_unreachable = [] 210 | for x in unreachable: 211 | do_ignore = False 212 | for r in ignore_regexes: 213 | m = r.match(six.ensure_str(x.name)) 214 | if m is not None: 215 | do_ignore = True 216 | if do_ignore: 217 | continue 218 | filtered_unreachable.append(x) 219 | unreachable = filtered_unreachable 220 | 221 | self.assertEqual( 222 | len(unreachable), 0, "The following ops are unreachable: %s" % 223 | (" ".join([x.name for x in unreachable]))) 224 | 225 | @classmethod 226 | def get_unreachable_ops(cls, graph, outputs): 227 | """Finds all of the tensors in graph that are unreachable from outputs.""" 228 | outputs = cls.flatten_recursive(outputs) 229 | output_to_op = collections.defaultdict(list) 230 | op_to_all = collections.defaultdict(list) 231 | assign_out_to_in = collections.defaultdict(list) 232 | 233 | for op in graph.get_operations(): 234 | for x in op.inputs: 235 | op_to_all[op.name].append(x.name) 236 | for y in op.outputs: 237 | output_to_op[y.name].append(op.name) 238 | op_to_all[op.name].append(y.name) 239 | if str(op.type) == "Assign": 240 | for y in op.outputs: 241 | for x in op.inputs: 242 | assign_out_to_in[y.name].append(x.name) 243 | 244 | assign_groups = collections.defaultdict(list) 245 | for out_name in assign_out_to_in.keys(): 246 | name_group = assign_out_to_in[out_name] 247 | for n1 in name_group: 248 | assign_groups[n1].append(out_name) 249 | for n2 in name_group: 250 | if n1 != n2: 251 | assign_groups[n1].append(n2) 252 | 253 | seen_tensors = {} 254 | stack = [x.name for x in outputs] 255 | while stack: 256 | name = stack.pop() 257 | if name in seen_tensors: 258 | continue 259 | seen_tensors[name] = True 260 | 261 | if name in output_to_op: 262 | for op_name in output_to_op[name]: 263 | if op_name in op_to_all: 264 | for input_name in op_to_all[op_name]: 265 | if input_name not in stack: 266 | stack.append(input_name) 267 | 268 | expanded_names = [] 269 | if name in assign_groups: 270 | for assign_name in assign_groups[name]: 271 | expanded_names.append(assign_name) 272 | 273 | for expanded_name in expanded_names: 274 | if expanded_name not in stack: 275 | stack.append(expanded_name) 276 | 277 | unreachable_ops = [] 278 | for op in graph.get_operations(): 279 | is_unreachable = False 280 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 281 | for name in all_names: 282 | if name not in seen_tensors: 283 | is_unreachable = True 284 | if is_unreachable: 285 | unreachable_ops.append(op) 286 | return unreachable_ops 287 | 288 | @classmethod 289 | def flatten_recursive(cls, item): 290 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 291 | output = [] 292 | if isinstance(item, list): 293 | output.extend(item) 294 | elif isinstance(item, tuple): 295 | output.extend(list(item)) 296 | elif isinstance(item, dict): 297 | for (_, v) in six.iteritems(item): 298 | output.append(v) 299 | else: 300 | return [item] 301 | 302 | flat_output = [] 303 | for x in output: 304 | flat_output.extend(cls.flatten_recursive(x)) 305 | return flat_output 306 | 307 | 308 | if __name__ == "__main__": 309 | tf.test.main() 310 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /albert_glue_fine_tuning_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "albert_glue_fine_tuning_tutorial", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "TPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "y8SJfpgTccDB", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "\n", 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "wHQH4OCHZ9bq", 33 | "colab_type": "code", 34 | "cellView": "form", 35 | "colab": {} 36 | }, 37 | "source": [ 38 | "# @title Copyright 2020 The ALBERT Authors. All Rights Reserved.\n", 39 | "#\n", 40 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", 41 | "# you may not use this file except in compliance with the License.\n", 42 | "# You may obtain a copy of the License at\n", 43 | "#\n", 44 | "# http://www.apache.org/licenses/LICENSE-2.0\n", 45 | "#\n", 46 | "# Unless required by applicable law or agreed to in writing, software\n", 47 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 48 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 49 | "# See the License for the specific language governing permissions and\n", 50 | "# limitations under the License.\n", 51 | "# ==============================================================================" 52 | ], 53 | "execution_count": 0, 54 | "outputs": [] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "id": "rkTLZ3I4_7c_", 60 | "colab_type": "text" 61 | }, 62 | "source": [ 63 | "# ALBERT End to End (Fine-tuning + Predicting) with Cloud TPU" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "id": "1wtjs1QDb3DX", 70 | "colab_type": "text" 71 | }, 72 | "source": [ 73 | "## Overview\n", 74 | "\n", 75 | "ALBERT is \"A Lite\" version of BERT, a popular unsupervised language representation learning algorithm. ALBERT uses parameter-reduction techniques that allow for large-scale configurations, overcome previous memory limitations, and achieve better behavior with respect to model degradation.\n", 76 | "\n", 77 | "For a technical description of the algorithm, see our paper:\n", 78 | "\n", 79 | "https://arxiv.org/abs/1909.11942\n", 80 | "\n", 81 | "Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut\n", 82 | "\n", 83 | "This Colab demonstates using a free Colab Cloud TPU to fine-tune GLUE tasks built on top of pretrained ALBERT models and \n", 84 | "run predictions on tuned model. The colab demonsrates loading pretrained ALBERT models from both [TF Hub](https://www.tensorflow.org/hub) and checkpoints.\n", 85 | "\n", 86 | "**Note:** You will need a GCP (Google Compute Engine) account and a GCS (Google Cloud \n", 87 | "Storage) bucket for this Colab to run.\n", 88 | "\n", 89 | "Please follow the [Google Cloud TPU quickstart](https://cloud.google.com/tpu/docs/quickstart) for how to create GCP account and GCS bucket. You have [$300 free credit](https://cloud.google.com/free/) to get started with any GCP product. You can learn more about Cloud TPU at https://cloud.google.com/tpu/docs.\n", 90 | "\n", 91 | "This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select **File > View on GitHub**." 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": { 97 | "id": "Ld-JXlueIuPH", 98 | "colab_type": "text" 99 | }, 100 | "source": [ 101 | "## Instructions" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": { 107 | "id": "POkof5uHaQ_c", 108 | "colab_type": "text" 109 | }, 110 | "source": [ 111 | "

  Train on TPU

\n", 112 | "\n", 113 | " 1. Create a Cloud Storage bucket for your TensorBoard logs at http://console.cloud.google.com/storage and fill in the BUCKET parameter in the \"Parameters\" section below.\n", 114 | " \n", 115 | " 1. On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n", 116 | " 1. Click Runtime again and select **Runtime > Run All** (Watch out: the \"Colab-only auth for this notebook and the TPU\" cell requires user input). You can also run the cells manually with Shift-ENTER." 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": { 122 | "id": "UdMmwCJFaT8F", 123 | "colab_type": "text" 124 | }, 125 | "source": [ 126 | "### Set up your TPU environment\n", 127 | "\n", 128 | "In this section, you perform the following tasks:\n", 129 | "\n", 130 | "* Set up a Colab TPU running environment\n", 131 | "* Verify that you are connected to a TPU device\n", 132 | "* Upload your credentials to TPU to access your GCS bucket." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "metadata": { 138 | "id": "191zq3ZErihP", 139 | "colab_type": "code", 140 | "colab": {} 141 | }, 142 | "source": [ 143 | "# TODO(lanzhzh): Add support for 2.x.\n", 144 | "%tensorflow_version 1.x\n", 145 | "import os\n", 146 | "import pprint\n", 147 | "import json\n", 148 | "import tensorflow as tf\n", 149 | "\n", 150 | "assert \"COLAB_TPU_ADDR\" in os.environ, \"ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!\"\n", 151 | "TPU_ADDRESS = \"grpc://\" + os.environ[\"COLAB_TPU_ADDR\"] \n", 152 | "TPU_TOPOLOGY = \"2x2\"\n", 153 | "print(\"TPU address is\", TPU_ADDRESS)\n", 154 | "\n", 155 | "from google.colab import auth\n", 156 | "auth.authenticate_user()\n", 157 | "with tf.Session(TPU_ADDRESS) as session:\n", 158 | " print('TPU devices:')\n", 159 | " pprint.pprint(session.list_devices())\n", 160 | "\n", 161 | " # Upload credentials to TPU.\n", 162 | " with open('/content/adc.json', 'r') as f:\n", 163 | " auth_info = json.load(f)\n", 164 | " tf.contrib.cloud.configure_gcs(session, credentials=auth_info)\n", 165 | " # Now credentials are set for all future sessions on this TPU." 166 | ], 167 | "execution_count": 0, 168 | "outputs": [] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": { 173 | "id": "HUBP35oCDmbF", 174 | "colab_type": "text" 175 | }, 176 | "source": [ 177 | "### Prepare and import ALBERT modules\n", 178 | "​\n", 179 | "With your environment configured, you can now prepare and import the ALBERT modules. The following step clones the source code from GitHub." 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "metadata": { 185 | "id": "7wzwke0sxS6W", 186 | "colab_type": "code", 187 | "colab": {}, 188 | "cellView": "code" 189 | }, 190 | "source": [ 191 | "#TODO(lanzhzh): Add pip support\n", 192 | "import sys\n", 193 | "\n", 194 | "!test -d albert || git clone https://github.com/google-research/albert albert\n", 195 | "if not 'albert' in sys.path:\n", 196 | " sys.path += ['albert']\n", 197 | " \n", 198 | "!pip install sentencepiece\n" 199 | ], 200 | "execution_count": 0, 201 | "outputs": [] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": { 206 | "id": "RRu1aKO1D7-Z", 207 | "colab_type": "text" 208 | }, 209 | "source": [ 210 | "### Prepare for training\n", 211 | "\n", 212 | "This next section of code performs the following tasks:\n", 213 | "\n", 214 | "* Specify GS bucket, create output directory for model checkpoints and eval results.\n", 215 | "* Specify task and download training data.\n", 216 | "* Specify ALBERT pretrained model\n", 217 | "\n", 218 | "\n", 219 | "\n" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "metadata": { 225 | "id": "tYkaAlJNfhul", 226 | "colab_type": "code", 227 | "colab": {}, 228 | "cellView": "form" 229 | }, 230 | "source": [ 231 | "# Please find the full list of tasks and their fintuning hyperparameters\n", 232 | "# here https://github.com/google-research/albert/blob/master/run_glue.sh\n", 233 | "\n", 234 | "BUCKET = \"albert_tutorial_glue\" #@param { type: \"string\" }\n", 235 | "TASK = 'MRPC' #@param {type:\"string\"}\n", 236 | "# Available pretrained model checkpoints:\n", 237 | "# base, large, xlarge, xxlarge\n", 238 | "ALBERT_MODEL = 'base' #@param {type:\"string\"}\n", 239 | "\n", 240 | "TASK_DATA_DIR = 'glue_data'\n", 241 | "\n", 242 | "BASE_DIR = \"gs://\" + BUCKET\n", 243 | "if not BASE_DIR or BASE_DIR == \"gs://\":\n", 244 | " raise ValueError(\"You must enter a BUCKET.\")\n", 245 | "DATA_DIR = os.path.join(BASE_DIR, \"data\")\n", 246 | "MODELS_DIR = os.path.join(BASE_DIR, \"models\")\n", 247 | "OUTPUT_DIR = 'gs://{}/albert-tfhub/models/{}'.format(BUCKET, TASK)\n", 248 | "tf.gfile.MakeDirs(OUTPUT_DIR)\n", 249 | "print('***** Model output directory: {} *****'.format(OUTPUT_DIR))\n", 250 | "\n", 251 | "# Download glue data.\n", 252 | "! test -d download_glue_repo || git clone https://gist.github.com/60c2bdb54d156a41194446737ce03e2e.git download_glue_repo\n", 253 | "!python download_glue_repo/download_glue_data.py --data_dir=$TASK_DATA_DIR --tasks=$TASK\n", 254 | "print('***** Task data directory: {} *****'.format(TASK_DATA_DIR))\n", 255 | "\n", 256 | "ALBERT_MODEL_HUB = 'https://tfhub.dev/google/albert_' + ALBERT_MODEL + '/3'" 257 | ], 258 | "execution_count": 0, 259 | "outputs": [] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": { 264 | "id": "Hcpfl4N2EdOk", 265 | "colab_type": "text" 266 | }, 267 | "source": [ 268 | "Now let's run the fine-tuning scripts. If you use the default MRPC task, this should be finished in around 10 mintues and you will get an accuracy of around 86.5." 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "metadata": { 274 | "id": "o8qXPxv8-kBO", 275 | "colab_type": "code", 276 | "colab": {} 277 | }, 278 | "source": [ 279 | "os.environ['TFHUB_CACHE_DIR'] = OUTPUT_DIR\n", 280 | "!python -m albert.run_classifier \\\n", 281 | " --data_dir=\"glue_data/\" \\\n", 282 | " --output_dir=$OUTPUT_DIR \\\n", 283 | " --albert_hub_module_handle=$ALBERT_MODEL_HUB \\\n", 284 | " --spm_model_file=\"from_tf_hub\" \\\n", 285 | " --do_train=True \\\n", 286 | " --do_eval=True \\\n", 287 | " --do_predict=False \\\n", 288 | " --max_seq_length=512 \\\n", 289 | " --optimizer=adamw \\\n", 290 | " --task_name=$TASK \\\n", 291 | " --warmup_step=200 \\\n", 292 | " --learning_rate=2e-5 \\\n", 293 | " --train_step=800 \\\n", 294 | " --save_checkpoints_steps=100 \\\n", 295 | " --train_batch_size=32 \\\n", 296 | " --tpu_name=$TPU_ADDRESS \\\n", 297 | " --use_tpu=True" 298 | ], 299 | "execution_count": 0, 300 | "outputs": [] 301 | } 302 | ] 303 | } 304 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ALBERT 2 | ====== 3 | 4 | ***************New March 28, 2020 *************** 5 | 6 | Add a colab [tutorial](https://github.com/google-research/albert/blob/master/albert_glue_fine_tuning_tutorial.ipynb) to run fine-tuning for GLUE datasets. 7 | 8 | ***************New January 7, 2020 *************** 9 | 10 | v2 TF-Hub models should be working now with TF 1.15, as we removed the 11 | native Einsum op from the graph. See updated TF-Hub links below. 12 | 13 | ***************New December 30, 2019 *************** 14 | 15 | Chinese models are released. We would like to thank [CLUE team ](https://github.com/CLUEbenchmark/CLUE) for providing the training data. 16 | 17 | - [Base](https://storage.googleapis.com/albert_models/albert_base_zh.tar.gz) 18 | - [Large](https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz) 19 | - [Xlarge](https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz) 20 | - [Xxlarge](https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz) 21 | 22 | Version 2 of ALBERT models is released. 23 | 24 | - Base: [[Tar file](https://storage.googleapis.com/albert_models/albert_base_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_base/3)] 25 | - Large: [[Tar file](https://storage.googleapis.com/albert_models/albert_large_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_large/3)] 26 | - Xlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xlarge_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xlarge/3)] 27 | - Xxlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xxlarge_v2.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xxlarge/3)] 28 | 29 | In this version, we apply 'no dropout', 'additional training data' and 'long training time' strategies to all models. We train ALBERT-base for 10M steps and other models for 3M steps. 30 | 31 | The result comparison to the v1 models is as followings: 32 | 33 | | | Average | SQuAD1.1 | SQuAD2.0 | MNLI | SST-2 | RACE | 34 | |----------------|----------|----------|----------|----------|----------|----------| 35 | |V2 | 36 | |ALBERT-base |82.3 |90.2/83.2 |82.1/79.3 |84.6 |92.9 |66.8 | 37 | |ALBERT-large |85.7 |91.8/85.2 |84.9/81.8 |86.5 |94.9 |75.2 | 38 | |ALBERT-xlarge |87.9 |92.9/86.4 |87.9/84.1 |87.9 |95.4 |80.7 | 39 | |ALBERT-xxlarge |90.9 |94.6/89.1 |89.8/86.9 |90.6 |96.8 |86.8 | 40 | |V1 | 41 | |ALBERT-base |80.1 |89.3/82.3 | 80.0/77.1|81.6 |90.3 | 64.0 | 42 | |ALBERT-large |82.4 |90.6/83.9 | 82.3/79.4|83.5 |91.7 | 68.5 | 43 | |ALBERT-xlarge |85.5 |92.5/86.1 | 86.1/83.1|86.4 |92.4 | 74.8 | 44 | |ALBERT-xxlarge |91.0 |94.8/89.3 | 90.2/87.4|90.8 |96.9 | 86.5 | 45 | 46 | The comparison shows that for ALBERT-base, ALBERT-large, and ALBERT-xlarge, v2 is much better than v1, indicating the importance of applying the above three strategies. On average, ALBERT-xxlarge is slightly worse than the v1, because of the following two reasons: 1) Training additional 1.5 M steps (the only difference between these two models is training for 1.5M steps and 3M steps) did not lead to significant performance improvement. 2) For v1, we did a little bit hyperparameter search among the parameters sets given by BERT, Roberta, and XLnet. For v2, we simply adopt the parameters from v1 except for RACE, where we use a learning rate of 1e-5 and 0 [ALBERT DR](https://arxiv.org/pdf/1909.11942.pdf) (dropout rate for ALBERT in finetuning). The original (v1) RACE hyperparameter will cause model divergence for v2 models. Given that the downstream tasks are sensitive to the fine-tuning hyperparameters, we should be careful about so called slight improvements. 47 | 48 | ALBERT is "A Lite" version of BERT, a popular unsupervised language 49 | representation learning algorithm. ALBERT uses parameter-reduction techniques 50 | that allow for large-scale configurations, overcome previous memory limitations, 51 | and achieve better behavior with respect to model degradation. 52 | 53 | For a technical description of the algorithm, see our paper: 54 | 55 | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) 56 | 57 | Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 58 | 59 | Release Notes 60 | ============= 61 | 62 | - Initial release: 10/9/2019 63 | 64 | Results 65 | ======= 66 | 67 | Performance of ALBERT on GLUE benchmark results using a single-model setup on 68 | dev: 69 | 70 | | Models | MNLI | QNLI | QQP | RTE | SST | MRPC | CoLA | STS | 71 | |-------------------|----------|----------|----------|----------|----------|----------|----------|----------| 72 | | BERT-large | 86.6 | 92.3 | 91.3 | 70.4 | 93.2 | 88.0 | 60.6 | 90.0 | 73 | | XLNet-large | 89.8 | 93.9 | 91.8 | 83.8 | 95.6 | 89.2 | 63.6 | 91.8 | 74 | | RoBERTa-large | 90.2 | 94.7 | **92.2** | 86.6 | 96.4 | **90.9** | 68.0 | 92.4 | 75 | | ALBERT (1M) | 90.4 | 95.2 | 92.0 | 88.1 | 96.8 | 90.2 | 68.7 | 92.7 | 76 | | ALBERT (1.5M) | **90.8** | **95.3** | **92.2** | **89.2** | **96.9** | **90.9** | **71.4** | **93.0** | 77 | 78 | Performance of ALBERT-xxl on SQuaD and RACE benchmarks using a single-model 79 | setup: 80 | 81 | |Models | SQuAD1.1 dev | SQuAD2.0 dev | SQuAD2.0 test | RACE test (Middle/High) | 82 | |--------------------------|---------------|---------------|---------------|-------------------------| 83 | |BERT-large | 90.9/84.1 | 81.8/79.0 | 89.1/86.3 | 72.0 (76.6/70.1) | 84 | |XLNet | 94.5/89.0 | 88.8/86.1 | 89.1/86.3 | 81.8 (85.5/80.2) | 85 | |RoBERTa | 94.6/88.9 | 89.4/86.5 | 89.8/86.8 | 83.2 (86.5/81.3) | 86 | |UPM | - | - | 89.9/87.2 | - | 87 | |XLNet + SG-Net Verifier++ | - | - | 90.1/87.2 | - | 88 | |ALBERT (1M) | 94.8/89.2 | 89.9/87.2 | - | 86.0 (88.2/85.1) | 89 | |ALBERT (1.5M) | **94.8/89.3** | **90.2/87.4** | **90.9/88.1** | **86.5 (89.0/85.5)** | 90 | 91 | 92 | Pre-trained Models 93 | ================== 94 | TF-Hub modules are available: 95 | 96 | - Base: [[Tar file](https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_base/1)] 97 | - Large: [[Tar file](https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_large/1)] 98 | - Xlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xlarge_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xlarge/1)] 99 | - Xxlarge: [[Tar file](https://storage.googleapis.com/albert_models/albert_xxlarge_v1.tar.gz)] [[TF-Hub](https://tfhub.dev/google/albert_xxlarge/1)] 100 | 101 | Example usage of the TF-Hub module in code: 102 | 103 | ``` 104 | tags = set() 105 | if is_training: 106 | tags.add("train") 107 | albert_module = hub.Module("https://tfhub.dev/google/albert_base/1", tags=tags, 108 | trainable=True) 109 | albert_inputs = dict( 110 | input_ids=input_ids, 111 | input_mask=input_mask, 112 | segment_ids=segment_ids) 113 | albert_outputs = albert_module( 114 | inputs=albert_inputs, 115 | signature="tokens", 116 | as_dict=True) 117 | 118 | # If you want to use the token-level output, use 119 | # albert_outputs["sequence_output"] instead. 120 | output_layer = albert_outputs["pooled_output"] 121 | ``` 122 | 123 | Most of the fine-tuning scripts in this repository support TF-hub modules 124 | via the `--albert_hub_module_handle` flag. 125 | 126 | Pre-training Instructions 127 | ========================= 128 | To pretrain ALBERT, use `run_pretraining.py`: 129 | 130 | ``` 131 | pip install -r albert/requirements.txt 132 | python -m albert.run_pretraining \ 133 | --input_file=... \ 134 | --output_dir=... \ 135 | --init_checkpoint=... \ 136 | --albert_config_file=... \ 137 | --do_train \ 138 | --do_eval \ 139 | --train_batch_size=4096 \ 140 | --eval_batch_size=64 \ 141 | --max_seq_length=512 \ 142 | --max_predictions_per_seq=20 \ 143 | --optimizer='lamb' \ 144 | --learning_rate=.00176 \ 145 | --num_train_steps=125000 \ 146 | --num_warmup_steps=3125 \ 147 | --save_checkpoints_steps=5000 148 | ``` 149 | 150 | Fine-tuning on GLUE 151 | =================== 152 | To fine-tune and evaluate a pretrained ALBERT on GLUE, please see the 153 | convenience script `run_glue.sh`. 154 | 155 | Lower-level use cases may want to use the `run_classifier.py` script directly. 156 | The `run_classifier.py` script is used both for fine-tuning and evaluation of 157 | ALBERT on individual GLUE benchmark tasks, such as MNLI: 158 | 159 | ``` 160 | pip install -r albert/requirements.txt 161 | python -m albert.run_classifier \ 162 | --data_dir=... \ 163 | --output_dir=... \ 164 | --init_checkpoint=... \ 165 | --albert_config_file=... \ 166 | --spm_model_file=... \ 167 | --do_train \ 168 | --do_eval \ 169 | --do_predict \ 170 | --do_lower_case \ 171 | --max_seq_length=128 \ 172 | --optimizer=adamw \ 173 | --task_name=MNLI \ 174 | --warmup_step=1000 \ 175 | --learning_rate=3e-5 \ 176 | --train_step=10000 \ 177 | --save_checkpoints_steps=100 \ 178 | --train_batch_size=128 179 | ``` 180 | 181 | Good default flag values for each GLUE task can be found in `run_glue.sh`. 182 | 183 | You can fine-tune the model starting from TF-Hub modules instead of raw 184 | checkpoints by setting e.g. 185 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead 186 | of `--init_checkpoint`. 187 | 188 | You can find the spm_model_file in the tar files or under the assets folder of 189 | the tf-hub module. The name of the model file is "30k-clean.model". 190 | 191 | After evaluation, the script should report some output like this: 192 | 193 | ``` 194 | ***** Eval results ***** 195 | global_step = ... 196 | loss = ... 197 | masked_lm_accuracy = ... 198 | masked_lm_loss = ... 199 | sentence_order_accuracy = ... 200 | sentence_order_loss = ... 201 | ``` 202 | 203 | Fine-tuning on SQuAD 204 | ==================== 205 | To fine-tune and evaluate a pretrained model on SQuAD v1, use the 206 | `run_squad_v1.py` script: 207 | 208 | ``` 209 | pip install -r albert/requirements.txt 210 | python -m albert.run_squad_v1 \ 211 | --albert_config_file=... \ 212 | --output_dir=... \ 213 | --train_file=... \ 214 | --predict_file=... \ 215 | --train_feature_file=... \ 216 | --predict_feature_file=... \ 217 | --predict_feature_left_file=... \ 218 | --init_checkpoint=... \ 219 | --spm_model_file=... \ 220 | --do_lower_case \ 221 | --max_seq_length=384 \ 222 | --doc_stride=128 \ 223 | --max_query_length=64 \ 224 | --do_train=true \ 225 | --do_predict=true \ 226 | --train_batch_size=48 \ 227 | --predict_batch_size=8 \ 228 | --learning_rate=5e-5 \ 229 | --num_train_epochs=2.0 \ 230 | --warmup_proportion=.1 \ 231 | --save_checkpoints_steps=5000 \ 232 | --n_best_size=20 \ 233 | --max_answer_length=30 234 | ``` 235 | 236 | You can fine-tune the model starting from TF-Hub modules instead of raw 237 | checkpoints by setting e.g. 238 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead 239 | of `--init_checkpoint`. 240 | 241 | For SQuAD v2, use the `run_squad_v2.py` script: 242 | 243 | ``` 244 | pip install -r albert/requirements.txt 245 | python -m albert.run_squad_v2 \ 246 | --albert_config_file=... \ 247 | --output_dir=... \ 248 | --train_file=... \ 249 | --predict_file=... \ 250 | --train_feature_file=... \ 251 | --predict_feature_file=... \ 252 | --predict_feature_left_file=... \ 253 | --init_checkpoint=... \ 254 | --spm_model_file=... \ 255 | --do_lower_case \ 256 | --max_seq_length=384 \ 257 | --doc_stride=128 \ 258 | --max_query_length=64 \ 259 | --do_train \ 260 | --do_predict \ 261 | --train_batch_size=48 \ 262 | --predict_batch_size=8 \ 263 | --learning_rate=5e-5 \ 264 | --num_train_epochs=2.0 \ 265 | --warmup_proportion=.1 \ 266 | --save_checkpoints_steps=5000 \ 267 | --n_best_size=20 \ 268 | --max_answer_length=30 269 | ``` 270 | 271 | You can fine-tune the model starting from TF-Hub modules instead of raw 272 | checkpoints by setting e.g. 273 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead 274 | of `--init_checkpoint`. 275 | 276 | Fine-tuning on RACE 277 | =================== 278 | For RACE, use the `run_race.py` script: 279 | 280 | ``` 281 | pip install -r albert/requirements.txt 282 | python -m albert.run_race \ 283 | --albert_config_file=... \ 284 | --output_dir=... \ 285 | --train_file=... \ 286 | --eval_file=... \ 287 | --data_dir=...\ 288 | --init_checkpoint=... \ 289 | --spm_model_file=... \ 290 | --max_seq_length=512 \ 291 | --max_qa_length=128 \ 292 | --do_train \ 293 | --do_eval \ 294 | --train_batch_size=32 \ 295 | --eval_batch_size=8 \ 296 | --learning_rate=1e-5 \ 297 | --train_step=12000 \ 298 | --warmup_step=1000 \ 299 | --save_checkpoints_steps=100 300 | ``` 301 | 302 | You can fine-tune the model starting from TF-Hub modules instead of raw 303 | checkpoints by setting e.g. 304 | `--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1` instead 305 | of `--init_checkpoint`. 306 | 307 | SentencePiece 308 | ============= 309 | Command for generating the sentence piece vocabulary: 310 | 311 | ``` 312 | spm_train \ 313 | --input all.txt --model_prefix=30k-clean --vocab_size=30000 --logtostderr 314 | --pad_id=0 --unk_id=1 --eos_id=-1 --bos_id=-1 315 | --control_symbols=[CLS],[SEP],[MASK] 316 | --user_defined_symbols="(,),\",-,.,–,£,€" 317 | --shuffle_input_sentence=true --input_sentence_size=10000000 318 | --character_coverage=0.99995 --model_type=unigram 319 | ``` 320 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | # coding=utf-8 17 | """Tokenization classes.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import collections 24 | import unicodedata 25 | import six 26 | from six.moves import range 27 | import tensorflow.compat.v1 as tf 28 | import tensorflow_hub as hub 29 | import sentencepiece as spm 30 | 31 | SPIECE_UNDERLINE = u"▁".encode("utf-8") 32 | 33 | 34 | def preprocess_text(inputs, remove_space=True, lower=False): 35 | """preprocess data by removing extra space and normalize data.""" 36 | outputs = inputs 37 | if remove_space: 38 | outputs = " ".join(inputs.strip().split()) 39 | 40 | if six.PY2 and isinstance(outputs, str): 41 | try: 42 | outputs = six.ensure_text(outputs, "utf-8") 43 | except UnicodeDecodeError: 44 | outputs = six.ensure_text(outputs, "latin-1") 45 | 46 | outputs = unicodedata.normalize("NFKD", outputs) 47 | outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) 48 | if lower: 49 | outputs = outputs.lower() 50 | 51 | return outputs 52 | 53 | 54 | def encode_pieces(sp_model, text, return_unicode=True, sample=False): 55 | """turn sentences into word pieces.""" 56 | 57 | if six.PY2 and isinstance(text, six.text_type): 58 | text = six.ensure_binary(text, "utf-8") 59 | 60 | if not sample: 61 | pieces = sp_model.EncodeAsPieces(text) 62 | else: 63 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) 64 | new_pieces = [] 65 | for piece in pieces: 66 | piece = printable_text(piece) 67 | if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit(): 68 | cur_pieces = sp_model.EncodeAsPieces( 69 | six.ensure_binary(piece[:-1]).replace(SPIECE_UNDERLINE, b"")) 70 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 71 | if len(cur_pieces[0]) == 1: 72 | cur_pieces = cur_pieces[1:] 73 | else: 74 | cur_pieces[0] = cur_pieces[0][1:] 75 | cur_pieces.append(piece[-1]) 76 | new_pieces.extend(cur_pieces) 77 | else: 78 | new_pieces.append(piece) 79 | 80 | # note(zhiliny): convert back to unicode for py2 81 | if six.PY2 and return_unicode: 82 | ret_pieces = [] 83 | for piece in new_pieces: 84 | if isinstance(piece, str): 85 | piece = six.ensure_text(piece, "utf-8") 86 | ret_pieces.append(piece) 87 | new_pieces = ret_pieces 88 | 89 | return new_pieces 90 | 91 | 92 | def encode_ids(sp_model, text, sample=False): 93 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) 94 | ids = [sp_model.PieceToId(piece) for piece in pieces] 95 | return ids 96 | 97 | 98 | def convert_to_unicode(text): 99 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 100 | if six.PY3: 101 | if isinstance(text, str): 102 | return text 103 | elif isinstance(text, bytes): 104 | return six.ensure_text(text, "utf-8", "ignore") 105 | else: 106 | raise ValueError("Unsupported string type: %s" % (type(text))) 107 | elif six.PY2: 108 | if isinstance(text, str): 109 | return six.ensure_text(text, "utf-8", "ignore") 110 | elif isinstance(text, six.text_type): 111 | return text 112 | else: 113 | raise ValueError("Unsupported string type: %s" % (type(text))) 114 | else: 115 | raise ValueError("Not running on Python2 or Python 3?") 116 | 117 | 118 | def printable_text(text): 119 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 120 | 121 | # These functions want `str` for both Python2 and Python3, but in one case 122 | # it's a Unicode string and in the other it's a byte string. 123 | if six.PY3: 124 | if isinstance(text, str): 125 | return text 126 | elif isinstance(text, bytes): 127 | return six.ensure_text(text, "utf-8", "ignore") 128 | else: 129 | raise ValueError("Unsupported string type: %s" % (type(text))) 130 | elif six.PY2: 131 | if isinstance(text, str): 132 | return text 133 | elif isinstance(text, six.text_type): 134 | return six.ensure_binary(text, "utf-8") 135 | else: 136 | raise ValueError("Unsupported string type: %s" % (type(text))) 137 | else: 138 | raise ValueError("Not running on Python2 or Python 3?") 139 | 140 | 141 | def load_vocab(vocab_file): 142 | """Loads a vocabulary file into a dictionary.""" 143 | vocab = collections.OrderedDict() 144 | with tf.gfile.GFile(vocab_file, "r") as reader: 145 | while True: 146 | token = convert_to_unicode(reader.readline()) 147 | if not token: 148 | break 149 | token = token.strip().split()[0] if token.strip() else " " 150 | if token not in vocab: 151 | vocab[token] = len(vocab) 152 | return vocab 153 | 154 | 155 | def convert_by_vocab(vocab, items): 156 | """Converts a sequence of [tokens|ids] using the vocab.""" 157 | output = [] 158 | for item in items: 159 | output.append(vocab[item]) 160 | return output 161 | 162 | 163 | def convert_tokens_to_ids(vocab, tokens): 164 | return convert_by_vocab(vocab, tokens) 165 | 166 | 167 | def convert_ids_to_tokens(inv_vocab, ids): 168 | return convert_by_vocab(inv_vocab, ids) 169 | 170 | 171 | def whitespace_tokenize(text): 172 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 173 | text = text.strip() 174 | if not text: 175 | return [] 176 | tokens = text.split() 177 | return tokens 178 | 179 | 180 | class FullTokenizer(object): 181 | """Runs end-to-end tokenziation.""" 182 | 183 | def __init__(self, vocab_file, do_lower_case=True, spm_model_file=None): 184 | self.vocab = None 185 | self.sp_model = None 186 | if spm_model_file: 187 | self.sp_model = spm.SentencePieceProcessor() 188 | tf.logging.info("loading sentence piece model") 189 | # Handle cases where SP can't load the file, but gfile can. 190 | sp_model_ = tf.gfile.GFile(spm_model_file, "rb").read() 191 | self.sp_model.LoadFromSerializedProto(sp_model_) 192 | # Note(mingdachen): For the purpose of consisent API, we are 193 | # generating a vocabulary for the sentence piece tokenizer. 194 | self.vocab = {self.sp_model.IdToPiece(i): i for i 195 | in range(self.sp_model.GetPieceSize())} 196 | else: 197 | self.vocab = load_vocab(vocab_file) 198 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 199 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 200 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 201 | 202 | @classmethod 203 | def from_scratch(cls, vocab_file, do_lower_case, spm_model_file): 204 | return FullTokenizer(vocab_file, do_lower_case, spm_model_file) 205 | 206 | @classmethod 207 | def from_hub_module(cls, hub_module, use_spm=True): 208 | """Get the vocab file and casing info from the Hub module.""" 209 | with tf.Graph().as_default(): 210 | albert_module = hub.Module(hub_module) 211 | tokenization_info = albert_module(signature="tokenization_info", 212 | as_dict=True) 213 | with tf.Session() as sess: 214 | vocab_file, do_lower_case = sess.run( 215 | [tokenization_info["vocab_file"], 216 | tokenization_info["do_lower_case"]]) 217 | if use_spm: 218 | spm_model_file = vocab_file 219 | vocab_file = None 220 | return FullTokenizer( 221 | vocab_file=vocab_file, do_lower_case=do_lower_case, 222 | spm_model_file=spm_model_file) 223 | 224 | def tokenize(self, text): 225 | if self.sp_model: 226 | split_tokens = encode_pieces(self.sp_model, text, return_unicode=False) 227 | else: 228 | split_tokens = [] 229 | for token in self.basic_tokenizer.tokenize(text): 230 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 231 | split_tokens.append(sub_token) 232 | 233 | return split_tokens 234 | 235 | def convert_tokens_to_ids(self, tokens): 236 | if self.sp_model: 237 | tf.logging.info("using sentence piece tokenzier.") 238 | return [self.sp_model.PieceToId( 239 | printable_text(token)) for token in tokens] 240 | else: 241 | return convert_by_vocab(self.vocab, tokens) 242 | 243 | def convert_ids_to_tokens(self, ids): 244 | if self.sp_model: 245 | tf.logging.info("using sentence piece tokenzier.") 246 | return [self.sp_model.IdToPiece(id_) for id_ in ids] 247 | else: 248 | return convert_by_vocab(self.inv_vocab, ids) 249 | 250 | 251 | class BasicTokenizer(object): 252 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 253 | 254 | def __init__(self, do_lower_case=True): 255 | """Constructs a BasicTokenizer. 256 | 257 | Args: 258 | do_lower_case: Whether to lower case the input. 259 | """ 260 | self.do_lower_case = do_lower_case 261 | 262 | def tokenize(self, text): 263 | """Tokenizes a piece of text.""" 264 | text = convert_to_unicode(text) 265 | text = self._clean_text(text) 266 | 267 | # This was added on November 1st, 2018 for the multilingual and Chinese 268 | # models. This is also applied to the English models now, but it doesn't 269 | # matter since the English models were not trained on any Chinese data 270 | # and generally don't have any Chinese data in them (there are Chinese 271 | # characters in the vocabulary because Wikipedia does have some Chinese 272 | # words in the English Wikipedia.). 273 | text = self._tokenize_chinese_chars(text) 274 | 275 | orig_tokens = whitespace_tokenize(text) 276 | split_tokens = [] 277 | for token in orig_tokens: 278 | if self.do_lower_case: 279 | token = token.lower() 280 | token = self._run_strip_accents(token) 281 | split_tokens.extend(self._run_split_on_punc(token)) 282 | 283 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 284 | return output_tokens 285 | 286 | def _run_strip_accents(self, text): 287 | """Strips accents from a piece of text.""" 288 | text = unicodedata.normalize("NFD", text) 289 | output = [] 290 | for char in text: 291 | cat = unicodedata.category(char) 292 | if cat == "Mn": 293 | continue 294 | output.append(char) 295 | return "".join(output) 296 | 297 | def _run_split_on_punc(self, text): 298 | """Splits punctuation on a piece of text.""" 299 | chars = list(text) 300 | i = 0 301 | start_new_word = True 302 | output = [] 303 | while i < len(chars): 304 | char = chars[i] 305 | if _is_punctuation(char): 306 | output.append([char]) 307 | start_new_word = True 308 | else: 309 | if start_new_word: 310 | output.append([]) 311 | start_new_word = False 312 | output[-1].append(char) 313 | i += 1 314 | 315 | return ["".join(x) for x in output] 316 | 317 | def _tokenize_chinese_chars(self, text): 318 | """Adds whitespace around any CJK character.""" 319 | output = [] 320 | for char in text: 321 | cp = ord(char) 322 | if self._is_chinese_char(cp): 323 | output.append(" ") 324 | output.append(char) 325 | output.append(" ") 326 | else: 327 | output.append(char) 328 | return "".join(output) 329 | 330 | def _is_chinese_char(self, cp): 331 | """Checks whether CP is the codepoint of a CJK character.""" 332 | # This defines a "chinese character" as anything in the CJK Unicode block: 333 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 334 | # 335 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 336 | # despite its name. The modern Korean Hangul alphabet is a different block, 337 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 338 | # space-separated words, so they are not treated specially and handled 339 | # like the all of the other languages. 340 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 341 | (cp >= 0x3400 and cp <= 0x4DBF) or # 342 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 343 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 344 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 345 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 346 | (cp >= 0xF900 and cp <= 0xFAFF) or # 347 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 348 | return True 349 | 350 | return False 351 | 352 | def _clean_text(self, text): 353 | """Performs invalid character removal and whitespace cleanup on text.""" 354 | output = [] 355 | for char in text: 356 | cp = ord(char) 357 | if cp == 0 or cp == 0xfffd or _is_control(char): 358 | continue 359 | if _is_whitespace(char): 360 | output.append(" ") 361 | else: 362 | output.append(char) 363 | return "".join(output) 364 | 365 | 366 | class WordpieceTokenizer(object): 367 | """Runs WordPiece tokenziation.""" 368 | 369 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 370 | self.vocab = vocab 371 | self.unk_token = unk_token 372 | self.max_input_chars_per_word = max_input_chars_per_word 373 | 374 | def tokenize(self, text): 375 | """Tokenizes a piece of text into its word pieces. 376 | 377 | This uses a greedy longest-match-first algorithm to perform tokenization 378 | using the given vocabulary. 379 | 380 | For example: 381 | input = "unaffable" 382 | output = ["un", "##aff", "##able"] 383 | 384 | Args: 385 | text: A single token or whitespace separated tokens. This should have 386 | already been passed through `BasicTokenizer. 387 | 388 | Returns: 389 | A list of wordpiece tokens. 390 | """ 391 | 392 | text = convert_to_unicode(text) 393 | 394 | output_tokens = [] 395 | for token in whitespace_tokenize(text): 396 | chars = list(token) 397 | if len(chars) > self.max_input_chars_per_word: 398 | output_tokens.append(self.unk_token) 399 | continue 400 | 401 | is_bad = False 402 | start = 0 403 | sub_tokens = [] 404 | while start < len(chars): 405 | end = len(chars) 406 | cur_substr = None 407 | while start < end: 408 | substr = "".join(chars[start:end]) 409 | if start > 0: 410 | substr = "##" + six.ensure_str(substr) 411 | if substr in self.vocab: 412 | cur_substr = substr 413 | break 414 | end -= 1 415 | if cur_substr is None: 416 | is_bad = True 417 | break 418 | sub_tokens.append(cur_substr) 419 | start = end 420 | 421 | if is_bad: 422 | output_tokens.append(self.unk_token) 423 | else: 424 | output_tokens.extend(sub_tokens) 425 | return output_tokens 426 | 427 | 428 | def _is_whitespace(char): 429 | """Checks whether `chars` is a whitespace character.""" 430 | # \t, \n, and \r are technically control characters but we treat them 431 | # as whitespace since they are generally considered as such. 432 | if char == " " or char == "\t" or char == "\n" or char == "\r": 433 | return True 434 | cat = unicodedata.category(char) 435 | if cat == "Zs": 436 | return True 437 | return False 438 | 439 | 440 | def _is_control(char): 441 | """Checks whether `chars` is a control character.""" 442 | # These are technically control characters but we count them as whitespace 443 | # characters. 444 | if char == "\t" or char == "\n" or char == "\r": 445 | return False 446 | cat = unicodedata.category(char) 447 | if cat in ("Cc", "Cf"): 448 | return True 449 | return False 450 | 451 | 452 | def _is_punctuation(char): 453 | """Checks whether `chars` is a punctuation character.""" 454 | cp = ord(char) 455 | # We treat all non-letter/number ASCII as punctuation. 456 | # Characters such as "^", "$", and "`" are not in the Unicode 457 | # Punctuation class but we treat them as punctuation anyways, for 458 | # consistency. 459 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 460 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 461 | return True 462 | cat = unicodedata.category(char) 463 | if cat.startswith("P"): 464 | return True 465 | return False 466 | -------------------------------------------------------------------------------- /race_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utility functions for RACE dataset.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | # from __future__ import google_type_annotations 20 | from __future__ import print_function 21 | 22 | import collections 23 | import json 24 | import os 25 | from albert import classifier_utils 26 | from albert import fine_tuning_utils 27 | from albert import modeling 28 | from albert import optimization 29 | from albert import tokenization 30 | import tensorflow.compat.v1 as tf 31 | from tensorflow.contrib import tpu as contrib_tpu 32 | 33 | 34 | class InputExample(object): 35 | """A single training/test example for the RACE dataset.""" 36 | 37 | def __init__(self, 38 | example_id, 39 | context_sentence, 40 | start_ending, 41 | endings, 42 | label=None): 43 | self.example_id = example_id 44 | self.context_sentence = context_sentence 45 | self.start_ending = start_ending 46 | self.endings = endings 47 | self.label = label 48 | 49 | def __str__(self): 50 | return self.__repr__() 51 | 52 | def __repr__(self): 53 | l = [ 54 | "id: {}".format(self.example_id), 55 | "context_sentence: {}".format(self.context_sentence), 56 | "start_ending: {}".format(self.start_ending), 57 | "ending_0: {}".format(self.endings[0]), 58 | "ending_1: {}".format(self.endings[1]), 59 | "ending_2: {}".format(self.endings[2]), 60 | "ending_3: {}".format(self.endings[3]), 61 | ] 62 | 63 | if self.label is not None: 64 | l.append("label: {}".format(self.label)) 65 | 66 | return ", ".join(l) 67 | 68 | 69 | class RaceProcessor(object): 70 | """Processor for the RACE data set.""" 71 | 72 | def __init__(self, use_spm, do_lower_case, high_only, middle_only): 73 | super(RaceProcessor, self).__init__() 74 | self.use_spm = use_spm 75 | self.do_lower_case = do_lower_case 76 | self.high_only = high_only 77 | self.middle_only = middle_only 78 | 79 | def get_train_examples(self, data_dir): 80 | """Gets a collection of `InputExample`s for the train set.""" 81 | return self.read_examples( 82 | os.path.join(data_dir, "RACE", "train")) 83 | 84 | def get_dev_examples(self, data_dir): 85 | """Gets a collection of `InputExample`s for the dev set.""" 86 | return self.read_examples( 87 | os.path.join(data_dir, "RACE", "dev")) 88 | 89 | def get_test_examples(self, data_dir): 90 | """Gets a collection of `InputExample`s for prediction.""" 91 | return self.read_examples( 92 | os.path.join(data_dir, "RACE", "test")) 93 | 94 | def get_labels(self): 95 | """Gets the list of labels for this data set.""" 96 | return ["A", "B", "C", "D"] 97 | 98 | def process_text(self, text): 99 | if self.use_spm: 100 | return tokenization.preprocess_text(text, lower=self.do_lower_case) 101 | else: 102 | return tokenization.convert_to_unicode(text) 103 | 104 | def read_examples(self, data_dir): 105 | """Read examples from RACE json files.""" 106 | examples = [] 107 | for level in ["middle", "high"]: 108 | if level == "middle" and self.high_only: continue 109 | if level == "high" and self.middle_only: continue 110 | cur_dir = os.path.join(data_dir, level) 111 | 112 | cur_path = os.path.join(cur_dir, "all.txt") 113 | with tf.gfile.Open(cur_path) as f: 114 | for line in f: 115 | cur_data = json.loads(line.strip()) 116 | 117 | answers = cur_data["answers"] 118 | options = cur_data["options"] 119 | questions = cur_data["questions"] 120 | context = self.process_text(cur_data["article"]) 121 | 122 | for i in range(len(answers)): 123 | label = ord(answers[i]) - ord("A") 124 | qa_list = [] 125 | 126 | question = self.process_text(questions[i]) 127 | for j in range(4): 128 | option = self.process_text(options[i][j]) 129 | 130 | if "_" in question: 131 | qa_cat = question.replace("_", option) 132 | else: 133 | qa_cat = " ".join([question, option]) 134 | 135 | qa_list.append(qa_cat) 136 | 137 | examples.append( 138 | InputExample( 139 | example_id=cur_data["id"], 140 | context_sentence=context, 141 | start_ending=None, 142 | endings=[qa_list[0], qa_list[1], qa_list[2], qa_list[3]], 143 | label=label 144 | ) 145 | ) 146 | 147 | return examples 148 | 149 | 150 | def convert_single_example(example_index, example, label_size, max_seq_length, 151 | tokenizer, max_qa_length): 152 | """Loads a data file into a list of `InputBatch`s.""" 153 | 154 | # RACE is a multiple choice task. To perform this task using AlBERT, 155 | # we will use the formatting proposed in "Improving Language 156 | # Understanding by Generative Pre-Training" and suggested by 157 | # @jacobdevlin-google in this issue 158 | # https://github.com/google-research/bert/issues/38. 159 | # 160 | # Each choice will correspond to a sample on which we run the 161 | # inference. For a given RACE example, we will create the 4 162 | # following inputs: 163 | # - [CLS] context [SEP] choice_1 [SEP] 164 | # - [CLS] context [SEP] choice_2 [SEP] 165 | # - [CLS] context [SEP] choice_3 [SEP] 166 | # - [CLS] context [SEP] choice_4 [SEP] 167 | # The model will output a single value for each input. To get the 168 | # final decision of the model, we will run a softmax over these 4 169 | # outputs. 170 | if isinstance(example, classifier_utils.PaddingInputExample): 171 | return classifier_utils.InputFeatures( 172 | example_id=0, 173 | input_ids=[[0] * max_seq_length] * label_size, 174 | input_mask=[[0] * max_seq_length] * label_size, 175 | segment_ids=[[0] * max_seq_length] * label_size, 176 | label_id=0, 177 | is_real_example=False) 178 | else: 179 | context_tokens = tokenizer.tokenize(example.context_sentence) 180 | if example.start_ending is not None: 181 | start_ending_tokens = tokenizer.tokenize(example.start_ending) 182 | 183 | all_input_tokens = [] 184 | all_input_ids = [] 185 | all_input_mask = [] 186 | all_segment_ids = [] 187 | for ending in example.endings: 188 | # We create a copy of the context tokens in order to be 189 | # able to shrink it according to ending_tokens 190 | context_tokens_choice = context_tokens[:] 191 | if example.start_ending is not None: 192 | ending_tokens = start_ending_tokens + tokenizer.tokenize(ending) 193 | else: 194 | ending_tokens = tokenizer.tokenize(ending) 195 | # Modifies `context_tokens_choice` and `ending_tokens` in 196 | # place so that the total length is less than the 197 | # specified length. Account for [CLS], [SEP], [SEP] with 198 | # "- 3" 199 | ending_tokens = ending_tokens[- max_qa_length:] 200 | 201 | if len(context_tokens_choice) + len(ending_tokens) > max_seq_length - 3: 202 | context_tokens_choice = context_tokens_choice[: ( 203 | max_seq_length - 3 - len(ending_tokens))] 204 | tokens = ["[CLS]"] + context_tokens_choice + ( 205 | ["[SEP]"] + ending_tokens + ["[SEP]"]) 206 | segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * ( 207 | len(ending_tokens) + 1) 208 | 209 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 210 | input_mask = [1] * len(input_ids) 211 | 212 | # Zero-pad up to the sequence length. 213 | padding = [0] * (max_seq_length - len(input_ids)) 214 | input_ids += padding 215 | input_mask += padding 216 | segment_ids += padding 217 | 218 | assert len(input_ids) == max_seq_length 219 | assert len(input_mask) == max_seq_length 220 | assert len(segment_ids) == max_seq_length 221 | 222 | all_input_tokens.append(tokens) 223 | all_input_ids.append(input_ids) 224 | all_input_mask.append(input_mask) 225 | all_segment_ids.append(segment_ids) 226 | 227 | label = example.label 228 | if example_index < 5: 229 | tf.logging.info("*** Example ***") 230 | tf.logging.info("id: {}".format(example.example_id)) 231 | for choice_idx, (tokens, input_ids, input_mask, segment_ids) in \ 232 | enumerate(zip(all_input_tokens, all_input_ids, all_input_mask, all_segment_ids)): 233 | tf.logging.info("choice: {}".format(choice_idx)) 234 | tf.logging.info("tokens: {}".format(" ".join(tokens))) 235 | tf.logging.info( 236 | "input_ids: {}".format(" ".join(map(str, input_ids)))) 237 | tf.logging.info( 238 | "input_mask: {}".format(" ".join(map(str, input_mask)))) 239 | tf.logging.info( 240 | "segment_ids: {}".format(" ".join(map(str, segment_ids)))) 241 | tf.logging.info("label: {}".format(label)) 242 | 243 | return classifier_utils.InputFeatures( 244 | example_id=example.example_id, 245 | input_ids=all_input_ids, 246 | input_mask=all_input_mask, 247 | segment_ids=all_segment_ids, 248 | label_id=label 249 | ) 250 | 251 | 252 | def file_based_convert_examples_to_features( 253 | examples, label_list, max_seq_length, tokenizer, 254 | output_file, max_qa_length): 255 | """Convert a set of `InputExample`s to a TFRecord file.""" 256 | 257 | writer = tf.python_io.TFRecordWriter(output_file) 258 | 259 | for (ex_index, example) in enumerate(examples): 260 | if ex_index % 10000 == 0: 261 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 262 | 263 | feature = convert_single_example(ex_index, example, len(label_list), 264 | max_seq_length, tokenizer, max_qa_length) 265 | 266 | def create_int_feature(values): 267 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 268 | return f 269 | 270 | features = collections.OrderedDict() 271 | features["input_ids"] = create_int_feature(sum(feature.input_ids, [])) 272 | features["input_mask"] = create_int_feature(sum(feature.input_mask, [])) 273 | features["segment_ids"] = create_int_feature(sum(feature.segment_ids, [])) 274 | features["label_ids"] = create_int_feature([feature.label_id]) 275 | features["is_real_example"] = create_int_feature( 276 | [int(feature.is_real_example)]) 277 | 278 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 279 | writer.write(tf_example.SerializeToString()) 280 | writer.close() 281 | 282 | 283 | def create_model(albert_config, is_training, input_ids, input_mask, segment_ids, 284 | labels, num_labels, use_one_hot_embeddings, max_seq_length, 285 | dropout_prob, hub_module): 286 | """Creates a classification model.""" 287 | bsz_per_core = tf.shape(input_ids)[0] 288 | 289 | input_ids = tf.reshape(input_ids, [bsz_per_core * num_labels, max_seq_length]) 290 | input_mask = tf.reshape(input_mask, 291 | [bsz_per_core * num_labels, max_seq_length]) 292 | token_type_ids = tf.reshape(segment_ids, 293 | [bsz_per_core * num_labels, max_seq_length]) 294 | 295 | (output_layer, _) = fine_tuning_utils.create_albert( 296 | albert_config=albert_config, 297 | is_training=is_training, 298 | input_ids=input_ids, 299 | input_mask=input_mask, 300 | segment_ids=token_type_ids, 301 | use_one_hot_embeddings=use_one_hot_embeddings, 302 | use_einsum=True, 303 | hub_module=hub_module) 304 | 305 | hidden_size = output_layer.shape[-1].value 306 | 307 | output_weights = tf.get_variable( 308 | "output_weights", [1, hidden_size], 309 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 310 | 311 | output_bias = tf.get_variable( 312 | "output_bias", [1], 313 | initializer=tf.zeros_initializer()) 314 | 315 | with tf.variable_scope("loss"): 316 | if is_training: 317 | # I.e., 0.1 dropout 318 | output_layer = tf.nn.dropout( 319 | output_layer, keep_prob=1 - dropout_prob) 320 | 321 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 322 | logits = tf.nn.bias_add(logits, output_bias) 323 | logits = tf.reshape(logits, [bsz_per_core, num_labels]) 324 | probabilities = tf.nn.softmax(logits, axis=-1) 325 | predictions = tf.argmax(probabilities, axis=-1, output_type=tf.int32) 326 | log_probs = tf.nn.log_softmax(logits, axis=-1) 327 | 328 | one_hot_labels = tf.one_hot( 329 | labels, depth=tf.cast(num_labels, dtype=tf.int32), dtype=tf.float32) 330 | 331 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 332 | loss = tf.reduce_mean(per_example_loss) 333 | 334 | return (loss, per_example_loss, probabilities, logits, predictions) 335 | 336 | 337 | def model_fn_builder(albert_config, num_labels, init_checkpoint, learning_rate, 338 | num_train_steps, num_warmup_steps, use_tpu, 339 | use_one_hot_embeddings, max_seq_length, dropout_prob, 340 | hub_module): 341 | """Returns `model_fn` closure for TPUEstimator.""" 342 | 343 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 344 | """The `model_fn` for TPUEstimator.""" 345 | 346 | tf.logging.info("*** Features ***") 347 | for name in sorted(features.keys()): 348 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 349 | 350 | input_ids = features["input_ids"] 351 | input_mask = features["input_mask"] 352 | segment_ids = features["segment_ids"] 353 | label_ids = features["label_ids"] 354 | is_real_example = None 355 | if "is_real_example" in features: 356 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 357 | else: 358 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) 359 | 360 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 361 | 362 | (total_loss, per_example_loss, probabilities, logits, predictions) = \ 363 | create_model(albert_config, is_training, input_ids, input_mask, 364 | segment_ids, label_ids, num_labels, 365 | use_one_hot_embeddings, max_seq_length, dropout_prob, 366 | hub_module) 367 | 368 | tvars = tf.trainable_variables() 369 | initialized_variable_names = {} 370 | scaffold_fn = None 371 | if init_checkpoint: 372 | (assignment_map, initialized_variable_names 373 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 374 | if use_tpu: 375 | 376 | def tpu_scaffold(): 377 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 378 | return tf.train.Scaffold() 379 | 380 | scaffold_fn = tpu_scaffold 381 | else: 382 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 383 | 384 | tf.logging.info("**** Trainable Variables ****") 385 | for var in tvars: 386 | init_string = "" 387 | if var.name in initialized_variable_names: 388 | init_string = ", *INIT_FROM_CKPT*" 389 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 390 | init_string) 391 | 392 | output_spec = None 393 | if mode == tf.estimator.ModeKeys.TRAIN: 394 | 395 | train_op = optimization.create_optimizer( 396 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 397 | 398 | output_spec = contrib_tpu.TPUEstimatorSpec( 399 | mode=mode, 400 | loss=total_loss, 401 | train_op=train_op, 402 | scaffold_fn=scaffold_fn) 403 | elif mode == tf.estimator.ModeKeys.EVAL: 404 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 405 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 406 | accuracy = tf.metrics.accuracy( 407 | labels=label_ids, predictions=predictions, 408 | weights=is_real_example) 409 | loss = tf.metrics.mean( 410 | values=per_example_loss, weights=is_real_example) 411 | return { 412 | "eval_accuracy": accuracy, 413 | "eval_loss": loss, 414 | } 415 | 416 | eval_metrics = (metric_fn, 417 | [per_example_loss, label_ids, logits, is_real_example]) 418 | output_spec = contrib_tpu.TPUEstimatorSpec( 419 | mode=mode, 420 | loss=total_loss, 421 | eval_metrics=eval_metrics, 422 | scaffold_fn=scaffold_fn) 423 | else: 424 | output_spec = contrib_tpu.TPUEstimatorSpec( 425 | mode=mode, 426 | predictions={"probabilities": probabilities, 427 | "predictions": predictions}, 428 | scaffold_fn=scaffold_fn) 429 | return output_spec 430 | 431 | return model_fn 432 | 433 | -------------------------------------------------------------------------------- /run_race.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ALBERT finetuning runner with sentence piece tokenization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import time 23 | from albert import classifier_utils 24 | from albert import fine_tuning_utils 25 | from albert import modeling 26 | from albert import race_utils 27 | import tensorflow.compat.v1 as tf 28 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver 29 | from tensorflow.contrib import tpu as contrib_tpu 30 | 31 | flags = tf.flags 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | ## Required parameters 36 | flags.DEFINE_string( 37 | "data_dir", None, 38 | "The input data dir. Should contain the .tsv files (or other data files) " 39 | "for the task.") 40 | 41 | flags.DEFINE_string( 42 | "albert_config_file", None, 43 | "The config json file corresponding to the pre-trained ALBERT model. " 44 | "This specifies the model architecture.") 45 | 46 | flags.DEFINE_string("task_name", "race", "The name of the task to train.") 47 | 48 | flags.DEFINE_string("vocab_file", None, 49 | "The vocabulary file that the ALBERT model was trained on.") 50 | 51 | flags.DEFINE_string("train_file", None, 52 | "path to preprocessed tfrecord file. " 53 | "The file will be generated if not exst.") 54 | 55 | flags.DEFINE_string("eval_file", None, 56 | "path to preprocessed tfrecord file. " 57 | "The file will be generated if not exst.") 58 | 59 | flags.DEFINE_string("predict_file", None, 60 | "path to preprocessed tfrecord file. " 61 | "The file will be generated if not exst.") 62 | 63 | flags.DEFINE_string("spm_model_file", None, 64 | "The model file for sentence piece tokenization.") 65 | 66 | flags.DEFINE_string( 67 | "output_dir", None, 68 | "The output directory where the model checkpoints will be written.") 69 | 70 | ## Other parameters 71 | 72 | flags.DEFINE_string( 73 | "init_checkpoint", None, 74 | "Initial checkpoint (usually from a pre-trained ALBERT model).") 75 | 76 | flags.DEFINE_string( 77 | "albert_hub_module_handle", None, 78 | "If set, the ALBERT hub module to use.") 79 | 80 | flags.DEFINE_bool( 81 | "do_lower_case", True, 82 | "Whether to lower case the input text. Should be True for uncased " 83 | "models and False for cased models.") 84 | 85 | flags.DEFINE_float("dropout_prob", 0.1, "dropout probability.") 86 | 87 | flags.DEFINE_integer( 88 | "max_seq_length", 512, 89 | "The maximum total input sequence length after WordPiece tokenization. " 90 | "Sequences longer than this will be truncated, and sequences shorter " 91 | "than this will be padded.") 92 | 93 | flags.DEFINE_integer( 94 | "max_qa_length", 128, 95 | "The maximum total input sequence length after WordPiece tokenization. " 96 | "Sequences longer than this will be truncated, and sequences shorter " 97 | "than this will be padded.") 98 | 99 | flags.DEFINE_integer( 100 | "num_keep_checkpoint", 5, 101 | "maximum number of keep checkpoints") 102 | 103 | 104 | flags.DEFINE_bool( 105 | "high_only", False, 106 | "Whether to only run the model on the high school set.") 107 | 108 | flags.DEFINE_bool( 109 | "middle_only", False, 110 | "Whether to only run the model on the middle school set.") 111 | 112 | flags.DEFINE_bool("do_train", True, "Whether to run training.") 113 | 114 | flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.") 115 | 116 | flags.DEFINE_bool( 117 | "do_predict", False, 118 | "Whether to run the model in inference mode on the test set.") 119 | 120 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 121 | 122 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 123 | 124 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 125 | 126 | flags.DEFINE_float("learning_rate", 1e-5, "The initial learning rate for Adam.") 127 | 128 | flags.DEFINE_integer("train_step", 12000, 129 | "Total number of training epochs to perform.") 130 | 131 | flags.DEFINE_integer( 132 | "warmup_step", 1000, 133 | "number of steps to perform linear learning rate warmup for.") 134 | 135 | flags.DEFINE_integer("save_checkpoints_steps", 100, 136 | "How often to save the model checkpoint.") 137 | 138 | flags.DEFINE_integer("iterations_per_loop", 1000, 139 | "How many steps to make in each estimator call.") 140 | 141 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 142 | 143 | tf.flags.DEFINE_string( 144 | "tpu_name", None, 145 | "The Cloud TPU to use for training. This should be either the name " 146 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 147 | "url.") 148 | 149 | tf.flags.DEFINE_string( 150 | "tpu_zone", None, 151 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 152 | "specified, we will attempt to automatically detect the GCE project from " 153 | "metadata.") 154 | 155 | tf.flags.DEFINE_string( 156 | "gcp_project", None, 157 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 158 | "specified, we will attempt to automatically detect the GCE project from " 159 | "metadata.") 160 | 161 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 162 | 163 | flags.DEFINE_integer( 164 | "num_tpu_cores", 8, 165 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 166 | 167 | 168 | def main(_): 169 | tf.logging.set_verbosity(tf.logging.INFO) 170 | 171 | processors = { 172 | "race": race_utils.RaceProcessor 173 | } 174 | 175 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 176 | raise ValueError( 177 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 178 | 179 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) 180 | 181 | if FLAGS.max_seq_length > albert_config.max_position_embeddings: 182 | raise ValueError( 183 | "Cannot use sequence length %d because the ALBERT model " 184 | "was only trained up to sequence length %d" % 185 | (FLAGS.max_seq_length, albert_config.max_position_embeddings)) 186 | 187 | tf.gfile.MakeDirs(FLAGS.output_dir) 188 | 189 | task_name = FLAGS.task_name.lower() 190 | 191 | if task_name not in processors: 192 | raise ValueError("Task not found: %s" % (task_name)) 193 | 194 | processor = processors[task_name]( 195 | use_spm=True if FLAGS.spm_model_file else False, 196 | do_lower_case=FLAGS.do_lower_case, 197 | high_only=FLAGS.high_only, 198 | middle_only=FLAGS.middle_only) 199 | 200 | label_list = processor.get_labels() 201 | 202 | tokenizer = fine_tuning_utils.create_vocab( 203 | vocab_file=FLAGS.vocab_file, 204 | do_lower_case=FLAGS.do_lower_case, 205 | spm_model_file=FLAGS.spm_model_file, 206 | hub_module=FLAGS.albert_hub_module_handle) 207 | 208 | tpu_cluster_resolver = None 209 | if FLAGS.use_tpu and FLAGS.tpu_name: 210 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( 211 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 212 | 213 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 214 | if FLAGS.do_train: 215 | iterations_per_loop = int(min(FLAGS.iterations_per_loop, 216 | FLAGS.save_checkpoints_steps)) 217 | else: 218 | iterations_per_loop = FLAGS.iterations_per_loop 219 | run_config = contrib_tpu.RunConfig( 220 | cluster=tpu_cluster_resolver, 221 | master=FLAGS.master, 222 | model_dir=FLAGS.output_dir, 223 | save_checkpoints_steps=int(FLAGS.save_checkpoints_steps), 224 | keep_checkpoint_max=0, 225 | tpu_config=contrib_tpu.TPUConfig( 226 | iterations_per_loop=iterations_per_loop, 227 | num_shards=FLAGS.num_tpu_cores, 228 | per_host_input_for_training=is_per_host)) 229 | 230 | train_examples = None 231 | if FLAGS.do_train: 232 | train_examples = processor.get_train_examples(FLAGS.data_dir) 233 | 234 | model_fn = race_utils.model_fn_builder( 235 | albert_config=albert_config, 236 | num_labels=len(label_list), 237 | init_checkpoint=FLAGS.init_checkpoint, 238 | learning_rate=FLAGS.learning_rate, 239 | num_train_steps=FLAGS.train_step, 240 | num_warmup_steps=FLAGS.warmup_step, 241 | use_tpu=FLAGS.use_tpu, 242 | use_one_hot_embeddings=FLAGS.use_tpu, 243 | max_seq_length=FLAGS.max_seq_length, 244 | dropout_prob=FLAGS.dropout_prob, 245 | hub_module=FLAGS.albert_hub_module_handle) 246 | 247 | # If TPU is not available, this will fall back to normal Estimator on CPU 248 | # or GPU. 249 | estimator = contrib_tpu.TPUEstimator( 250 | use_tpu=FLAGS.use_tpu, 251 | model_fn=model_fn, 252 | config=run_config, 253 | train_batch_size=FLAGS.train_batch_size, 254 | eval_batch_size=FLAGS.eval_batch_size, 255 | predict_batch_size=FLAGS.predict_batch_size) 256 | 257 | if FLAGS.do_train: 258 | if not tf.gfile.Exists(FLAGS.train_file): 259 | race_utils.file_based_convert_examples_to_features( 260 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, 261 | FLAGS.train_file, FLAGS.max_qa_length) 262 | tf.logging.info("***** Running training *****") 263 | tf.logging.info(" Num examples = %d", len(train_examples)) 264 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 265 | tf.logging.info(" Num steps = %d", FLAGS.train_step) 266 | train_input_fn = classifier_utils.file_based_input_fn_builder( 267 | input_file=FLAGS.train_file, 268 | seq_length=FLAGS.max_seq_length, 269 | is_training=True, 270 | drop_remainder=True, 271 | task_name=task_name, 272 | use_tpu=FLAGS.use_tpu, 273 | bsz=FLAGS.train_batch_size, 274 | multiple=len(label_list)) 275 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_step) 276 | 277 | if FLAGS.do_eval: 278 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 279 | num_actual_eval_examples = len(eval_examples) 280 | if FLAGS.use_tpu: 281 | # TPU requires a fixed batch size for all batches, therefore the number 282 | # of examples must be a multiple of the batch size, or else examples 283 | # will get dropped. So we pad with fake examples which are ignored 284 | # later on. These do NOT count towards the metric (all tf.metrics 285 | # support a per-instance weight, and these get a weight of 0.0). 286 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 287 | eval_examples.append(classifier_utils.PaddingInputExample()) 288 | 289 | if not tf.gfile.Exists(FLAGS.eval_file): 290 | race_utils.file_based_convert_examples_to_features( 291 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, 292 | FLAGS.eval_file, FLAGS.max_qa_length) 293 | 294 | tf.logging.info("***** Running evaluation *****") 295 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 296 | len(eval_examples), num_actual_eval_examples, 297 | len(eval_examples) - num_actual_eval_examples) 298 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 299 | 300 | # This tells the estimator to run through the entire set. 301 | eval_steps = None 302 | # However, if running eval on the TPU, you will need to specify the 303 | # number of steps. 304 | if FLAGS.use_tpu: 305 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 306 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 307 | 308 | eval_drop_remainder = True if FLAGS.use_tpu else False 309 | eval_input_fn = classifier_utils.file_based_input_fn_builder( 310 | input_file=FLAGS.eval_file, 311 | seq_length=FLAGS.max_seq_length, 312 | is_training=False, 313 | drop_remainder=eval_drop_remainder, 314 | task_name=task_name, 315 | use_tpu=FLAGS.use_tpu, 316 | bsz=FLAGS.eval_batch_size, 317 | multiple=len(label_list)) 318 | 319 | def _find_valid_cands(curr_step): 320 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 321 | candidates = [] 322 | for filename in filenames: 323 | if filename.endswith(".index"): 324 | ckpt_name = filename[:-6] 325 | idx = ckpt_name.split("-")[-1] 326 | if idx != "best" and int(idx) > curr_step: 327 | candidates.append(filename) 328 | return candidates 329 | 330 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 331 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 332 | key_name = "eval_accuracy" 333 | if tf.gfile.Exists(checkpoint_path + ".index"): 334 | result = estimator.evaluate( 335 | input_fn=eval_input_fn, 336 | steps=eval_steps, 337 | checkpoint_path=checkpoint_path) 338 | best_perf = result[key_name] 339 | global_step = result["global_step"] 340 | else: 341 | global_step = -1 342 | best_perf = -1 343 | checkpoint_path = None 344 | writer = tf.gfile.GFile(output_eval_file, "w") 345 | while global_step < FLAGS.train_step: 346 | steps_and_files = {} 347 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 348 | for filename in filenames: 349 | if filename.endswith(".index"): 350 | ckpt_name = filename[:-6] 351 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 352 | if cur_filename.split("-")[-1] == "best": 353 | continue 354 | gstep = int(cur_filename.split("-")[-1]) 355 | if gstep not in steps_and_files: 356 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 357 | steps_and_files[gstep] = cur_filename 358 | tf.logging.info("found {} files.".format(len(steps_and_files))) 359 | # steps_and_files = sorted(steps_and_files, key=lambda x: x[0]) 360 | if not steps_and_files: 361 | tf.logging.info("found 0 file, global step: {}. Sleeping." 362 | .format(global_step)) 363 | time.sleep(1) 364 | else: 365 | for ele in sorted(steps_and_files.items()): 366 | step, checkpoint_path = ele 367 | if global_step >= step: 368 | if len(_find_valid_cands(step)) > 1: 369 | for ext in ["meta", "data-00000-of-00001", "index"]: 370 | src_ckpt = checkpoint_path + ".{}".format(ext) 371 | tf.logging.info("removing {}".format(src_ckpt)) 372 | tf.gfile.Remove(src_ckpt) 373 | continue 374 | result = estimator.evaluate( 375 | input_fn=eval_input_fn, 376 | steps=eval_steps, 377 | checkpoint_path=checkpoint_path) 378 | global_step = result["global_step"] 379 | tf.logging.info("***** Eval results *****") 380 | for key in sorted(result.keys()): 381 | tf.logging.info(" %s = %s", key, str(result[key])) 382 | writer.write("%s = %s\n" % (key, str(result[key]))) 383 | writer.write("best = {}\n".format(best_perf)) 384 | if result[key_name] > best_perf: 385 | best_perf = result[key_name] 386 | for ext in ["meta", "data-00000-of-00001", "index"]: 387 | src_ckpt = checkpoint_path + ".{}".format(ext) 388 | tgt_ckpt = checkpoint_path.rsplit("-", 1)[0] + "-best.{}".format(ext) 389 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt)) 390 | tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True) 391 | writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt)) 392 | 393 | if len(_find_valid_cands(global_step)) > 1: 394 | for ext in ["meta", "data-00000-of-00001", "index"]: 395 | src_ckpt = checkpoint_path + ".{}".format(ext) 396 | tf.logging.info("removing {}".format(src_ckpt)) 397 | tf.gfile.Remove(src_ckpt) 398 | writer.write("=" * 50 + "\n") 399 | writer.close() 400 | if FLAGS.do_predict: 401 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 402 | num_actual_predict_examples = len(predict_examples) 403 | if FLAGS.use_tpu: 404 | # TPU requires a fixed batch size for all batches, therefore the number 405 | # of examples must be a multiple of the batch size, or else examples 406 | # will get dropped. So we pad with fake examples which are ignored 407 | # later on. 408 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 409 | predict_examples.append(classifier_utils.PaddingInputExample()) 410 | assert len(predict_examples) % FLAGS.predict_batch_size == 0 411 | predict_steps = int(len(predict_examples) // FLAGS.predict_batch_size) 412 | 413 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 414 | race_utils.file_based_convert_examples_to_features( 415 | predict_examples, label_list, 416 | FLAGS.max_seq_length, tokenizer, 417 | predict_file, FLAGS.max_qa_length) 418 | 419 | tf.logging.info("***** Running prediction*****") 420 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 421 | len(predict_examples), num_actual_predict_examples, 422 | len(predict_examples) - num_actual_predict_examples) 423 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 424 | 425 | predict_drop_remainder = True if FLAGS.use_tpu else False 426 | predict_input_fn = classifier_utils.file_based_input_fn_builder( 427 | input_file=predict_file, 428 | seq_length=FLAGS.max_seq_length, 429 | is_training=False, 430 | drop_remainder=predict_drop_remainder, 431 | task_name=task_name, 432 | use_tpu=FLAGS.use_tpu, 433 | bsz=FLAGS.predict_batch_size, 434 | multiple=len(label_list)) 435 | 436 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 437 | result = estimator.evaluate( 438 | input_fn=predict_input_fn, 439 | steps=predict_steps, 440 | checkpoint_path=checkpoint_path) 441 | 442 | output_predict_file = os.path.join(FLAGS.output_dir, "predict_results.txt") 443 | with tf.gfile.GFile(output_predict_file, "w") as pred_writer: 444 | # num_written_lines = 0 445 | tf.logging.info("***** Predict results *****") 446 | pred_writer.write("***** Predict results *****\n") 447 | for key in sorted(result.keys()): 448 | tf.logging.info(" %s = %s", key, str(result[key])) 449 | pred_writer.write("%s = %s\n" % (key, str(result[key]))) 450 | pred_writer.write("best = {}\n".format(best_perf)) 451 | 452 | 453 | if __name__ == "__main__": 454 | flags.mark_flag_as_required("data_dir") 455 | flags.mark_flag_as_required("spm_model_file") 456 | flags.mark_flag_as_required("albert_config_file") 457 | flags.mark_flag_as_required("output_dir") 458 | tf.app.run() 459 | -------------------------------------------------------------------------------- /run_squad_v2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | """Run ALBERT on SQuAD v2.0 using sentence piece tokenization.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import json 24 | import os 25 | import random 26 | import time 27 | 28 | from albert import fine_tuning_utils 29 | from albert import modeling 30 | from albert import squad_utils 31 | import six 32 | import tensorflow.compat.v1 as tf 33 | 34 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver 35 | from tensorflow.contrib import tpu as contrib_tpu 36 | 37 | 38 | # pylint: disable=g-import-not-at-top 39 | if six.PY2: 40 | import six.moves.cPickle as pickle 41 | else: 42 | import pickle 43 | # pylint: enable=g-import-not-at-top 44 | 45 | flags = tf.flags 46 | 47 | FLAGS = flags.FLAGS 48 | 49 | ## Required parameters 50 | flags.DEFINE_string( 51 | "albert_config_file", None, 52 | "The config json file corresponding to the pre-trained ALBERT model. " 53 | "This specifies the model architecture.") 54 | 55 | flags.DEFINE_string("vocab_file", None, 56 | "The vocabulary file that the ALBERT model was trained on.") 57 | 58 | flags.DEFINE_string("spm_model_file", None, 59 | "The model file for sentence piece tokenization.") 60 | 61 | flags.DEFINE_string( 62 | "output_dir", None, 63 | "The output directory where the model checkpoints will be written.") 64 | 65 | ## Other parameters 66 | flags.DEFINE_string("train_file", None, 67 | "SQuAD json for training. E.g., train-v1.1.json") 68 | 69 | flags.DEFINE_string( 70 | "predict_file", None, 71 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 72 | 73 | flags.DEFINE_string("train_feature_file", None, 74 | "training feature file.") 75 | 76 | flags.DEFINE_string( 77 | "predict_feature_file", None, 78 | "Location of predict features. If it doesn't exist, it will be written. " 79 | "If it does exist, it will be read.") 80 | 81 | flags.DEFINE_string( 82 | "predict_feature_left_file", None, 83 | "Location of predict features not passed to TPU. If it doesn't exist, it " 84 | "will be written. If it does exist, it will be read.") 85 | 86 | flags.DEFINE_string( 87 | "init_checkpoint", None, 88 | "Initial checkpoint (usually from a pre-trained BERT model).") 89 | 90 | flags.DEFINE_string( 91 | "albert_hub_module_handle", None, 92 | "If set, the ALBERT hub module to use.") 93 | 94 | flags.DEFINE_bool( 95 | "do_lower_case", True, 96 | "Whether to lower case the input text. Should be True for uncased " 97 | "models and False for cased models.") 98 | 99 | flags.DEFINE_integer( 100 | "max_seq_length", 384, 101 | "The maximum total input sequence length after WordPiece tokenization. " 102 | "Sequences longer than this will be truncated, and sequences shorter " 103 | "than this will be padded.") 104 | 105 | flags.DEFINE_integer( 106 | "doc_stride", 128, 107 | "When splitting up a long document into chunks, how much stride to " 108 | "take between chunks.") 109 | 110 | flags.DEFINE_integer( 111 | "max_query_length", 64, 112 | "The maximum number of tokens for the question. Questions longer than " 113 | "this will be truncated to this length.") 114 | 115 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 116 | 117 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.") 118 | 119 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 120 | 121 | flags.DEFINE_integer("predict_batch_size", 8, 122 | "Total batch size for predictions.") 123 | 124 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 125 | 126 | flags.DEFINE_float("num_train_epochs", 3.0, 127 | "Total number of training epochs to perform.") 128 | 129 | flags.DEFINE_float( 130 | "warmup_proportion", 0.1, 131 | "Proportion of training to perform linear learning rate warmup for. " 132 | "E.g., 0.1 = 10% of training.") 133 | 134 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 135 | "How often to save the model checkpoint.") 136 | 137 | flags.DEFINE_integer("iterations_per_loop", 1000, 138 | "How many steps to make in each estimator call.") 139 | 140 | flags.DEFINE_integer( 141 | "n_best_size", 20, 142 | "The total number of n-best predictions to generate in the " 143 | "nbest_predictions.json output file.") 144 | 145 | flags.DEFINE_integer( 146 | "max_answer_length", 30, 147 | "The maximum length of an answer that can be generated. This is needed " 148 | "because the start and end predictions are not conditioned on one another.") 149 | 150 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 151 | 152 | tf.flags.DEFINE_string( 153 | "tpu_name", None, 154 | "The Cloud TPU to use for training. This should be either the name " 155 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 156 | "url.") 157 | 158 | tf.flags.DEFINE_string( 159 | "tpu_zone", None, 160 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 161 | "specified, we will attempt to automatically detect the GCE project from " 162 | "metadata.") 163 | 164 | tf.flags.DEFINE_string( 165 | "gcp_project", None, 166 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 167 | "specified, we will attempt to automatically detect the GCE project from " 168 | "metadata.") 169 | 170 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 171 | 172 | flags.DEFINE_integer( 173 | "num_tpu_cores", 8, 174 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 175 | 176 | 177 | flags.DEFINE_integer("start_n_top", 5, "beam size for the start positions.") 178 | 179 | flags.DEFINE_integer("end_n_top", 5, "beam size for the end positions.") 180 | 181 | flags.DEFINE_float("dropout_prob", 0.1, "dropout probability.") 182 | 183 | 184 | def validate_flags_or_throw(albert_config): 185 | """Validate the input FLAGS or throw an exception.""" 186 | 187 | if not FLAGS.do_train and not FLAGS.do_predict: 188 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 189 | 190 | if FLAGS.do_train: 191 | if not FLAGS.train_file: 192 | raise ValueError( 193 | "If `do_train` is True, then `train_file` must be specified.") 194 | if FLAGS.do_predict: 195 | if not FLAGS.predict_file: 196 | raise ValueError( 197 | "If `do_predict` is True, then `predict_file` must be specified.") 198 | if not FLAGS.predict_feature_file: 199 | raise ValueError( 200 | "If `do_predict` is True, then `predict_feature_file` must be " 201 | "specified.") 202 | if not FLAGS.predict_feature_left_file: 203 | raise ValueError( 204 | "If `do_predict` is True, then `predict_feature_left_file` must be " 205 | "specified.") 206 | 207 | if FLAGS.max_seq_length > albert_config.max_position_embeddings: 208 | raise ValueError( 209 | "Cannot use sequence length %d because the ALBERT model " 210 | "was only trained up to sequence length %d" % 211 | (FLAGS.max_seq_length, albert_config.max_position_embeddings)) 212 | 213 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3: 214 | raise ValueError( 215 | "The max_seq_length (%d) must be greater than max_query_length " 216 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)) 217 | 218 | 219 | def main(_): 220 | tf.logging.set_verbosity(tf.logging.INFO) 221 | 222 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) 223 | 224 | validate_flags_or_throw(albert_config) 225 | 226 | tf.gfile.MakeDirs(FLAGS.output_dir) 227 | 228 | tokenizer = fine_tuning_utils.create_vocab( 229 | vocab_file=FLAGS.vocab_file, 230 | do_lower_case=FLAGS.do_lower_case, 231 | spm_model_file=FLAGS.spm_model_file, 232 | hub_module=FLAGS.albert_hub_module_handle) 233 | 234 | tpu_cluster_resolver = None 235 | if FLAGS.use_tpu and FLAGS.tpu_name: 236 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( 237 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 238 | 239 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 240 | if FLAGS.do_train: 241 | iterations_per_loop = int(min(FLAGS.iterations_per_loop, 242 | FLAGS.save_checkpoints_steps)) 243 | else: 244 | iterations_per_loop = FLAGS.iterations_per_loop 245 | run_config = contrib_tpu.RunConfig( 246 | cluster=tpu_cluster_resolver, 247 | master=FLAGS.master, 248 | model_dir=FLAGS.output_dir, 249 | keep_checkpoint_max=0, 250 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 251 | tpu_config=contrib_tpu.TPUConfig( 252 | iterations_per_loop=iterations_per_loop, 253 | num_shards=FLAGS.num_tpu_cores, 254 | per_host_input_for_training=is_per_host)) 255 | 256 | train_examples = None 257 | num_train_steps = None 258 | num_warmup_steps = None 259 | train_examples = squad_utils.read_squad_examples( 260 | input_file=FLAGS.train_file, is_training=True) 261 | num_train_steps = int( 262 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 263 | if FLAGS.do_train: 264 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 265 | 266 | # Pre-shuffle the input to avoid having to make a very large shuffle 267 | # buffer in in the `input_fn`. 268 | rng = random.Random(12345) 269 | rng.shuffle(train_examples) 270 | 271 | model_fn = squad_utils.v2_model_fn_builder( 272 | albert_config=albert_config, 273 | init_checkpoint=FLAGS.init_checkpoint, 274 | learning_rate=FLAGS.learning_rate, 275 | num_train_steps=num_train_steps, 276 | num_warmup_steps=num_warmup_steps, 277 | use_tpu=FLAGS.use_tpu, 278 | use_one_hot_embeddings=FLAGS.use_tpu, 279 | max_seq_length=FLAGS.max_seq_length, 280 | start_n_top=FLAGS.start_n_top, 281 | end_n_top=FLAGS.end_n_top, 282 | dropout_prob=FLAGS.dropout_prob, 283 | hub_module=FLAGS.albert_hub_module_handle) 284 | 285 | # If TPU is not available, this will fall back to normal Estimator on CPU 286 | # or GPU. 287 | estimator = contrib_tpu.TPUEstimator( 288 | use_tpu=FLAGS.use_tpu, 289 | model_fn=model_fn, 290 | config=run_config, 291 | train_batch_size=FLAGS.train_batch_size, 292 | predict_batch_size=FLAGS.predict_batch_size) 293 | 294 | if FLAGS.do_train: 295 | # We write to a temporary file to avoid storing very large constant tensors 296 | # in memory. 297 | 298 | if not tf.gfile.Exists(FLAGS.train_feature_file): 299 | train_writer = squad_utils.FeatureWriter( 300 | filename=os.path.join(FLAGS.train_feature_file), is_training=True) 301 | squad_utils.convert_examples_to_features( 302 | examples=train_examples, 303 | tokenizer=tokenizer, 304 | max_seq_length=FLAGS.max_seq_length, 305 | doc_stride=FLAGS.doc_stride, 306 | max_query_length=FLAGS.max_query_length, 307 | is_training=True, 308 | output_fn=train_writer.process_feature, 309 | do_lower_case=FLAGS.do_lower_case) 310 | train_writer.close() 311 | 312 | tf.logging.info("***** Running training *****") 313 | tf.logging.info(" Num orig examples = %d", len(train_examples)) 314 | # tf.logging.info(" Num split examples = %d", train_writer.num_features) 315 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 316 | tf.logging.info(" Num steps = %d", num_train_steps) 317 | del train_examples 318 | 319 | train_input_fn = squad_utils.input_fn_builder( 320 | input_file=FLAGS.train_feature_file, 321 | seq_length=FLAGS.max_seq_length, 322 | is_training=True, 323 | drop_remainder=True, 324 | use_tpu=FLAGS.use_tpu, 325 | bsz=FLAGS.train_batch_size, 326 | is_v2=True) 327 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 328 | 329 | if FLAGS.do_predict: 330 | with tf.gfile.Open(FLAGS.predict_file) as predict_file: 331 | prediction_json = json.load(predict_file)["data"] 332 | eval_examples = squad_utils.read_squad_examples( 333 | input_file=FLAGS.predict_file, is_training=False) 334 | 335 | if (tf.gfile.Exists(FLAGS.predict_feature_file) and tf.gfile.Exists( 336 | FLAGS.predict_feature_left_file)): 337 | tf.logging.info("Loading eval features from {}".format( 338 | FLAGS.predict_feature_left_file)) 339 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin: 340 | eval_features = pickle.load(fin) 341 | else: 342 | eval_writer = squad_utils.FeatureWriter( 343 | filename=FLAGS.predict_feature_file, is_training=False) 344 | eval_features = [] 345 | 346 | def append_feature(feature): 347 | eval_features.append(feature) 348 | eval_writer.process_feature(feature) 349 | 350 | squad_utils.convert_examples_to_features( 351 | examples=eval_examples, 352 | tokenizer=tokenizer, 353 | max_seq_length=FLAGS.max_seq_length, 354 | doc_stride=FLAGS.doc_stride, 355 | max_query_length=FLAGS.max_query_length, 356 | is_training=False, 357 | output_fn=append_feature, 358 | do_lower_case=FLAGS.do_lower_case) 359 | eval_writer.close() 360 | 361 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout: 362 | pickle.dump(eval_features, fout) 363 | 364 | tf.logging.info("***** Running predictions *****") 365 | tf.logging.info(" Num orig examples = %d", len(eval_examples)) 366 | tf.logging.info(" Num split examples = %d", len(eval_features)) 367 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 368 | 369 | predict_input_fn = squad_utils.input_fn_builder( 370 | input_file=FLAGS.predict_feature_file, 371 | seq_length=FLAGS.max_seq_length, 372 | is_training=False, 373 | drop_remainder=False, 374 | use_tpu=FLAGS.use_tpu, 375 | bsz=FLAGS.predict_batch_size, 376 | is_v2=True) 377 | 378 | def get_result(checkpoint): 379 | """Evaluate the checkpoint on SQuAD v2.0.""" 380 | # If running eval on the TPU, you will need to specify the number of 381 | # steps. 382 | reader = tf.train.NewCheckpointReader(checkpoint) 383 | global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP) 384 | all_results = [] 385 | for result in estimator.predict( 386 | predict_input_fn, yield_single_examples=True, 387 | checkpoint_path=checkpoint): 388 | if len(all_results) % 1000 == 0: 389 | tf.logging.info("Processing example: %d" % (len(all_results))) 390 | unique_id = int(result["unique_ids"]) 391 | start_top_log_probs = ( 392 | [float(x) for x in result["start_top_log_probs"].flat]) 393 | start_top_index = [int(x) for x in result["start_top_index"].flat] 394 | end_top_log_probs = ( 395 | [float(x) for x in result["end_top_log_probs"].flat]) 396 | end_top_index = [int(x) for x in result["end_top_index"].flat] 397 | 398 | cls_logits = float(result["cls_logits"].flat[0]) 399 | all_results.append( 400 | squad_utils.RawResultV2( 401 | unique_id=unique_id, 402 | start_top_log_probs=start_top_log_probs, 403 | start_top_index=start_top_index, 404 | end_top_log_probs=end_top_log_probs, 405 | end_top_index=end_top_index, 406 | cls_logits=cls_logits)) 407 | 408 | output_prediction_file = os.path.join( 409 | FLAGS.output_dir, "predictions.json") 410 | output_nbest_file = os.path.join( 411 | FLAGS.output_dir, "nbest_predictions.json") 412 | output_null_log_odds_file = os.path.join( 413 | FLAGS.output_dir, "null_odds.json") 414 | 415 | result_dict = {} 416 | cls_dict = {} 417 | squad_utils.accumulate_predictions_v2( 418 | result_dict, cls_dict, eval_examples, eval_features, 419 | all_results, FLAGS.n_best_size, FLAGS.max_answer_length, 420 | FLAGS.start_n_top, FLAGS.end_n_top) 421 | 422 | return squad_utils.evaluate_v2( 423 | result_dict, cls_dict, prediction_json, eval_examples, 424 | eval_features, all_results, FLAGS.n_best_size, 425 | FLAGS.max_answer_length, output_prediction_file, output_nbest_file, 426 | output_null_log_odds_file), int(global_step) 427 | 428 | def _find_valid_cands(curr_step): 429 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 430 | candidates = [] 431 | for filename in filenames: 432 | if filename.endswith(".index"): 433 | ckpt_name = filename[:-6] 434 | idx = ckpt_name.split("-")[-1] 435 | if idx != "best" and int(idx) > curr_step: 436 | candidates.append(filename) 437 | return candidates 438 | 439 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 440 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 441 | key_name = "f1" 442 | writer = tf.gfile.GFile(output_eval_file, "w") 443 | if tf.gfile.Exists(checkpoint_path + ".index"): 444 | result = get_result(checkpoint_path) 445 | best_perf = result[0][key_name] 446 | global_step = result[1] 447 | else: 448 | global_step = -1 449 | best_perf = -1 450 | checkpoint_path = None 451 | while global_step < num_train_steps: 452 | steps_and_files = {} 453 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 454 | for filename in filenames: 455 | if filename.endswith(".index"): 456 | ckpt_name = filename[:-6] 457 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 458 | if cur_filename.split("-")[-1] == "best": 459 | continue 460 | gstep = int(cur_filename.split("-")[-1]) 461 | if gstep not in steps_and_files: 462 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 463 | steps_and_files[gstep] = cur_filename 464 | tf.logging.info("found {} files.".format(len(steps_and_files))) 465 | if not steps_and_files: 466 | tf.logging.info("found 0 file, global step: {}. Sleeping." 467 | .format(global_step)) 468 | time.sleep(60) 469 | else: 470 | for ele in sorted(steps_and_files.items()): 471 | step, checkpoint_path = ele 472 | if global_step >= step: 473 | if len(_find_valid_cands(step)) > 1: 474 | for ext in ["meta", "data-00000-of-00001", "index"]: 475 | src_ckpt = checkpoint_path + ".{}".format(ext) 476 | tf.logging.info("removing {}".format(src_ckpt)) 477 | tf.gfile.Remove(src_ckpt) 478 | continue 479 | result, global_step = get_result(checkpoint_path) 480 | tf.logging.info("***** Eval results *****") 481 | for key in sorted(result.keys()): 482 | tf.logging.info(" %s = %s", key, str(result[key])) 483 | writer.write("%s = %s\n" % (key, str(result[key]))) 484 | if result[key_name] > best_perf: 485 | best_perf = result[key_name] 486 | for ext in ["meta", "data-00000-of-00001", "index"]: 487 | src_ckpt = checkpoint_path + ".{}".format(ext) 488 | tgt_ckpt = checkpoint_path.rsplit( 489 | "-", 1)[0] + "-best.{}".format(ext) 490 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt)) 491 | tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True) 492 | writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt)) 493 | writer.write("best {} = {}\n".format(key_name, best_perf)) 494 | tf.logging.info(" best {} = {}\n".format(key_name, best_perf)) 495 | 496 | if len(_find_valid_cands(global_step)) > 2: 497 | for ext in ["meta", "data-00000-of-00001", "index"]: 498 | src_ckpt = checkpoint_path + ".{}".format(ext) 499 | tf.logging.info("removing {}".format(src_ckpt)) 500 | tf.gfile.Remove(src_ckpt) 501 | writer.write("=" * 50 + "\n") 502 | 503 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 504 | result, global_step = get_result(checkpoint_path) 505 | tf.logging.info("***** Final Eval results *****") 506 | for key in sorted(result.keys()): 507 | tf.logging.info(" %s = %s", key, str(result[key])) 508 | writer.write("%s = %s\n" % (key, str(result[key]))) 509 | writer.write("best perf happened at step: {}".format(global_step)) 510 | 511 | 512 | if __name__ == "__main__": 513 | flags.mark_flag_as_required("spm_model_file") 514 | flags.mark_flag_as_required("albert_config_file") 515 | flags.mark_flag_as_required("output_dir") 516 | tf.app.run() 517 | -------------------------------------------------------------------------------- /run_squad_v1.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | """Run ALBERT on SQuAD v1.1 using sentence piece tokenization.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import json 24 | import os 25 | import random 26 | import time 27 | from albert import fine_tuning_utils 28 | from albert import modeling 29 | from albert import squad_utils 30 | import six 31 | import tensorflow.compat.v1 as tf 32 | 33 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver 34 | from tensorflow.contrib import tpu as contrib_tpu 35 | 36 | 37 | # pylint: disable=g-import-not-at-top 38 | if six.PY2: 39 | import six.moves.cPickle as pickle 40 | else: 41 | import pickle 42 | # pylint: enable=g-import-not-at-top 43 | 44 | flags = tf.flags 45 | 46 | FLAGS = flags.FLAGS 47 | 48 | ## Required parameters 49 | flags.DEFINE_string( 50 | "albert_config_file", None, 51 | "The config json file corresponding to the pre-trained BERT model. " 52 | "This specifies the model architecture.") 53 | 54 | flags.DEFINE_string("vocab_file", None, 55 | "The vocabulary file that the BERT model was trained on.") 56 | 57 | flags.DEFINE_string("spm_model_file", None, 58 | "The model file for sentence piece tokenization.") 59 | 60 | flags.DEFINE_string( 61 | "output_dir", None, 62 | "The output directory where the model checkpoints will be written.") 63 | 64 | ## Other parameters 65 | flags.DEFINE_string("train_file", None, 66 | "SQuAD json for training. E.g., train-v1.1.json") 67 | 68 | flags.DEFINE_string( 69 | "predict_file", None, 70 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 71 | 72 | flags.DEFINE_string("train_feature_file", None, 73 | "training feature file.") 74 | 75 | flags.DEFINE_string( 76 | "predict_feature_file", None, 77 | "Location of predict features. If it doesn't exist, it will be written. " 78 | "If it does exist, it will be read.") 79 | 80 | flags.DEFINE_string( 81 | "predict_feature_left_file", None, 82 | "Location of predict features not passed to TPU. If it doesn't exist, it " 83 | "will be written. If it does exist, it will be read.") 84 | 85 | flags.DEFINE_string( 86 | "init_checkpoint", None, 87 | "Initial checkpoint (usually from a pre-trained BERT model).") 88 | 89 | flags.DEFINE_string( 90 | "albert_hub_module_handle", None, 91 | "If set, the ALBERT hub module to use.") 92 | 93 | flags.DEFINE_bool( 94 | "do_lower_case", True, 95 | "Whether to lower case the input text. Should be True for uncased " 96 | "models and False for cased models.") 97 | 98 | flags.DEFINE_integer( 99 | "max_seq_length", 384, 100 | "The maximum total input sequence length after WordPiece tokenization. " 101 | "Sequences longer than this will be truncated, and sequences shorter " 102 | "than this will be padded.") 103 | 104 | flags.DEFINE_integer( 105 | "doc_stride", 128, 106 | "When splitting up a long document into chunks, how much stride to " 107 | "take between chunks.") 108 | 109 | flags.DEFINE_integer( 110 | "max_query_length", 64, 111 | "The maximum number of tokens for the question. Questions longer than " 112 | "this will be truncated to this length.") 113 | 114 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 115 | 116 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.") 117 | 118 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 119 | 120 | flags.DEFINE_integer("predict_batch_size", 8, 121 | "Total batch size for predictions.") 122 | 123 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 124 | 125 | flags.DEFINE_float("num_train_epochs", 3.0, 126 | "Total number of training epochs to perform.") 127 | 128 | flags.DEFINE_float( 129 | "warmup_proportion", 0.1, 130 | "Proportion of training to perform linear learning rate warmup for. " 131 | "E.g., 0.1 = 10% of training.") 132 | 133 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 134 | "How often to save the model checkpoint.") 135 | 136 | flags.DEFINE_integer("iterations_per_loop", 1000, 137 | "How many steps to make in each estimator call.") 138 | 139 | flags.DEFINE_integer( 140 | "n_best_size", 20, 141 | "The total number of n-best predictions to generate in the " 142 | "nbest_predictions.json output file.") 143 | 144 | flags.DEFINE_integer( 145 | "max_answer_length", 30, 146 | "The maximum length of an answer that can be generated. This is needed " 147 | "because the start and end predictions are not conditioned on one another.") 148 | 149 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 150 | 151 | tf.flags.DEFINE_string( 152 | "tpu_name", None, 153 | "The Cloud TPU to use for training. This should be either the name " 154 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 155 | "url.") 156 | 157 | tf.flags.DEFINE_string( 158 | "tpu_zone", None, 159 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 160 | "specified, we will attempt to automatically detect the GCE project from " 161 | "metadata.") 162 | 163 | tf.flags.DEFINE_string( 164 | "gcp_project", None, 165 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 166 | "specified, we will attempt to automatically detect the GCE project from " 167 | "metadata.") 168 | 169 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 170 | 171 | flags.DEFINE_integer( 172 | "num_tpu_cores", 8, 173 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 174 | 175 | flags.DEFINE_bool( 176 | "use_einsum", True, 177 | "Whether to use tf.einsum or tf.reshape+tf.matmul for dense layers. Must " 178 | "be set to False for TFLite compatibility.") 179 | 180 | flags.DEFINE_string( 181 | "export_dir", 182 | default=None, 183 | help=("The directory where the exported SavedModel will be stored.")) 184 | 185 | 186 | def validate_flags_or_throw(albert_config): 187 | """Validate the input FLAGS or throw an exception.""" 188 | 189 | if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.export_dir: 190 | err_msg = "At least one of `do_train` or `do_predict` or `export_dir`" + "must be True." 191 | raise ValueError(err_msg) 192 | 193 | if FLAGS.do_train: 194 | if not FLAGS.train_file: 195 | raise ValueError( 196 | "If `do_train` is True, then `train_file` must be specified.") 197 | if FLAGS.do_predict: 198 | if not FLAGS.predict_file: 199 | raise ValueError( 200 | "If `do_predict` is True, then `predict_file` must be specified.") 201 | if not FLAGS.predict_feature_file: 202 | raise ValueError( 203 | "If `do_predict` is True, then `predict_feature_file` must be " 204 | "specified.") 205 | if not FLAGS.predict_feature_left_file: 206 | raise ValueError( 207 | "If `do_predict` is True, then `predict_feature_left_file` must be " 208 | "specified.") 209 | 210 | if FLAGS.max_seq_length > albert_config.max_position_embeddings: 211 | raise ValueError( 212 | "Cannot use sequence length %d because the ALBERT model " 213 | "was only trained up to sequence length %d" % 214 | (FLAGS.max_seq_length, albert_config.max_position_embeddings)) 215 | 216 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3: 217 | raise ValueError( 218 | "The max_seq_length (%d) must be greater than max_query_length " 219 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)) 220 | 221 | 222 | def build_squad_serving_input_fn(seq_length): 223 | """Builds a serving input fn for raw input.""" 224 | 225 | def _seq_serving_input_fn(): 226 | """Serving input fn for raw images.""" 227 | input_ids = tf.placeholder( 228 | shape=[1, seq_length], name="input_ids", dtype=tf.int32) 229 | input_mask = tf.placeholder( 230 | shape=[1, seq_length], name="input_mask", dtype=tf.int32) 231 | segment_ids = tf.placeholder( 232 | shape=[1, seq_length], name="segment_ids", dtype=tf.int32) 233 | 234 | inputs = { 235 | "input_ids": input_ids, 236 | "input_mask": input_mask, 237 | "segment_ids": segment_ids 238 | } 239 | return tf.estimator.export.ServingInputReceiver(features=inputs, 240 | receiver_tensors=inputs) 241 | 242 | return _seq_serving_input_fn 243 | 244 | 245 | def main(_): 246 | tf.logging.set_verbosity(tf.logging.INFO) 247 | 248 | albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) 249 | 250 | validate_flags_or_throw(albert_config) 251 | 252 | tf.gfile.MakeDirs(FLAGS.output_dir) 253 | 254 | tokenizer = fine_tuning_utils.create_vocab( 255 | vocab_file=FLAGS.vocab_file, 256 | do_lower_case=FLAGS.do_lower_case, 257 | spm_model_file=FLAGS.spm_model_file, 258 | hub_module=FLAGS.albert_hub_module_handle) 259 | 260 | tpu_cluster_resolver = None 261 | if FLAGS.use_tpu and FLAGS.tpu_name: 262 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( 263 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 264 | 265 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 266 | if FLAGS.do_train: 267 | iterations_per_loop = int(min(FLAGS.iterations_per_loop, 268 | FLAGS.save_checkpoints_steps)) 269 | else: 270 | iterations_per_loop = FLAGS.iterations_per_loop 271 | run_config = contrib_tpu.RunConfig( 272 | cluster=tpu_cluster_resolver, 273 | master=FLAGS.master, 274 | model_dir=FLAGS.output_dir, 275 | keep_checkpoint_max=0, 276 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 277 | tpu_config=contrib_tpu.TPUConfig( 278 | iterations_per_loop=iterations_per_loop, 279 | num_shards=FLAGS.num_tpu_cores, 280 | per_host_input_for_training=is_per_host)) 281 | 282 | train_examples = None 283 | num_train_steps = None 284 | num_warmup_steps = None 285 | if FLAGS.do_train: 286 | train_examples = squad_utils.read_squad_examples( 287 | input_file=FLAGS.train_file, is_training=True) 288 | num_train_steps = int( 289 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 290 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 291 | 292 | # Pre-shuffle the input to avoid having to make a very large shuffle 293 | # buffer in in the `input_fn`. 294 | rng = random.Random(12345) 295 | rng.shuffle(train_examples) 296 | 297 | model_fn = squad_utils.v1_model_fn_builder( 298 | albert_config=albert_config, 299 | init_checkpoint=FLAGS.init_checkpoint, 300 | learning_rate=FLAGS.learning_rate, 301 | num_train_steps=num_train_steps, 302 | num_warmup_steps=num_warmup_steps, 303 | use_tpu=FLAGS.use_tpu, 304 | use_one_hot_embeddings=FLAGS.use_tpu, 305 | use_einsum=FLAGS.use_einsum, 306 | hub_module=FLAGS.albert_hub_module_handle) 307 | 308 | # If TPU is not available, this will fall back to normal Estimator on CPU 309 | # or GPU. 310 | estimator = contrib_tpu.TPUEstimator( 311 | use_tpu=FLAGS.use_tpu, 312 | model_fn=model_fn, 313 | config=run_config, 314 | train_batch_size=FLAGS.train_batch_size, 315 | predict_batch_size=FLAGS.predict_batch_size) 316 | 317 | if FLAGS.do_train: 318 | # We write to a temporary file to avoid storing very large constant tensors 319 | # in memory. 320 | 321 | if not tf.gfile.Exists(FLAGS.train_feature_file): 322 | train_writer = squad_utils.FeatureWriter( 323 | filename=os.path.join(FLAGS.train_feature_file), is_training=True) 324 | squad_utils.convert_examples_to_features( 325 | examples=train_examples, 326 | tokenizer=tokenizer, 327 | max_seq_length=FLAGS.max_seq_length, 328 | doc_stride=FLAGS.doc_stride, 329 | max_query_length=FLAGS.max_query_length, 330 | is_training=True, 331 | output_fn=train_writer.process_feature, 332 | do_lower_case=FLAGS.do_lower_case) 333 | train_writer.close() 334 | 335 | tf.logging.info("***** Running training *****") 336 | tf.logging.info(" Num orig examples = %d", len(train_examples)) 337 | # tf.logging.info(" Num split examples = %d", train_writer.num_features) 338 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 339 | tf.logging.info(" Num steps = %d", num_train_steps) 340 | del train_examples 341 | 342 | train_input_fn = squad_utils.input_fn_builder( 343 | input_file=FLAGS.train_feature_file, 344 | seq_length=FLAGS.max_seq_length, 345 | is_training=True, 346 | drop_remainder=True, 347 | use_tpu=FLAGS.use_tpu, 348 | bsz=FLAGS.train_batch_size, 349 | is_v2=False) 350 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 351 | 352 | if FLAGS.do_predict: 353 | with tf.gfile.Open(FLAGS.predict_file) as predict_file: 354 | prediction_json = json.load(predict_file)["data"] 355 | 356 | eval_examples = squad_utils.read_squad_examples( 357 | input_file=FLAGS.predict_file, is_training=False) 358 | 359 | if (tf.gfile.Exists(FLAGS.predict_feature_file) and tf.gfile.Exists( 360 | FLAGS.predict_feature_left_file)): 361 | tf.logging.info("Loading eval features from {}".format( 362 | FLAGS.predict_feature_left_file)) 363 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "rb") as fin: 364 | eval_features = pickle.load(fin) 365 | else: 366 | eval_writer = squad_utils.FeatureWriter( 367 | filename=FLAGS.predict_feature_file, is_training=False) 368 | eval_features = [] 369 | 370 | def append_feature(feature): 371 | eval_features.append(feature) 372 | eval_writer.process_feature(feature) 373 | 374 | squad_utils.convert_examples_to_features( 375 | examples=eval_examples, 376 | tokenizer=tokenizer, 377 | max_seq_length=FLAGS.max_seq_length, 378 | doc_stride=FLAGS.doc_stride, 379 | max_query_length=FLAGS.max_query_length, 380 | is_training=False, 381 | output_fn=append_feature, 382 | do_lower_case=FLAGS.do_lower_case) 383 | eval_writer.close() 384 | 385 | with tf.gfile.Open(FLAGS.predict_feature_left_file, "wb") as fout: 386 | pickle.dump(eval_features, fout) 387 | 388 | tf.logging.info("***** Running predictions *****") 389 | tf.logging.info(" Num orig examples = %d", len(eval_examples)) 390 | tf.logging.info(" Num split examples = %d", len(eval_features)) 391 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 392 | 393 | predict_input_fn = squad_utils.input_fn_builder( 394 | input_file=FLAGS.predict_feature_file, 395 | seq_length=FLAGS.max_seq_length, 396 | is_training=False, 397 | drop_remainder=False, 398 | use_tpu=FLAGS.use_tpu, 399 | bsz=FLAGS.predict_batch_size, 400 | is_v2=False) 401 | 402 | def get_result(checkpoint): 403 | """Evaluate the checkpoint on SQuAD 1.0.""" 404 | # If running eval on the TPU, you will need to specify the number of 405 | # steps. 406 | reader = tf.train.NewCheckpointReader(checkpoint) 407 | global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP) 408 | all_results = [] 409 | for result in estimator.predict( 410 | predict_input_fn, yield_single_examples=True, 411 | checkpoint_path=checkpoint): 412 | if len(all_results) % 1000 == 0: 413 | tf.logging.info("Processing example: %d" % (len(all_results))) 414 | unique_id = int(result["unique_ids"]) 415 | start_log_prob = [float(x) for x in result["start_log_prob"].flat] 416 | end_log_prob = [float(x) for x in result["end_log_prob"].flat] 417 | all_results.append( 418 | squad_utils.RawResult( 419 | unique_id=unique_id, 420 | start_log_prob=start_log_prob, 421 | end_log_prob=end_log_prob)) 422 | 423 | output_prediction_file = os.path.join( 424 | FLAGS.output_dir, "predictions.json") 425 | output_nbest_file = os.path.join( 426 | FLAGS.output_dir, "nbest_predictions.json") 427 | 428 | result_dict = {} 429 | squad_utils.accumulate_predictions_v1( 430 | result_dict, eval_examples, eval_features, 431 | all_results, FLAGS.n_best_size, FLAGS.max_answer_length) 432 | predictions = squad_utils.write_predictions_v1( 433 | result_dict, eval_examples, eval_features, all_results, 434 | FLAGS.n_best_size, FLAGS.max_answer_length, 435 | output_prediction_file, output_nbest_file) 436 | 437 | return squad_utils.evaluate_v1( 438 | prediction_json, predictions), int(global_step) 439 | 440 | def _find_valid_cands(curr_step): 441 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 442 | candidates = [] 443 | for filename in filenames: 444 | if filename.endswith(".index"): 445 | ckpt_name = filename[:-6] 446 | idx = ckpt_name.split("-")[-1] 447 | if idx != "best" and int(idx) > curr_step: 448 | candidates.append(filename) 449 | return candidates 450 | 451 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 452 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 453 | key_name = "f1" 454 | writer = tf.gfile.GFile(output_eval_file, "w") 455 | if tf.gfile.Exists(checkpoint_path + ".index"): 456 | result = get_result(checkpoint_path) 457 | best_perf = result[0][key_name] 458 | global_step = result[1] 459 | else: 460 | global_step = -1 461 | best_perf = -1 462 | checkpoint_path = None 463 | while global_step < num_train_steps: 464 | steps_and_files = {} 465 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 466 | for filename in filenames: 467 | if filename.endswith(".index"): 468 | ckpt_name = filename[:-6] 469 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 470 | if cur_filename.split("-")[-1] == "best": 471 | continue 472 | gstep = int(cur_filename.split("-")[-1]) 473 | if gstep not in steps_and_files: 474 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 475 | steps_and_files[gstep] = cur_filename 476 | tf.logging.info("found {} files.".format(len(steps_and_files))) 477 | if not steps_and_files: 478 | tf.logging.info("found 0 file, global step: {}. Sleeping." 479 | .format(global_step)) 480 | time.sleep(60) 481 | else: 482 | for ele in sorted(steps_and_files.items()): 483 | step, checkpoint_path = ele 484 | if global_step >= step: 485 | if len(_find_valid_cands(step)) > 1: 486 | for ext in ["meta", "data-00000-of-00001", "index"]: 487 | src_ckpt = checkpoint_path + ".{}".format(ext) 488 | tf.logging.info("removing {}".format(src_ckpt)) 489 | tf.gfile.Remove(src_ckpt) 490 | continue 491 | result, global_step = get_result(checkpoint_path) 492 | tf.logging.info("***** Eval results *****") 493 | for key in sorted(result.keys()): 494 | tf.logging.info(" %s = %s", key, str(result[key])) 495 | writer.write("%s = %s\n" % (key, str(result[key]))) 496 | if result[key_name] > best_perf: 497 | best_perf = result[key_name] 498 | for ext in ["meta", "data-00000-of-00001", "index"]: 499 | src_ckpt = checkpoint_path + ".{}".format(ext) 500 | tgt_ckpt = checkpoint_path.rsplit( 501 | "-", 1)[0] + "-best.{}".format(ext) 502 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt)) 503 | tf.gfile.Copy(src_ckpt, tgt_ckpt, overwrite=True) 504 | writer.write("saved {} to {}\n".format(src_ckpt, tgt_ckpt)) 505 | writer.write("best {} = {}\n".format(key_name, best_perf)) 506 | tf.logging.info(" best {} = {}\n".format(key_name, best_perf)) 507 | 508 | if len(_find_valid_cands(global_step)) > 2: 509 | for ext in ["meta", "data-00000-of-00001", "index"]: 510 | src_ckpt = checkpoint_path + ".{}".format(ext) 511 | tf.logging.info("removing {}".format(src_ckpt)) 512 | tf.gfile.Remove(src_ckpt) 513 | writer.write("=" * 50 + "\n") 514 | 515 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 516 | result, global_step = get_result(checkpoint_path) 517 | tf.logging.info("***** Final Eval results *****") 518 | for key in sorted(result.keys()): 519 | tf.logging.info(" %s = %s", key, str(result[key])) 520 | writer.write("%s = %s\n" % (key, str(result[key]))) 521 | writer.write("best perf happened at step: {}".format(global_step)) 522 | 523 | if FLAGS.export_dir: 524 | tf.gfile.MakeDirs(FLAGS.export_dir) 525 | squad_serving_input_fn = ( 526 | build_squad_serving_input_fn(FLAGS.max_seq_length)) 527 | tf.logging.info("Starting to export model.") 528 | subfolder = estimator.export_saved_model( 529 | export_dir_base=os.path.join(FLAGS.export_dir, "saved_model"), 530 | serving_input_receiver_fn=squad_serving_input_fn) 531 | 532 | tf.logging.info("Starting to export TFLite.") 533 | converter = tf.lite.TFLiteConverter.from_saved_model( 534 | subfolder, 535 | input_arrays=["input_ids", "input_mask", "segment_ids"], 536 | output_arrays=["start_logits", "end_logits"]) 537 | float_model = converter.convert() 538 | tflite_file = os.path.join(FLAGS.export_dir, "albert_model.tflite") 539 | with tf.gfile.GFile(tflite_file, "wb") as f: 540 | f.write(float_model) 541 | 542 | 543 | if __name__ == "__main__": 544 | flags.mark_flag_as_required("spm_model_file") 545 | flags.mark_flag_as_required("albert_config_file") 546 | flags.mark_flag_as_required("output_dir") 547 | tf.app.run() 548 | -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning on classification tasks.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import os 23 | import time 24 | from albert import classifier_utils 25 | from albert import fine_tuning_utils 26 | from albert import modeling 27 | import tensorflow.compat.v1 as tf 28 | from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver 29 | from tensorflow.contrib import tpu as contrib_tpu 30 | 31 | flags = tf.flags 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | ## Required parameters 36 | flags.DEFINE_string( 37 | "data_dir", None, 38 | "The input data dir. Should contain the .tsv files (or other data files) " 39 | "for the task.") 40 | 41 | flags.DEFINE_string( 42 | "albert_config_file", None, 43 | "The config json file corresponding to the pre-trained ALBERT model. " 44 | "This specifies the model architecture.") 45 | 46 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 47 | 48 | flags.DEFINE_string( 49 | "vocab_file", None, 50 | "The vocabulary file that the ALBERT model was trained on.") 51 | 52 | flags.DEFINE_string("spm_model_file", None, 53 | "The model file for sentence piece tokenization.") 54 | 55 | flags.DEFINE_string( 56 | "output_dir", None, 57 | "The output directory where the model checkpoints will be written.") 58 | 59 | flags.DEFINE_string("cached_dir", None, 60 | "Path to cached training and dev tfrecord file. " 61 | "The file will be generated if not exist.") 62 | 63 | ## Other parameters 64 | 65 | flags.DEFINE_string( 66 | "init_checkpoint", None, 67 | "Initial checkpoint (usually from a pre-trained BERT model).") 68 | 69 | flags.DEFINE_string( 70 | "albert_hub_module_handle", None, 71 | "If set, the ALBERT hub module to use.") 72 | 73 | flags.DEFINE_bool( 74 | "do_lower_case", True, 75 | "Whether to lower case the input text. Should be True for uncased " 76 | "models and False for cased models.") 77 | 78 | flags.DEFINE_integer( 79 | "max_seq_length", 512, 80 | "The maximum total input sequence length after WordPiece tokenization. " 81 | "Sequences longer than this will be truncated, and sequences shorter " 82 | "than this will be padded.") 83 | 84 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 85 | 86 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 87 | 88 | flags.DEFINE_bool( 89 | "do_predict", False, 90 | "Whether to run the model in inference mode on the test set.") 91 | 92 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 93 | 94 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 95 | 96 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 97 | 98 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 99 | 100 | flags.DEFINE_integer("train_step", 1000, 101 | "Total number of training steps to perform.") 102 | 103 | flags.DEFINE_integer( 104 | "warmup_step", 0, 105 | "number of steps to perform linear learning rate warmup for.") 106 | 107 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 108 | "How often to save the model checkpoint.") 109 | 110 | flags.DEFINE_integer("keep_checkpoint_max", 5, 111 | "How many checkpoints to keep.") 112 | 113 | flags.DEFINE_integer("iterations_per_loop", 1000, 114 | "How many steps to make in each estimator call.") 115 | 116 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 117 | 118 | flags.DEFINE_string("optimizer", "adamw", "Optimizer to use") 119 | 120 | tf.flags.DEFINE_string( 121 | "tpu_name", None, 122 | "The Cloud TPU to use for training. This should be either the name " 123 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 124 | "url.") 125 | 126 | tf.flags.DEFINE_string( 127 | "tpu_zone", None, 128 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 129 | "specified, we will attempt to automatically detect the GCE project from " 130 | "metadata.") 131 | 132 | tf.flags.DEFINE_string( 133 | "gcp_project", None, 134 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 135 | "specified, we will attempt to automatically detect the GCE project from " 136 | "metadata.") 137 | 138 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 139 | 140 | flags.DEFINE_integer( 141 | "num_tpu_cores", 8, 142 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 143 | 144 | flags.DEFINE_string( 145 | "export_dir", None, 146 | "The directory where the exported SavedModel will be stored.") 147 | 148 | flags.DEFINE_float( 149 | "threshold_to_export", float("nan"), 150 | "The threshold value that should be used with the exported classifier. " 151 | "When specified, the threshold will be attached to the exported " 152 | "SavedModel, and served along with the predictions. Please use the " 153 | "saved model cli (" 154 | "https://www.tensorflow.org/guide/saved_model#details_of_the_savedmodel_command_line_interface" 155 | ") to view the output signature of the threshold.") 156 | 157 | 158 | def _serving_input_receiver_fn(): 159 | """Creates an input function for serving.""" 160 | seq_len = FLAGS.max_seq_length 161 | serialized_example = tf.placeholder( 162 | dtype=tf.string, shape=[None], name="serialized_example") 163 | features = { 164 | "input_ids": tf.FixedLenFeature([seq_len], dtype=tf.int64), 165 | "input_mask": tf.FixedLenFeature([seq_len], dtype=tf.int64), 166 | "segment_ids": tf.FixedLenFeature([seq_len], dtype=tf.int64), 167 | } 168 | feature_map = tf.parse_example(serialized_example, features=features) 169 | feature_map["is_real_example"] = tf.constant(1, dtype=tf.int32) 170 | feature_map["label_ids"] = tf.constant(0, dtype=tf.int32) 171 | 172 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 173 | # So cast all int64 to int32. 174 | for name in feature_map.keys(): 175 | t = feature_map[name] 176 | if t.dtype == tf.int64: 177 | t = tf.to_int32(t) 178 | feature_map[name] = t 179 | 180 | return tf.estimator.export.ServingInputReceiver( 181 | features=feature_map, receiver_tensors=serialized_example) 182 | 183 | 184 | def _add_threshold_to_model_fn(model_fn, threshold): 185 | """Adds the classifier threshold to the given model_fn.""" 186 | 187 | def new_model_fn(features, labels, mode, params): 188 | spec = model_fn(features, labels, mode, params) 189 | threshold_tensor = tf.constant(threshold, dtype=tf.float32) 190 | default_serving_export = spec.export_outputs[ 191 | tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 192 | default_serving_export.outputs["threshold"] = threshold_tensor 193 | return spec 194 | 195 | return new_model_fn 196 | 197 | 198 | def main(_): 199 | tf.logging.set_verbosity(tf.logging.INFO) 200 | 201 | processors = { 202 | "cola": classifier_utils.ColaProcessor, 203 | "mnli": classifier_utils.MnliProcessor, 204 | "mismnli": classifier_utils.MisMnliProcessor, 205 | "mrpc": classifier_utils.MrpcProcessor, 206 | "rte": classifier_utils.RteProcessor, 207 | "sst-2": classifier_utils.Sst2Processor, 208 | "sts-b": classifier_utils.StsbProcessor, 209 | "qqp": classifier_utils.QqpProcessor, 210 | "qnli": classifier_utils.QnliProcessor, 211 | "wnli": classifier_utils.WnliProcessor, 212 | } 213 | 214 | if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_predict or 215 | FLAGS.export_dir): 216 | raise ValueError( 217 | "At least one of `do_train`, `do_eval`, `do_predict' or `export_dir` " 218 | "must be True.") 219 | 220 | if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle: 221 | raise ValueError("At least one of `--albert_config_file` and " 222 | "`--albert_hub_module_handle` must be set") 223 | 224 | if FLAGS.albert_config_file: 225 | albert_config = modeling.AlbertConfig.from_json_file( 226 | FLAGS.albert_config_file) 227 | if FLAGS.max_seq_length > albert_config.max_position_embeddings: 228 | raise ValueError( 229 | "Cannot use sequence length %d because the ALBERT model " 230 | "was only trained up to sequence length %d" % 231 | (FLAGS.max_seq_length, albert_config.max_position_embeddings)) 232 | else: 233 | albert_config = None # Get the config from TF-Hub. 234 | 235 | tf.gfile.MakeDirs(FLAGS.output_dir) 236 | 237 | task_name = FLAGS.task_name.lower() 238 | 239 | if task_name not in processors: 240 | raise ValueError("Task not found: %s" % (task_name)) 241 | 242 | processor = processors[task_name]( 243 | use_spm=True if FLAGS.spm_model_file else False, 244 | do_lower_case=FLAGS.do_lower_case) 245 | 246 | label_list = processor.get_labels() 247 | 248 | tokenizer = fine_tuning_utils.create_vocab( 249 | vocab_file=FLAGS.vocab_file, 250 | do_lower_case=FLAGS.do_lower_case, 251 | spm_model_file=FLAGS.spm_model_file, 252 | hub_module=FLAGS.albert_hub_module_handle) 253 | 254 | tpu_cluster_resolver = None 255 | if FLAGS.use_tpu and FLAGS.tpu_name: 256 | tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( 257 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 258 | 259 | is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 260 | if FLAGS.do_train: 261 | iterations_per_loop = int(min(FLAGS.iterations_per_loop, 262 | FLAGS.save_checkpoints_steps)) 263 | else: 264 | iterations_per_loop = FLAGS.iterations_per_loop 265 | run_config = contrib_tpu.RunConfig( 266 | cluster=tpu_cluster_resolver, 267 | master=FLAGS.master, 268 | model_dir=FLAGS.output_dir, 269 | save_checkpoints_steps=int(FLAGS.save_checkpoints_steps), 270 | keep_checkpoint_max=0, 271 | tpu_config=contrib_tpu.TPUConfig( 272 | iterations_per_loop=iterations_per_loop, 273 | num_shards=FLAGS.num_tpu_cores, 274 | per_host_input_for_training=is_per_host)) 275 | 276 | train_examples = None 277 | if FLAGS.do_train: 278 | train_examples = processor.get_train_examples(FLAGS.data_dir) 279 | model_fn = classifier_utils.model_fn_builder( 280 | albert_config=albert_config, 281 | num_labels=len(label_list), 282 | init_checkpoint=FLAGS.init_checkpoint, 283 | learning_rate=FLAGS.learning_rate, 284 | num_train_steps=FLAGS.train_step, 285 | num_warmup_steps=FLAGS.warmup_step, 286 | use_tpu=FLAGS.use_tpu, 287 | use_one_hot_embeddings=FLAGS.use_tpu, 288 | task_name=task_name, 289 | hub_module=FLAGS.albert_hub_module_handle, 290 | optimizer=FLAGS.optimizer) 291 | 292 | if not math.isnan(FLAGS.threshold_to_export): 293 | model_fn = _add_threshold_to_model_fn(model_fn, FLAGS.threshold_to_export) 294 | 295 | # If TPU is not available, this will fall back to normal Estimator on CPU 296 | # or GPU. 297 | estimator = contrib_tpu.TPUEstimator( 298 | use_tpu=FLAGS.use_tpu, 299 | model_fn=model_fn, 300 | config=run_config, 301 | train_batch_size=FLAGS.train_batch_size, 302 | eval_batch_size=FLAGS.eval_batch_size, 303 | predict_batch_size=FLAGS.predict_batch_size, 304 | export_to_tpu=False) # http://yaqs/4707241341091840 305 | 306 | if FLAGS.do_train: 307 | cached_dir = FLAGS.cached_dir 308 | if not cached_dir: 309 | cached_dir = FLAGS.output_dir 310 | train_file = os.path.join(cached_dir, task_name + "_train.tf_record") 311 | if not tf.gfile.Exists(train_file): 312 | classifier_utils.file_based_convert_examples_to_features( 313 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, 314 | train_file, task_name) 315 | tf.logging.info("***** Running training *****") 316 | tf.logging.info(" Num examples = %d", len(train_examples)) 317 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 318 | tf.logging.info(" Num steps = %d", FLAGS.train_step) 319 | train_input_fn = classifier_utils.file_based_input_fn_builder( 320 | input_file=train_file, 321 | seq_length=FLAGS.max_seq_length, 322 | is_training=True, 323 | drop_remainder=True, 324 | task_name=task_name, 325 | use_tpu=FLAGS.use_tpu, 326 | bsz=FLAGS.train_batch_size) 327 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_step) 328 | 329 | if FLAGS.do_eval: 330 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 331 | num_actual_eval_examples = len(eval_examples) 332 | if FLAGS.use_tpu: 333 | # TPU requires a fixed batch size for all batches, therefore the number 334 | # of examples must be a multiple of the batch size, or else examples 335 | # will get dropped. So we pad with fake examples which are ignored 336 | # later on. These do NOT count towards the metric (all tf.metrics 337 | # support a per-instance weight, and these get a weight of 0.0). 338 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 339 | eval_examples.append(classifier_utils.PaddingInputExample()) 340 | 341 | cached_dir = FLAGS.cached_dir 342 | if not cached_dir: 343 | cached_dir = FLAGS.output_dir 344 | eval_file = os.path.join(cached_dir, task_name + "_eval.tf_record") 345 | if not tf.gfile.Exists(eval_file): 346 | classifier_utils.file_based_convert_examples_to_features( 347 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, 348 | eval_file, task_name) 349 | 350 | tf.logging.info("***** Running evaluation *****") 351 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 352 | len(eval_examples), num_actual_eval_examples, 353 | len(eval_examples) - num_actual_eval_examples) 354 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 355 | 356 | # This tells the estimator to run through the entire set. 357 | eval_steps = None 358 | # However, if running eval on the TPU, you will need to specify the 359 | # number of steps. 360 | if FLAGS.use_tpu: 361 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 362 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 363 | 364 | eval_drop_remainder = True if FLAGS.use_tpu else False 365 | eval_input_fn = classifier_utils.file_based_input_fn_builder( 366 | input_file=eval_file, 367 | seq_length=FLAGS.max_seq_length, 368 | is_training=False, 369 | drop_remainder=eval_drop_remainder, 370 | task_name=task_name, 371 | use_tpu=FLAGS.use_tpu, 372 | bsz=FLAGS.eval_batch_size) 373 | 374 | best_trial_info_file = os.path.join(FLAGS.output_dir, "best_trial.txt") 375 | 376 | def _best_trial_info(): 377 | """Returns information about which checkpoints have been evaled so far.""" 378 | if tf.gfile.Exists(best_trial_info_file): 379 | with tf.gfile.GFile(best_trial_info_file, "r") as best_info: 380 | global_step, best_metric_global_step, metric_value = ( 381 | best_info.read().split(":")) 382 | global_step = int(global_step) 383 | best_metric_global_step = int(best_metric_global_step) 384 | metric_value = float(metric_value) 385 | else: 386 | metric_value = -1 387 | best_metric_global_step = -1 388 | global_step = -1 389 | tf.logging.info( 390 | "Best trial info: Step: %s, Best Value Step: %s, " 391 | "Best Value: %s", global_step, best_metric_global_step, metric_value) 392 | return global_step, best_metric_global_step, metric_value 393 | 394 | def _remove_checkpoint(checkpoint_path): 395 | for ext in ["meta", "data-00000-of-00001", "index"]: 396 | src_ckpt = checkpoint_path + ".{}".format(ext) 397 | tf.logging.info("removing {}".format(src_ckpt)) 398 | tf.gfile.Remove(src_ckpt) 399 | 400 | def _find_valid_cands(curr_step): 401 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 402 | candidates = [] 403 | for filename in filenames: 404 | if filename.endswith(".index"): 405 | ckpt_name = filename[:-6] 406 | idx = ckpt_name.split("-")[-1] 407 | if int(idx) > curr_step: 408 | candidates.append(filename) 409 | return candidates 410 | 411 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 412 | 413 | if task_name == "sts-b": 414 | key_name = "pearson" 415 | elif task_name == "cola": 416 | key_name = "matthew_corr" 417 | else: 418 | key_name = "eval_accuracy" 419 | 420 | global_step, best_perf_global_step, best_perf = _best_trial_info() 421 | writer = tf.gfile.GFile(output_eval_file, "w") 422 | while global_step < FLAGS.train_step: 423 | steps_and_files = {} 424 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 425 | for filename in filenames: 426 | if filename.endswith(".index"): 427 | ckpt_name = filename[:-6] 428 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 429 | if cur_filename.split("-")[-1] == "best": 430 | continue 431 | gstep = int(cur_filename.split("-")[-1]) 432 | if gstep not in steps_and_files: 433 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 434 | steps_and_files[gstep] = cur_filename 435 | tf.logging.info("found {} files.".format(len(steps_and_files))) 436 | if not steps_and_files: 437 | tf.logging.info("found 0 file, global step: {}. Sleeping." 438 | .format(global_step)) 439 | time.sleep(60) 440 | else: 441 | for checkpoint in sorted(steps_and_files.items()): 442 | step, checkpoint_path = checkpoint 443 | if global_step >= step: 444 | if (best_perf_global_step != step and 445 | len(_find_valid_cands(step)) > 1): 446 | _remove_checkpoint(checkpoint_path) 447 | continue 448 | result = estimator.evaluate( 449 | input_fn=eval_input_fn, 450 | steps=eval_steps, 451 | checkpoint_path=checkpoint_path) 452 | global_step = result["global_step"] 453 | tf.logging.info("***** Eval results *****") 454 | for key in sorted(result.keys()): 455 | tf.logging.info(" %s = %s", key, str(result[key])) 456 | writer.write("%s = %s\n" % (key, str(result[key]))) 457 | writer.write("best = {}\n".format(best_perf)) 458 | if result[key_name] > best_perf: 459 | best_perf = result[key_name] 460 | best_perf_global_step = global_step 461 | elif len(_find_valid_cands(global_step)) > 1: 462 | _remove_checkpoint(checkpoint_path) 463 | writer.write("=" * 50 + "\n") 464 | writer.flush() 465 | with tf.gfile.GFile(best_trial_info_file, "w") as best_info: 466 | best_info.write("{}:{}:{}".format( 467 | global_step, best_perf_global_step, best_perf)) 468 | writer.close() 469 | 470 | for ext in ["meta", "data-00000-of-00001", "index"]: 471 | src_ckpt = "model.ckpt-{}.{}".format(best_perf_global_step, ext) 472 | tgt_ckpt = "model.ckpt-best.{}".format(ext) 473 | tf.logging.info("saving {} to {}".format(src_ckpt, tgt_ckpt)) 474 | tf.io.gfile.rename( 475 | os.path.join(FLAGS.output_dir, src_ckpt), 476 | os.path.join(FLAGS.output_dir, tgt_ckpt), 477 | overwrite=True) 478 | 479 | if FLAGS.do_predict: 480 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 481 | num_actual_predict_examples = len(predict_examples) 482 | if FLAGS.use_tpu: 483 | # TPU requires a fixed batch size for all batches, therefore the number 484 | # of examples must be a multiple of the batch size, or else examples 485 | # will get dropped. So we pad with fake examples which are ignored 486 | # later on. 487 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 488 | predict_examples.append(classifier_utils.PaddingInputExample()) 489 | 490 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 491 | classifier_utils.file_based_convert_examples_to_features( 492 | predict_examples, label_list, 493 | FLAGS.max_seq_length, tokenizer, 494 | predict_file, task_name) 495 | 496 | tf.logging.info("***** Running prediction*****") 497 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 498 | len(predict_examples), num_actual_predict_examples, 499 | len(predict_examples) - num_actual_predict_examples) 500 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 501 | 502 | predict_drop_remainder = True if FLAGS.use_tpu else False 503 | predict_input_fn = classifier_utils.file_based_input_fn_builder( 504 | input_file=predict_file, 505 | seq_length=FLAGS.max_seq_length, 506 | is_training=False, 507 | drop_remainder=predict_drop_remainder, 508 | task_name=task_name, 509 | use_tpu=FLAGS.use_tpu, 510 | bsz=FLAGS.predict_batch_size) 511 | 512 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 513 | result = estimator.predict( 514 | input_fn=predict_input_fn, 515 | checkpoint_path=checkpoint_path) 516 | 517 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 518 | output_submit_file = os.path.join(FLAGS.output_dir, "submit_results.tsv") 519 | with tf.gfile.GFile(output_predict_file, "w") as pred_writer,\ 520 | tf.gfile.GFile(output_submit_file, "w") as sub_writer: 521 | sub_writer.write("index" + "\t" + "prediction\n") 522 | num_written_lines = 0 523 | tf.logging.info("***** Predict results *****") 524 | for (i, (example, prediction)) in\ 525 | enumerate(zip(predict_examples, result)): 526 | probabilities = prediction["probabilities"] 527 | if i >= num_actual_predict_examples: 528 | break 529 | output_line = "\t".join( 530 | str(class_probability) 531 | for class_probability in probabilities) + "\n" 532 | pred_writer.write(output_line) 533 | 534 | if task_name != "sts-b": 535 | actual_label = label_list[int(prediction["predictions"])] 536 | else: 537 | actual_label = str(prediction["predictions"]) 538 | sub_writer.write(example.guid + "\t" + actual_label + "\n") 539 | num_written_lines += 1 540 | assert num_written_lines == num_actual_predict_examples 541 | 542 | if FLAGS.export_dir: 543 | tf.gfile.MakeDirs(FLAGS.export_dir) 544 | checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best") 545 | tf.logging.info("Starting to export model.") 546 | subfolder = estimator.export_saved_model( 547 | export_dir_base=FLAGS.export_dir, 548 | serving_input_receiver_fn=_serving_input_receiver_fn, 549 | checkpoint_path=checkpoint_path) 550 | tf.logging.info("Model exported to %s.", subfolder) 551 | 552 | 553 | if __name__ == "__main__": 554 | flags.mark_flag_as_required("data_dir") 555 | flags.mark_flag_as_required("task_name") 556 | flags.mark_flag_as_required("spm_model_file") 557 | flags.mark_flag_as_required("output_dir") 558 | tf.app.run() 559 | --------------------------------------------------------------------------------