├── requirements.txt ├── samples ├── input.txt └── sample_text.txt ├── __init__.py ├── CONTRIBUTING.md ├── tests ├── optimization_test.py ├── tokenization_test.py └── modeling_test.py ├── .gitignore ├── convert_tf_checkpoint_to_pytorch.py ├── optimization.py ├── tokenization.py ├── LICENSE ├── extract_features.py ├── README.md ├── modeling.py ├── run_classifier.py └── run_squad.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | tqdm -------------------------------------------------------------------------------- /samples/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | BERT needs to maintain permanent compatibility with the pre-trained model files, 4 | so we do not plan to make any major changes to this library (other than what was 5 | promised in the README). However, we can accept small patches related to 6 | re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | 21 | import torch 22 | 23 | import optimization 24 | 25 | class OptimizationTest(unittest.TestCase): 26 | 27 | def assertListAlmostEqual(self, list1, list2, tol): 28 | self.assertEqual(len(list1), len(list2)) 29 | for a, b in zip(list1, list2): 30 | self.assertAlmostEqual(a, b, delta=tol) 31 | 32 | def test_adam(self): 33 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 34 | target = torch.tensor([0.4, 0.2, -0.5]) 35 | criterion = torch.nn.MSELoss(reduction='elementwise_mean') 36 | # No warmup, constant schedule, no gradient clipping 37 | optimizer = optimization.BERTAdam(params=[w], lr=2e-1, 38 | weight_decay_rate=0.0, 39 | max_grad_norm=-1) 40 | for _ in range(100): 41 | loss = criterion(w, target) 42 | loss.backward() 43 | optimizer.step() 44 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 45 | w.grad.zero_() 46 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 47 | 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | # vscode 119 | .vscode 120 | 121 | # TF code 122 | tensorflow_code -------------------------------------------------------------------------------- /convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import argparse 23 | import tensorflow as tf 24 | import torch 25 | import numpy as np 26 | 27 | from modeling import BertConfig, BertModel 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | ## Required parameters 32 | parser.add_argument("--tf_checkpoint_path", 33 | default = None, 34 | type = str, 35 | required = True, 36 | help = "Path the TensorFlow checkpoint path.") 37 | parser.add_argument("--bert_config_file", 38 | default = None, 39 | type = str, 40 | required = True, 41 | help = "The config json file corresponding to the pre-trained BERT model. \n" 42 | "This specifies the model architecture.") 43 | parser.add_argument("--pytorch_dump_path", 44 | default = None, 45 | type = str, 46 | required = True, 47 | help = "Path to the output PyTorch model.") 48 | 49 | args = parser.parse_args() 50 | 51 | def convert(): 52 | # Initialise PyTorch model 53 | config = BertConfig.from_json_file(args.bert_config_file) 54 | model = BertModel(config) 55 | 56 | # Load weights from TF model 57 | path = args.tf_checkpoint_path 58 | print("Converting TensorFlow checkpoint from {}".format(path)) 59 | 60 | init_vars = tf.train.list_variables(path) 61 | names = [] 62 | arrays = [] 63 | for name, shape in init_vars: 64 | print("Loading {} with shape {}".format(name, shape)) 65 | array = tf.train.load_variable(path, name) 66 | print("Numpy array shape {}".format(array.shape)) 67 | names.append(name) 68 | arrays.append(array) 69 | 70 | for name, array in zip(names, arrays): 71 | name = name[5:] # skip "bert/" 72 | print("Loading {}".format(name)) 73 | name = name.split('/') 74 | if name[0] in ['redictions', 'eq_relationship']: 75 | print("Skipping") 76 | continue 77 | pointer = model 78 | for m_name in name: 79 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 80 | l = re.split(r'_(\d+)', m_name) 81 | else: 82 | l = [m_name] 83 | if l[0] == 'kernel': 84 | pointer = getattr(pointer, 'weight') 85 | else: 86 | pointer = getattr(pointer, l[0]) 87 | if len(l) >= 2: 88 | num = int(l[1]) 89 | pointer = pointer[num] 90 | if m_name[-11:] == '_embeddings': 91 | pointer = getattr(pointer, 'weight') 92 | elif m_name == 'kernel': 93 | array = np.transpose(array) 94 | try: 95 | assert pointer.shape == array.shape 96 | except AssertionError as e: 97 | e.args += (pointer.shape, array.shape) 98 | raise 99 | pointer.data = torch.from_numpy(array) 100 | 101 | # Save pytorch-model 102 | torch.save(model.state_dict(), args.pytorch_dump_path) 103 | 104 | if __name__ == "__main__": 105 | convert() 106 | -------------------------------------------------------------------------------- /samples/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /tests/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import unittest 21 | 22 | import tokenization 23 | 24 | 25 | class TokenizationTest(unittest.TestCase): 26 | 27 | def test_full_tokenizer(self): 28 | vocab_tokens = [ 29 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 30 | "##ing", "," 31 | ] 32 | with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer: 33 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 34 | 35 | vocab_file = vocab_writer.name 36 | 37 | tokenizer = tokenization.FullTokenizer(vocab_file) 38 | os.remove(vocab_file) 39 | 40 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 41 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 42 | 43 | self.assertListEqual( 44 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 45 | 46 | def test_basic_tokenizer_lower(self): 47 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 48 | 49 | self.assertListEqual( 50 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 51 | ["hello", "!", "how", "are", "you", "?"]) 52 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 53 | 54 | def test_basic_tokenizer_no_lower(self): 55 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 56 | 57 | self.assertListEqual( 58 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 59 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 60 | 61 | def test_wordpiece_tokenizer(self): 62 | vocab_tokens = [ 63 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 64 | "##ing" 65 | ] 66 | 67 | vocab = {} 68 | for (i, token) in enumerate(vocab_tokens): 69 | vocab[token] = i 70 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 71 | 72 | self.assertListEqual(tokenizer.tokenize(""), []) 73 | 74 | self.assertListEqual( 75 | tokenizer.tokenize("unwanted running"), 76 | ["un", "##want", "##ed", "runn", "##ing"]) 77 | 78 | self.assertListEqual( 79 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 80 | 81 | def test_convert_tokens_to_ids(self): 82 | vocab_tokens = [ 83 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 84 | "##ing" 85 | ] 86 | 87 | vocab = {} 88 | for (i, token) in enumerate(vocab_tokens): 89 | vocab[token] = i 90 | 91 | self.assertListEqual( 92 | tokenization.convert_tokens_to_ids( 93 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 94 | 95 | def test_is_whitespace(self): 96 | self.assertTrue(tokenization._is_whitespace(u" ")) 97 | self.assertTrue(tokenization._is_whitespace(u"\t")) 98 | self.assertTrue(tokenization._is_whitespace(u"\r")) 99 | self.assertTrue(tokenization._is_whitespace(u"\n")) 100 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 101 | 102 | self.assertFalse(tokenization._is_whitespace(u"A")) 103 | self.assertFalse(tokenization._is_whitespace(u"-")) 104 | 105 | def test_is_control(self): 106 | self.assertTrue(tokenization._is_control(u"\u0005")) 107 | 108 | self.assertFalse(tokenization._is_control(u"A")) 109 | self.assertFalse(tokenization._is_control(u" ")) 110 | self.assertFalse(tokenization._is_control(u"\t")) 111 | self.assertFalse(tokenization._is_control(u"\r")) 112 | 113 | def test_is_punctuation(self): 114 | self.assertTrue(tokenization._is_punctuation(u"-")) 115 | self.assertTrue(tokenization._is_punctuation(u"$")) 116 | self.assertTrue(tokenization._is_punctuation(u"`")) 117 | self.assertTrue(tokenization._is_punctuation(u".")) 118 | 119 | self.assertFalse(tokenization._is_punctuation(u"A")) 120 | self.assertFalse(tokenization._is_punctuation(u" ")) 121 | 122 | 123 | if __name__ == '__main__': 124 | unittest.main() 125 | -------------------------------------------------------------------------------- /tests/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import json 21 | import random 22 | 23 | import torch 24 | 25 | import modeling 26 | 27 | 28 | class BertModelTest(unittest.TestCase): 29 | class BertModelTester(object): 30 | 31 | def __init__(self, 32 | parent, 33 | batch_size=13, 34 | seq_length=7, 35 | is_training=True, 36 | use_input_mask=True, 37 | use_token_type_ids=True, 38 | vocab_size=99, 39 | hidden_size=32, 40 | num_hidden_layers=5, 41 | num_attention_heads=4, 42 | intermediate_size=37, 43 | hidden_act="gelu", 44 | hidden_dropout_prob=0.1, 45 | attention_probs_dropout_prob=0.1, 46 | max_position_embeddings=512, 47 | type_vocab_size=16, 48 | initializer_range=0.02, 49 | scope=None): 50 | self.parent = parent 51 | self.batch_size = batch_size 52 | self.seq_length = seq_length 53 | self.is_training = is_training 54 | self.use_input_mask = use_input_mask 55 | self.use_token_type_ids = use_token_type_ids 56 | self.vocab_size = vocab_size 57 | self.hidden_size = hidden_size 58 | self.num_hidden_layers = num_hidden_layers 59 | self.num_attention_heads = num_attention_heads 60 | self.intermediate_size = intermediate_size 61 | self.hidden_act = hidden_act 62 | self.hidden_dropout_prob = hidden_dropout_prob 63 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 64 | self.max_position_embeddings = max_position_embeddings 65 | self.type_vocab_size = type_vocab_size 66 | self.initializer_range = initializer_range 67 | self.scope = scope 68 | 69 | def create_model(self): 70 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 71 | 72 | input_mask = None 73 | if self.use_input_mask: 74 | input_mask = BertModelTest.ids_tensor([self.batch_size, self.seq_length], vocab_size=2) 75 | 76 | token_type_ids = None 77 | if self.use_token_type_ids: 78 | token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) 79 | 80 | config = modeling.BertConfig( 81 | vocab_size=self.vocab_size, 82 | hidden_size=self.hidden_size, 83 | num_hidden_layers=self.num_hidden_layers, 84 | num_attention_heads=self.num_attention_heads, 85 | intermediate_size=self.intermediate_size, 86 | hidden_act=self.hidden_act, 87 | hidden_dropout_prob=self.hidden_dropout_prob, 88 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 89 | max_position_embeddings=self.max_position_embeddings, 90 | type_vocab_size=self.type_vocab_size, 91 | initializer_range=self.initializer_range) 92 | 93 | model = modeling.BertModel(config=config) 94 | 95 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 96 | 97 | outputs = { 98 | "sequence_output": all_encoder_layers[-1], 99 | "pooled_output": pooled_output, 100 | "all_encoder_layers": all_encoder_layers, 101 | } 102 | return outputs 103 | 104 | def check_output(self, result): 105 | self.parent.assertListEqual( 106 | list(result["sequence_output"].size()), 107 | [self.batch_size, self.seq_length, self.hidden_size]) 108 | 109 | self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) 110 | 111 | def test_default(self): 112 | self.run_tester(BertModelTest.BertModelTester(self)) 113 | 114 | def test_config_to_json_string(self): 115 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 116 | obj = json.loads(config.to_json_string()) 117 | self.assertEqual(obj["vocab_size"], 99) 118 | self.assertEqual(obj["hidden_size"], 37) 119 | 120 | def run_tester(self, tester): 121 | output_result = tester.create_model() 122 | tester.check_output(output_result) 123 | 124 | @classmethod 125 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 126 | """Creates a random int32 tensor of the shape within the vocab size.""" 127 | if rng is None: 128 | rng = random.Random() 129 | 130 | total_dims = 1 131 | for dim in shape: 132 | total_dims *= dim 133 | 134 | values = [] 135 | for _ in range(total_dims): 136 | values.append(rng.randint(0, vocab_size - 1)) 137 | 138 | return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() 139 | 140 | 141 | if __name__ == "__main__": 142 | unittest.main() 143 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.nn.utils import clip_grad_norm_ 21 | 22 | def warmup_cosine(x, warmup=0.002): 23 | if x < warmup: 24 | return x/warmup 25 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 26 | 27 | def warmup_constant(x, warmup=0.002): 28 | if x < warmup: 29 | return x/warmup 30 | return 1.0 31 | 32 | def warmup_linear(x, warmup=0.002): 33 | if x < warmup: 34 | return x/warmup 35 | return 1.0 - x 36 | 37 | SCHEDULES = { 38 | 'warmup_cosine':warmup_cosine, 39 | 'warmup_constant':warmup_constant, 40 | 'warmup_linear':warmup_linear, 41 | } 42 | 43 | 44 | class BERTAdam(Optimizer): 45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 46 | Params: 47 | lr: learning rate 48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 49 | t_total: total number of training steps for the learning 50 | rate schedule, -1 means constant learning rate. Default: -1 51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 52 | b1: Adams b1. Default: 0.9 53 | b2: Adams b2. Default: 0.999 54 | e: Adams epsilon. Default: 1e-6 55 | weight_decay_rate: Weight decay. Default: 0.01 56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 57 | """ 58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 60 | max_grad_norm=1.0): 61 | if not lr >= 0.0: 62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 63 | if schedule not in SCHEDULES: 64 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 65 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 67 | if not 0.0 <= b1 < 1.0: 68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 69 | if not 0.0 <= b2 < 1.0: 70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 71 | if not e >= 0.0: 72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 75 | max_grad_norm=max_grad_norm) 76 | super(BERTAdam, self).__init__(params, defaults) 77 | 78 | def get_lr(self): 79 | lr = [] 80 | for group in self.param_groups: 81 | for p in group['params']: 82 | state = self.state[p] 83 | if len(state) == 0: 84 | return [0] 85 | if group['t_total'] != -1: 86 | schedule_fct = SCHEDULES[group['schedule']] 87 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 88 | else: 89 | lr_scheduled = group['lr'] 90 | lr.append(lr_scheduled) 91 | return lr 92 | 93 | def to(self, device): 94 | """ Move the optimizer state to a specified device""" 95 | for state in self.state.values(): 96 | state['exp_avg'].to(device) 97 | state['exp_avg_sq'].to(device) 98 | 99 | def initialize_step(self, initial_step): 100 | """Initialize state with a defined step (but we don't have stored averaged). 101 | Arguments: 102 | initial_step (int): Initial step number. 103 | """ 104 | for group in self.param_groups: 105 | for p in group['params']: 106 | state = self.state[p] 107 | # State initialization 108 | state['step'] = initial_step 109 | # Exponential moving average of gradient values 110 | state['exp_avg'] = torch.zeros_like(p.data) 111 | # Exponential moving average of squared gradient values 112 | state['exp_avg_sq'] = torch.zeros_like(p.data) 113 | 114 | def step(self, closure=None): 115 | """Performs a single optimization step. 116 | 117 | Arguments: 118 | closure (callable, optional): A closure that reevaluates the model 119 | and returns the loss. 120 | """ 121 | loss = None 122 | if closure is not None: 123 | loss = closure() 124 | 125 | for group in self.param_groups: 126 | for p in group['params']: 127 | if p.grad is None: 128 | continue 129 | grad = p.grad.data 130 | if grad.is_sparse: 131 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 132 | 133 | state = self.state[p] 134 | 135 | # State initialization 136 | if len(state) == 0: 137 | state['step'] = 0 138 | # Exponential moving average of gradient values 139 | state['next_m'] = torch.zeros_like(p.data) 140 | # Exponential moving average of squared gradient values 141 | state['next_v'] = torch.zeros_like(p.data) 142 | 143 | next_m, next_v = state['next_m'], state['next_v'] 144 | beta1, beta2 = group['b1'], group['b2'] 145 | 146 | # Add grad clipping 147 | if group['max_grad_norm'] > 0: 148 | clip_grad_norm_(p, group['max_grad_norm']) 149 | 150 | # Decay the first and second moment running average coefficient 151 | # In-place operations to update the averages at the same time 152 | next_m.mul_(beta1).add_(1 - beta1, grad) 153 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 154 | update = next_m / (next_v.sqrt() + group['e']) 155 | 156 | # Just adding the square of the weights to the loss function is *not* 157 | # the correct way of using L2 regularization/weight decay with Adam, 158 | # since that will interact with the m and v parameters in strange ways. 159 | # 160 | # Instead we want ot decay the weights in a manner that doesn't interact 161 | # with the m/v parameters. This is equivalent to adding the square 162 | # of the weights to the loss with plain (non-momentum) SGD. 163 | if group['weight_decay_rate'] > 0.0: 164 | update += group['weight_decay_rate'] * p.data 165 | 166 | if group['t_total'] != -1: 167 | schedule_fct = SCHEDULES[group['schedule']] 168 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 169 | else: 170 | lr_scheduled = group['lr'] 171 | 172 | update_with_lr = lr_scheduled * update 173 | p.data.add_(-update_with_lr) 174 | 175 | state['step'] += 1 176 | 177 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 178 | # bias_correction1 = 1 - beta1 ** state['step'] 179 | # bias_correction2 = 1 - beta2 ** state['step'] 180 | 181 | return loss 182 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | index = 0 73 | with open(vocab_file, "r") as reader: 74 | while True: 75 | token = convert_to_unicode(reader.readline()) 76 | if not token: 77 | break 78 | token = token.strip() 79 | vocab[token] = index 80 | index += 1 81 | return vocab 82 | 83 | 84 | def convert_tokens_to_ids(vocab, tokens): 85 | """Converts a sequence of tokens into ids using the vocab.""" 86 | ids = [] 87 | for token in tokens: 88 | ids.append(vocab[token]) 89 | return ids 90 | 91 | 92 | def whitespace_tokenize(text): 93 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 94 | text = text.strip() 95 | if not text: 96 | return [] 97 | tokens = text.split() 98 | return tokens 99 | 100 | 101 | class FullTokenizer(object): 102 | """Runs end-to-end tokenziation.""" 103 | 104 | def __init__(self, vocab_file, do_lower_case=True): 105 | self.vocab = load_vocab(vocab_file) 106 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 107 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 108 | 109 | def tokenize(self, text): 110 | split_tokens = [] 111 | for token in self.basic_tokenizer.tokenize(text): 112 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 113 | split_tokens.append(sub_token) 114 | 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | return convert_tokens_to_ids(self.vocab, tokens) 119 | 120 | 121 | class BasicTokenizer(object): 122 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 123 | 124 | def __init__(self, do_lower_case=True): 125 | """Constructs a BasicTokenizer. 126 | 127 | Args: 128 | do_lower_case: Whether to lower case the input. 129 | """ 130 | self.do_lower_case = do_lower_case 131 | 132 | def tokenize(self, text): 133 | """Tokenizes a piece of text.""" 134 | text = convert_to_unicode(text) 135 | text = self._clean_text(text) 136 | orig_tokens = whitespace_tokenize(text) 137 | split_tokens = [] 138 | for token in orig_tokens: 139 | if self.do_lower_case: 140 | token = token.lower() 141 | token = self._run_strip_accents(token) 142 | split_tokens.extend(self._run_split_on_punc(token)) 143 | 144 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 145 | return output_tokens 146 | 147 | def _run_strip_accents(self, text): 148 | """Strips accents from a piece of text.""" 149 | text = unicodedata.normalize("NFD", text) 150 | output = [] 151 | for char in text: 152 | cat = unicodedata.category(char) 153 | if cat == "Mn": 154 | continue 155 | output.append(char) 156 | return "".join(output) 157 | 158 | def _run_split_on_punc(self, text): 159 | """Splits punctuation on a piece of text.""" 160 | chars = list(text) 161 | i = 0 162 | start_new_word = True 163 | output = [] 164 | while i < len(chars): 165 | char = chars[i] 166 | if _is_punctuation(char): 167 | output.append([char]) 168 | start_new_word = True 169 | else: 170 | if start_new_word: 171 | output.append([]) 172 | start_new_word = False 173 | output[-1].append(char) 174 | i += 1 175 | 176 | return ["".join(x) for x in output] 177 | 178 | def _clean_text(self, text): 179 | """Performs invalid character removal and whitespace cleanup on text.""" 180 | output = [] 181 | for char in text: 182 | cp = ord(char) 183 | if cp == 0 or cp == 0xfffd or _is_control(char): 184 | continue 185 | if _is_whitespace(char): 186 | output.append(" ") 187 | else: 188 | output.append(char) 189 | return "".join(output) 190 | 191 | 192 | class WordpieceTokenizer(object): 193 | """Runs WordPiece tokenization.""" 194 | 195 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 196 | self.vocab = vocab 197 | self.unk_token = unk_token 198 | self.max_input_chars_per_word = max_input_chars_per_word 199 | 200 | def tokenize(self, text): 201 | """Tokenizes a piece of text into its word pieces. 202 | 203 | This uses a greedy longest-match-first algorithm to perform tokenization 204 | using the given vocabulary. 205 | 206 | For example: 207 | input = "unaffable" 208 | output = ["un", "##aff", "##able"] 209 | 210 | Args: 211 | text: A single token or whitespace separated tokens. This should have 212 | already been passed through `BasicTokenizer. 213 | 214 | Returns: 215 | A list of wordpiece tokens. 216 | """ 217 | 218 | text = convert_to_unicode(text) 219 | 220 | output_tokens = [] 221 | for token in whitespace_tokenize(text): 222 | chars = list(token) 223 | if len(chars) > self.max_input_chars_per_word: 224 | output_tokens.append(self.unk_token) 225 | continue 226 | 227 | is_bad = False 228 | start = 0 229 | sub_tokens = [] 230 | while start < len(chars): 231 | end = len(chars) 232 | cur_substr = None 233 | while start < end: 234 | substr = "".join(chars[start:end]) 235 | if start > 0: 236 | substr = "##" + substr 237 | if substr in self.vocab: 238 | cur_substr = substr 239 | break 240 | end -= 1 241 | if cur_substr is None: 242 | is_bad = True 243 | break 244 | sub_tokens.append(cur_substr) 245 | start = end 246 | 247 | if is_bad: 248 | output_tokens.append(self.unk_token) 249 | else: 250 | output_tokens.extend(sub_tokens) 251 | return output_tokens 252 | 253 | 254 | def _is_whitespace(char): 255 | """Checks whether `chars` is a whitespace character.""" 256 | # \t, \n, and \r are technically contorl characters but we treat them 257 | # as whitespace since they are generally considered as such. 258 | if char == " " or char == "\t" or char == "\n" or char == "\r": 259 | return True 260 | cat = unicodedata.category(char) 261 | if cat == "Zs": 262 | return True 263 | return False 264 | 265 | 266 | def _is_control(char): 267 | """Checks whether `chars` is a control character.""" 268 | # These are technically control characters but we count them as whitespace 269 | # characters. 270 | if char == "\t" or char == "\n" or char == "\r": 271 | return False 272 | cat = unicodedata.category(char) 273 | if cat.startswith("C"): 274 | return True 275 | return False 276 | 277 | 278 | def _is_punctuation(char): 279 | """Checks whether `chars` is a punctuation character.""" 280 | cp = ord(char) 281 | # We treat all non-letter/number ASCII as punctuation. 282 | # Characters such as "^", "$", and "`" are not in the Unicode 283 | # Punctuation class but we treat them as punctuation anyways, for 284 | # consistency. 285 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 286 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 287 | return True 288 | cat = unicodedata.category(char) 289 | if cat.startswith("P"): 290 | return True 291 | return False 292 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from a PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import codecs 23 | import collections 24 | import logging 25 | import json 26 | import re 27 | 28 | import torch 29 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 30 | from torch.utils.data.distributed import DistributedSampler 31 | 32 | import tokenization 33 | from modeling import BertConfig, BertModel 34 | 35 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 36 | datefmt = '%m/%d/%Y %H:%M:%S', 37 | level = logging.INFO) 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | class InputExample(object): 42 | 43 | def __init__(self, unique_id, text_a, text_b): 44 | self.unique_id = unique_id 45 | self.text_a = text_a 46 | self.text_b = text_b 47 | 48 | 49 | class InputFeatures(object): 50 | """A single set of features of data.""" 51 | 52 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 53 | self.unique_id = unique_id 54 | self.tokens = tokens 55 | self.input_ids = input_ids 56 | self.input_mask = input_mask 57 | self.input_type_ids = input_type_ids 58 | 59 | 60 | def convert_examples_to_features(examples, seq_length, tokenizer): 61 | """Loads a data file into a list of `InputBatch`s.""" 62 | 63 | features = [] 64 | for (ex_index, example) in enumerate(examples): 65 | tokens_a = tokenizer.tokenize(example.text_a) 66 | 67 | tokens_b = None 68 | if example.text_b: 69 | tokens_b = tokenizer.tokenize(example.text_b) 70 | 71 | if tokens_b: 72 | # Modifies `tokens_a` and `tokens_b` in place so that the total 73 | # length is less than the specified length. 74 | # Account for [CLS], [SEP], [SEP] with "- 3" 75 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 76 | else: 77 | # Account for [CLS] and [SEP] with "- 2" 78 | if len(tokens_a) > seq_length - 2: 79 | tokens_a = tokens_a[0:(seq_length - 2)] 80 | 81 | # The convention in BERT is: 82 | # (a) For sequence pairs: 83 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 84 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 85 | # (b) For single sequences: 86 | # tokens: [CLS] the dog is hairy . [SEP] 87 | # type_ids: 0 0 0 0 0 0 0 88 | # 89 | # Where "type_ids" are used to indicate whether this is the first 90 | # sequence or the second sequence. The embedding vectors for `type=0` and 91 | # `type=1` were learned during pre-training and are added to the wordpiece 92 | # embedding vector (and position vector). This is not *strictly* necessary 93 | # since the [SEP] token unambigiously separates the sequences, but it makes 94 | # it easier for the model to learn the concept of sequences. 95 | # 96 | # For classification tasks, the first vector (corresponding to [CLS]) is 97 | # used as as the "sentence vector". Note that this only makes sense because 98 | # the entire model is fine-tuned. 99 | tokens = [] 100 | input_type_ids = [] 101 | tokens.append("[CLS]") 102 | input_type_ids.append(0) 103 | for token in tokens_a: 104 | tokens.append(token) 105 | input_type_ids.append(0) 106 | tokens.append("[SEP]") 107 | input_type_ids.append(0) 108 | 109 | if tokens_b: 110 | for token in tokens_b: 111 | tokens.append(token) 112 | input_type_ids.append(1) 113 | tokens.append("[SEP]") 114 | input_type_ids.append(1) 115 | 116 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 117 | 118 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 119 | # tokens are attended to. 120 | input_mask = [1] * len(input_ids) 121 | 122 | # Zero-pad up to the sequence length. 123 | while len(input_ids) < seq_length: 124 | input_ids.append(0) 125 | input_mask.append(0) 126 | input_type_ids.append(0) 127 | 128 | assert len(input_ids) == seq_length 129 | assert len(input_mask) == seq_length 130 | assert len(input_type_ids) == seq_length 131 | 132 | if ex_index < 5: 133 | logger.info("*** Example ***") 134 | logger.info("unique_id: %s" % (example.unique_id)) 135 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) 136 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 137 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 138 | logger.info( 139 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 140 | 141 | features.append( 142 | InputFeatures( 143 | unique_id=example.unique_id, 144 | tokens=tokens, 145 | input_ids=input_ids, 146 | input_mask=input_mask, 147 | input_type_ids=input_type_ids)) 148 | return features 149 | 150 | 151 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 152 | """Truncates a sequence pair in place to the maximum length.""" 153 | 154 | # This is a simple heuristic which will always truncate the longer sequence 155 | # one token at a time. This makes more sense than truncating an equal percent 156 | # of tokens from each, since if one sequence is very short then each token 157 | # that's truncated likely contains more information than a longer sequence. 158 | while True: 159 | total_length = len(tokens_a) + len(tokens_b) 160 | if total_length <= max_length: 161 | break 162 | if len(tokens_a) > len(tokens_b): 163 | tokens_a.pop() 164 | else: 165 | tokens_b.pop() 166 | 167 | 168 | def read_examples(input_file): 169 | """Read a list of `InputExample`s from an input file.""" 170 | examples = [] 171 | unique_id = 0 172 | with open(input_file, "r") as reader: 173 | while True: 174 | line = tokenization.convert_to_unicode(reader.readline()) 175 | if not line: 176 | break 177 | line = line.strip() 178 | text_a = None 179 | text_b = None 180 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 181 | if m is None: 182 | text_a = line 183 | else: 184 | text_a = m.group(1) 185 | text_b = m.group(2) 186 | examples.append( 187 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 188 | unique_id += 1 189 | return examples 190 | 191 | 192 | def main(): 193 | parser = argparse.ArgumentParser() 194 | 195 | ## Required parameters 196 | parser.add_argument("--input_file", default=None, type=str, required=True) 197 | parser.add_argument("--vocab_file", default=None, type=str, required=True, 198 | help="The vocabulary file that the BERT model was trained on.") 199 | parser.add_argument("--output_file", default=None, type=str, required=True) 200 | parser.add_argument("--bert_config_file", default=None, type=str, required=True, 201 | help="The config json file corresponding to the pre-trained BERT model. " 202 | "This specifies the model architecture.") 203 | parser.add_argument("--init_checkpoint", default=None, type=str, required=True, 204 | help="Initial checkpoint (usually from a pre-trained BERT model).") 205 | 206 | ## Other parameters 207 | parser.add_argument("--layers", default="-1,-2,-3,-4", type=str) 208 | parser.add_argument("--max_seq_length", default=128, type=int, 209 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer " 210 | "than this will be truncated, and sequences shorter than this will be padded.") 211 | parser.add_argument("--do_lower_case", default=True, action='store_true', 212 | help="Whether to lower case the input text. Should be True for uncased " 213 | "models and False for cased models.") 214 | parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.") 215 | parser.add_argument("--local_rank", 216 | type=int, 217 | default=-1, 218 | help = "local_rank for distributed training on gpus") 219 | 220 | args = parser.parse_args() 221 | 222 | if args.local_rank == -1 or args.no_cuda: 223 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 224 | n_gpu = torch.cuda.device_count() 225 | else: 226 | device = torch.device("cuda", args.local_rank) 227 | n_gpu = 1 228 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 229 | torch.distributed.init_process_group(backend='nccl') 230 | logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) 231 | 232 | layer_indexes = [int(x) for x in args.layers.split(",")] 233 | 234 | bert_config = BertConfig.from_json_file(args.bert_config_file) 235 | 236 | tokenizer = tokenization.FullTokenizer( 237 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) 238 | 239 | examples = read_examples(args.input_file) 240 | 241 | features = convert_examples_to_features( 242 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer) 243 | 244 | unique_id_to_feature = {} 245 | for feature in features: 246 | unique_id_to_feature[feature.unique_id] = feature 247 | 248 | model = BertModel(bert_config) 249 | if args.init_checkpoint is not None: 250 | model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 251 | model.to(device) 252 | 253 | if args.local_rank != -1: 254 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 255 | output_device=args.local_rank) 256 | elif n_gpu > 1: 257 | model = torch.nn.DataParallel(model) 258 | 259 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 260 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 261 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 262 | 263 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) 264 | if args.local_rank == -1: 265 | eval_sampler = SequentialSampler(eval_data) 266 | else: 267 | eval_sampler = DistributedSampler(eval_data) 268 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) 269 | 270 | model.eval() 271 | with open(args.output_file, "w", encoding='utf-8') as writer: 272 | for input_ids, input_mask, example_indices in eval_dataloader: 273 | input_ids = input_ids.to(device) 274 | input_mask = input_mask.to(device) 275 | 276 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) 277 | all_encoder_layers = all_encoder_layers 278 | 279 | for b, example_index in enumerate(example_indices): 280 | feature = features[example_index.item()] 281 | unique_id = int(feature.unique_id) 282 | # feature = unique_id_to_feature[unique_id] 283 | output_json = collections.OrderedDict() 284 | output_json["linex_index"] = unique_id 285 | all_out_features = [] 286 | for (i, token) in enumerate(feature.tokens): 287 | all_layers = [] 288 | for (j, layer_index) in enumerate(layer_indexes): 289 | layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy() 290 | layer_output = layer_output[b] 291 | layers = collections.OrderedDict() 292 | layers["index"] = layer_index 293 | layers["values"] = [ 294 | round(x.item(), 6) for x in layer_output[i] 295 | ] 296 | all_layers.append(layers) 297 | out_features = collections.OrderedDict() 298 | out_features["token"] = token 299 | out_features["layers"] = all_layers 300 | all_out_features.append(out_features) 301 | output_json["features"] = all_out_features 302 | writer.write(json.dumps(output_json) + "\n") 303 | 304 | 305 | if __name__ == "__main__": 306 | main() 307 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of Google AI's BERT model with a script to load Google's pre-trained models 2 | 3 | ## Introduction 4 | 5 | This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. 6 | 7 | This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script is provided (see below). 8 | 9 | The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated). 10 | 11 | ## Loading a TensorFlow checkpoint (e.g. [Google's pre-trained models](https://github.com/google-research/bert#pre-trained-models)) 12 | 13 | You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script. 14 | 15 | This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`). 16 | 17 | You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too. 18 | 19 | To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch. 20 | 21 | Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model: 22 | 23 | ```shell 24 | export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 25 | 26 | python convert_tf_checkpoint_to_pytorch.py \ 27 | --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \ 28 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 29 | --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin 30 | ``` 31 | 32 | You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models). 33 | 34 | ## PyTorch models for BERT 35 | 36 | We included three PyTorch models in this repository that you will find in [`modeling.py`](modeling.py): 37 | 38 | - `BertModel` - the basic BERT Transformer model 39 | - `BertForSequenceClassification` - the BERT model with a sequence classification head on top 40 | - `BertForQuestionAnswering` - the BERT model with a token classification head on top 41 | 42 | Here are some details on each class. 43 | 44 | ### 1. `BertModel` 45 | 46 | `BertModel` is the basic BERT Transformer model with a layer of summed token, position and sequence embeddings followed by a series of identical self-attention blocks (12 for BERT-base, 24 for BERT-large). 47 | 48 | The inputs and output are **identical to the TensorFlow model inputs and outputs**. 49 | 50 | We detail them here. This model takes as inputs: 51 | 52 | - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and 53 | - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 54 | - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. 55 | 56 | This model outputs a tuple composed of: 57 | 58 | - `all_encoder_layers`: a list of torch.FloatTensor of size [batch_size, sequence_length, hidden_size] which is a list of the full sequences of hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), and 59 | - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper). 60 | 61 | An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input. 62 | 63 | ### 2. `BertForSequenceClassification` 64 | 65 | `BertForSequenceClassification` is a fine-tuning model that includes `BertModel` and a sequence-level (sequence or pair of sequences) classifier on top of the `BertModel`. 66 | 67 | The sequence-level classifier is a linear layer that takes as input the last hidden state of the first character in the input sequence (see Figures 3a and 3b in the BERT paper). 68 | 69 | An example on how to use this class is given in the `run_classifier.py` script which can be used to fine-tune a single sequence (or pair of sequence) classifier using BERT, for example for the MRPC task. 70 | 71 | ### 3. `BertForQuestionAnswering` 72 | 73 | `BertForQuestionAnswering` is a fine-tuning model that includes `BertModel` with a token-level classifiers on top of the full sequence of last hidden states. 74 | 75 | The token-level classifier takes as input the full sequence of the last hidden state and compute several (e.g. two) scores for each tokens that can for example respectively be the score that a given token is a `start_span` and a `end_span` token (see Figures 3c and 3d in the BERT paper). 76 | 77 | An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task. 78 | 79 | ## Installation, requirements, test 80 | 81 | This code was tested on Python 3.5+. The requirements are: 82 | 83 | - PyTorch (>= 0.4.1) 84 | - tqdm 85 | 86 | To install the dependencies: 87 | 88 | ````bash 89 | pip install -r ./requirements.txt 90 | ```` 91 | 92 | A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`). 93 | 94 | You can run the tests with the command: 95 | ```bash 96 | pytest -sv ./tests/ 97 | ``` 98 | 99 | ## Training on large batches: gradient accumulation, multi-GPU and distributed training 100 | 101 | BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32). 102 | 103 | To help with fine-tuning these models, we have included three techniques that you can activate in the fine-tuning scripts `run_classifier.py` and `run_squad.py`: gradient-accumulation, multi-gpu and distributed training. For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month. 104 | 105 | Here is how to use these techniques in our scripts: 106 | 107 | - **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps. 108 | - **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs. 109 | - **Distributed training**: Distributed training can be activated by suppying an integer greater or equal to 0 to the `--local_rank` argument. To use Distributed training, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see the above blog post for more details): 110 | 111 | ```bash 112 | python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=$THIS_MACHINE_INDEX --master_addr="192.168.1.1" --master_port=1234 run_classifier.py (--arg1 --arg2 --arg3 and all other arguments of the run_classifier script) 113 | ``` 114 | 115 | Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your machine (0, 1, 2...) and the machine with rank 0 has an IP adress `192.168.1.1` and an open port `1234`. 116 | 117 | ## TPU support and pretraining scripts 118 | 119 | TPU are not supported by the current stable release of PyTorch (0.4.1). However, the next version of PyTorch (v1.0) should support training on TPU and is expected to be released soon (see the recent [official announcement](https://cloud.google.com/blog/products/ai-machine-learning/introducing-pytorch-across-google-cloud)). 120 | 121 | We will add TPU support when this next release is published. 122 | 123 | The original TensorFlow code further comprises two scripts for pre-training BERT: [create_pretraining_data.py](https://github.com/google-research/bert/blob/master/create_pretraining_data.py) and [run_pretraining.py](https://github.com/google-research/bert/blob/master/run_pretraining.py). 124 | 125 | Since, pre-training BERT is a particularly expensive operation that basically requires one or several TPUs to be completed in a reasonable amout of time (see details [here](https://github.com/google-research/bert#pre-training-with-bert)) we have decided to wait for the inclusion of TPU support in PyTorch to convert these pre-training scripts. 126 | 127 | ## Comparing the PyTorch model and the TensorFlow model predictions 128 | 129 | We also include [two Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model. 130 | 131 | - The first NoteBook ([Comparing TF and PT models.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models.ipynb)) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the sandard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models. 132 | 133 | - The second NoteBook ([Comparing TF and PT models SQuAD predictions.ipynb](https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/notebooks/Comparing%20TF%20and%20PT%20models%20SQuAD%20predictions.ipynb)) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the `BertForQuestionAnswering` and computes the sandard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models. 134 | 135 | Please follow the instructions given in the notebooks to run and modify them. They can also be nice example on how to use the models in a simpler way than the full fine-tuning scripts we provide. 136 | 137 | ## Fine-tuning with BERT: running the examples 138 | 139 | We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD. 140 | 141 | Before running theses examples you should download the 142 | [GLUE data](https://gluebenchmark.com/tasks) by running 143 | [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) 144 | and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base` 145 | checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section. 146 | 147 | This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase 148 | Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80. 149 | 150 | ```shell 151 | export GLUE_DIR=/path/to/glue 152 | 153 | python run_classifier.py \ 154 | --task_name MRPC \ 155 | --do_train \ 156 | --do_eval \ 157 | --do_lower_case \ 158 | --data_dir $GLUE_DIR/MRPC/ \ 159 | --vocab_file $BERT_BASE_DIR/vocab.txt \ 160 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 161 | --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 162 | --max_seq_length 128 \ 163 | --train_batch_size 32 \ 164 | --learning_rate 2e-5 \ 165 | --num_train_epochs 3.0 \ 166 | --output_dir /tmp/mrpc_output/ 167 | ``` 168 | 169 | Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 82 and 87. 170 | 171 | The second example fine-tunes `BERT-Base` on the SQuAD question answering task. 172 | 173 | The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory. 174 | 175 | * [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json) 176 | * [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json) 177 | * [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py) 178 | 179 | ```shell 180 | export SQUAD_DIR=/path/to/SQUAD 181 | 182 | python run_squad.py \ 183 | --vocab_file $BERT_BASE_DIR/vocab.txt \ 184 | --bert_config_file $BERT_BASE_DIR/bert_config.json \ 185 | --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin \ 186 | --do_train \ 187 | --train_file $SQUAD_DIR/train-v1.1.json \ 188 | --do_predict \ 189 | --predict_file $SQUAD_DIR/dev-v1.1.json \ 190 | --train_batch_size 12 \ 191 | --learning_rate 5e-5 \ 192 | --num_train_epochs 2.0 \ 193 | --max_seq_length 384 \ 194 | --doc_stride 128 \ 195 | --output_dir ../debug_squad/ 196 | ``` 197 | 198 | There is currently a bug in the `run_squad.py` script that we are investigating. The reported numbers are very low (F1 of 41.8 and exact match of 21.7) even though the correct answer is usually in the n-best predictions. We are investigating that right now on the develop branch, follow [this issue](https://github.com/huggingface/pytorch-pretrained-BERT/issues/3) for more updates. 199 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import json 23 | import math 24 | import six 25 | import torch 26 | import torch.nn as nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | def gelu(x): 30 | """Implementation of the gelu activation function. 31 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 32 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 33 | """ 34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 35 | 36 | 37 | class BertConfig(object): 38 | """Configuration class to store the configuration of a `BertModel`. 39 | """ 40 | def __init__(self, 41 | vocab_size, 42 | hidden_size=768, 43 | num_hidden_layers=12, 44 | num_attention_heads=12, 45 | intermediate_size=3072, 46 | hidden_act="gelu", 47 | hidden_dropout_prob=0.1, 48 | attention_probs_dropout_prob=0.1, 49 | max_position_embeddings=512, 50 | type_vocab_size=16, 51 | initializer_range=0.02): 52 | """Constructs BertConfig. 53 | 54 | Args: 55 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 56 | hidden_size: Size of the encoder layers and the pooler layer. 57 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 58 | num_attention_heads: Number of attention heads for each attention layer in 59 | the Transformer encoder. 60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 61 | layer in the Transformer encoder. 62 | hidden_act: The non-linear activation function (function or string) in the 63 | encoder and pooler. 64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 65 | layers in the embeddings, encoder, and pooler. 66 | attention_probs_dropout_prob: The dropout ratio for the attention 67 | probabilities. 68 | max_position_embeddings: The maximum sequence length that this model might 69 | ever be used with. Typically set this to something large just in case 70 | (e.g., 512 or 1024 or 2048). 71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 72 | `BertModel`. 73 | initializer_range: The sttdev of the truncated_normal_initializer for 74 | initializing all weight matrices. 75 | """ 76 | self.vocab_size = vocab_size 77 | self.hidden_size = hidden_size 78 | self.num_hidden_layers = num_hidden_layers 79 | self.num_attention_heads = num_attention_heads 80 | self.hidden_act = hidden_act 81 | self.intermediate_size = intermediate_size 82 | self.hidden_dropout_prob = hidden_dropout_prob 83 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 84 | self.max_position_embeddings = max_position_embeddings 85 | self.type_vocab_size = type_vocab_size 86 | self.initializer_range = initializer_range 87 | 88 | @classmethod 89 | def from_dict(cls, json_object): 90 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 91 | config = BertConfig(vocab_size=None) 92 | for (key, value) in six.iteritems(json_object): 93 | config.__dict__[key] = value 94 | return config 95 | 96 | @classmethod 97 | def from_json_file(cls, json_file): 98 | """Constructs a `BertConfig` from a json file of parameters.""" 99 | with open(json_file, "r") as reader: 100 | text = reader.read() 101 | return cls.from_dict(json.loads(text)) 102 | 103 | def to_dict(self): 104 | """Serializes this instance to a Python dictionary.""" 105 | output = copy.deepcopy(self.__dict__) 106 | return output 107 | 108 | def to_json_string(self): 109 | """Serializes this instance to a JSON string.""" 110 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 111 | 112 | 113 | class BERTLayerNorm(nn.Module): 114 | def __init__(self, config, variance_epsilon=1e-12): 115 | """Construct a layernorm module in the TF style (epsilon inside the square root). 116 | """ 117 | super(BERTLayerNorm, self).__init__() 118 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 119 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 120 | self.variance_epsilon = variance_epsilon 121 | 122 | def forward(self, x): 123 | u = x.mean(-1, keepdim=True) 124 | s = (x - u).pow(2).mean(-1, keepdim=True) 125 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 126 | return self.gamma * x + self.beta 127 | 128 | class BERTEmbeddings(nn.Module): 129 | def __init__(self, config): 130 | super(BERTEmbeddings, self).__init__() 131 | """Construct the embedding module from word, position and token_type embeddings. 132 | """ 133 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 134 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 135 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 136 | 137 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 138 | # any TensorFlow checkpoint file 139 | self.LayerNorm = BERTLayerNorm(config) 140 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 141 | 142 | def forward(self, input_ids, token_type_ids=None): 143 | seq_length = input_ids.size(1) 144 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 145 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 146 | if token_type_ids is None: 147 | token_type_ids = torch.zeros_like(input_ids) 148 | 149 | words_embeddings = self.word_embeddings(input_ids) 150 | position_embeddings = self.position_embeddings(position_ids) 151 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 152 | 153 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 154 | embeddings = self.LayerNorm(embeddings) 155 | embeddings = self.dropout(embeddings) 156 | return embeddings 157 | 158 | 159 | class BERTSelfAttention(nn.Module): 160 | def __init__(self, config): 161 | super(BERTSelfAttention, self).__init__() 162 | if config.hidden_size % config.num_attention_heads != 0: 163 | raise ValueError( 164 | "The hidden size (%d) is not a multiple of the number of attention " 165 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 166 | self.num_attention_heads = config.num_attention_heads 167 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 168 | self.all_head_size = self.num_attention_heads * self.attention_head_size 169 | 170 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 171 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 172 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 173 | 174 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 175 | 176 | def transpose_for_scores(self, x): 177 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 178 | x = x.view(*new_x_shape) 179 | return x.permute(0, 2, 1, 3) 180 | 181 | def forward(self, hidden_states, attention_mask): 182 | mixed_query_layer = self.query(hidden_states) 183 | mixed_key_layer = self.key(hidden_states) 184 | mixed_value_layer = self.value(hidden_states) 185 | 186 | query_layer = self.transpose_for_scores(mixed_query_layer) 187 | key_layer = self.transpose_for_scores(mixed_key_layer) 188 | value_layer = self.transpose_for_scores(mixed_value_layer) 189 | 190 | # Take the dot product between "query" and "key" to get the raw attention scores. 191 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 192 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 193 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 194 | attention_scores = attention_scores + attention_mask 195 | 196 | # Normalize the attention scores to probabilities. 197 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 198 | 199 | # This is actually dropping out entire tokens to attend to, which might 200 | # seem a bit unusual, but is taken from the original Transformer paper. 201 | attention_probs = self.dropout(attention_probs) 202 | 203 | context_layer = torch.matmul(attention_probs, value_layer) 204 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 205 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 206 | context_layer = context_layer.view(*new_context_layer_shape) 207 | return context_layer 208 | 209 | 210 | class BERTSelfOutput(nn.Module): 211 | def __init__(self, config): 212 | super(BERTSelfOutput, self).__init__() 213 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 214 | self.LayerNorm = BERTLayerNorm(config) 215 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 216 | 217 | def forward(self, hidden_states, input_tensor): 218 | hidden_states = self.dense(hidden_states) 219 | hidden_states = self.dropout(hidden_states) 220 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 221 | return hidden_states 222 | 223 | 224 | class BERTAttention(nn.Module): 225 | def __init__(self, config): 226 | super(BERTAttention, self).__init__() 227 | self.self = BERTSelfAttention(config) 228 | self.output = BERTSelfOutput(config) 229 | 230 | def forward(self, input_tensor, attention_mask): 231 | self_output = self.self(input_tensor, attention_mask) 232 | attention_output = self.output(self_output, input_tensor) 233 | return attention_output 234 | 235 | 236 | class BERTIntermediate(nn.Module): 237 | def __init__(self, config): 238 | super(BERTIntermediate, self).__init__() 239 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 240 | self.intermediate_act_fn = gelu 241 | 242 | def forward(self, hidden_states): 243 | hidden_states = self.dense(hidden_states) 244 | hidden_states = self.intermediate_act_fn(hidden_states) 245 | return hidden_states 246 | 247 | 248 | class BERTOutput(nn.Module): 249 | def __init__(self, config): 250 | super(BERTOutput, self).__init__() 251 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 252 | self.LayerNorm = BERTLayerNorm(config) 253 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 254 | 255 | def forward(self, hidden_states, input_tensor): 256 | hidden_states = self.dense(hidden_states) 257 | hidden_states = self.dropout(hidden_states) 258 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 259 | return hidden_states 260 | 261 | 262 | class BERTLayer(nn.Module): 263 | def __init__(self, config): 264 | super(BERTLayer, self).__init__() 265 | self.attention = BERTAttention(config) 266 | self.intermediate = BERTIntermediate(config) 267 | self.output = BERTOutput(config) 268 | 269 | def forward(self, hidden_states, attention_mask): 270 | attention_output = self.attention(hidden_states, attention_mask) 271 | intermediate_output = self.intermediate(attention_output) 272 | layer_output = self.output(intermediate_output, attention_output) 273 | return layer_output 274 | 275 | 276 | class BERTEncoder(nn.Module): 277 | def __init__(self, config): 278 | super(BERTEncoder, self).__init__() 279 | layer = BERTLayer(config) 280 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 281 | 282 | def forward(self, hidden_states, attention_mask): 283 | all_encoder_layers = [] 284 | for layer_module in self.layer: 285 | hidden_states = layer_module(hidden_states, attention_mask) 286 | all_encoder_layers.append(hidden_states) 287 | return all_encoder_layers 288 | 289 | 290 | class BERTPooler(nn.Module): 291 | def __init__(self, config): 292 | super(BERTPooler, self).__init__() 293 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 294 | self.activation = nn.Tanh() 295 | 296 | def forward(self, hidden_states): 297 | # We "pool" the model by simply taking the hidden state corresponding 298 | # to the first token. 299 | first_token_tensor = hidden_states[:, 0] 300 | pooled_output = self.dense(first_token_tensor) 301 | pooled_output = self.activation(pooled_output) 302 | return pooled_output 303 | 304 | 305 | class BertModel(nn.Module): 306 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 307 | 308 | Example usage: 309 | ```python 310 | # Already been converted into WordPiece token ids 311 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 312 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 313 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 314 | 315 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 316 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 317 | 318 | model = modeling.BertModel(config=config) 319 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 320 | ``` 321 | """ 322 | def __init__(self, config: BertConfig): 323 | """Constructor for BertModel. 324 | 325 | Args: 326 | config: `BertConfig` instance. 327 | """ 328 | super(BertModel, self).__init__() 329 | self.embeddings = BERTEmbeddings(config) 330 | self.encoder = BERTEncoder(config) 331 | self.pooler = BERTPooler(config) 332 | 333 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 334 | if attention_mask is None: 335 | attention_mask = torch.ones_like(input_ids) 336 | if token_type_ids is None: 337 | token_type_ids = torch.zeros_like(input_ids) 338 | 339 | # We create a 3D attention mask from a 2D tensor mask. 340 | # Sizes are [batch_size, 1, 1, from_seq_length] 341 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] 342 | # this attention mask is more simple than the triangular masking of causal attention 343 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 344 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 345 | 346 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 347 | # masked positions, this operation will create a tensor which is 0.0 for 348 | # positions we want to attend and -10000.0 for masked positions. 349 | # Since we are adding it to the raw scores before the softmax, this is 350 | # effectively the same as removing these entirely. 351 | extended_attention_mask = extended_attention_mask.float() 352 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 353 | 354 | embedding_output = self.embeddings(input_ids, token_type_ids) 355 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 356 | sequence_output = all_encoder_layers[-1] 357 | pooled_output = self.pooler(sequence_output) 358 | return all_encoder_layers, pooled_output 359 | 360 | class BertForSequenceClassification(nn.Module): 361 | """BERT model for classification. 362 | This module is composed of the BERT model with a linear layer on top of 363 | the pooled output. 364 | 365 | Example usage: 366 | ```python 367 | # Already been converted into WordPiece token ids 368 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 369 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 370 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 371 | 372 | config = BertConfig(vocab_size=32000, hidden_size=512, 373 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 374 | 375 | num_labels = 2 376 | 377 | model = BertForSequenceClassification(config, num_labels) 378 | logits = model(input_ids, token_type_ids, input_mask) 379 | ``` 380 | """ 381 | def __init__(self, config, num_labels): 382 | super(BertForSequenceClassification, self).__init__() 383 | self.bert = BertModel(config) 384 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 385 | self.classifier = nn.Linear(config.hidden_size, num_labels) 386 | 387 | def init_weights(module): 388 | if isinstance(module, (nn.Linear, nn.Embedding)): 389 | # Slightly different from the TF version which uses truncated_normal for initialization 390 | # cf https://github.com/pytorch/pytorch/pull/5617 391 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 392 | elif isinstance(module, BERTLayerNorm): 393 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 394 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 395 | if isinstance(module, nn.Linear): 396 | module.bias.data.zero_() 397 | self.apply(init_weights) 398 | 399 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None): 400 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 401 | pooled_output = self.dropout(pooled_output) 402 | logits = self.classifier(pooled_output) 403 | 404 | if labels is not None: 405 | loss_fct = CrossEntropyLoss() 406 | loss = loss_fct(logits, labels) 407 | return loss, logits 408 | else: 409 | return logits 410 | 411 | class BertForQuestionAnswering(nn.Module): 412 | """BERT model for Question Answering (span extraction). 413 | This module is composed of the BERT model with a linear layer on top of 414 | the sequence output that computes start_logits and end_logits 415 | 416 | Example usage: 417 | ```python 418 | # Already been converted into WordPiece token ids 419 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 420 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 421 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 422 | 423 | config = BertConfig(vocab_size=32000, hidden_size=512, 424 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 425 | 426 | model = BertForQuestionAnswering(config) 427 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 428 | ``` 429 | """ 430 | def __init__(self, config): 431 | super(BertForQuestionAnswering, self).__init__() 432 | self.bert = BertModel(config) 433 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 434 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 435 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 436 | 437 | def init_weights(module): 438 | if isinstance(module, (nn.Linear, nn.Embedding)): 439 | # Slightly different from the TF version which uses truncated_normal for initialization 440 | # cf https://github.com/pytorch/pytorch/pull/5617 441 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 442 | elif isinstance(module, BERTLayerNorm): 443 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 444 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 445 | if isinstance(module, nn.Linear): 446 | module.bias.data.zero_() 447 | self.apply(init_weights) 448 | 449 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 450 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 451 | sequence_output = all_encoder_layers[-1] 452 | logits = self.qa_outputs(sequence_output) 453 | start_logits, end_logits = logits.split(1, dim=-1) 454 | start_logits = start_logits.squeeze(-1) 455 | end_logits = end_logits.squeeze(-1) 456 | 457 | if start_positions is not None and end_positions is not None: 458 | # If we are on multi-GPU, split add a dimension - if not this is a no-op 459 | start_positions = start_positions.squeeze(-1) 460 | end_positions = end_positions.squeeze(-1) 461 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 462 | ignored_index = start_logits.size(1) 463 | start_positions.clamp_(0, ignored_index) 464 | end_positions.clamp_(0, ignored_index) 465 | 466 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 467 | start_loss = loss_fct(start_logits, start_positions) 468 | end_loss = loss_fct(end_logits, end_positions) 469 | total_loss = (start_loss + end_loss) / 2 470 | return total_loss, (start_logits, end_logits) 471 | else: 472 | return start_logits, end_logits 473 | -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import csv 22 | import os 23 | import logging 24 | import argparse 25 | import random 26 | from tqdm import tqdm, trange 27 | 28 | import numpy as np 29 | import torch 30 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 31 | from torch.utils.data.distributed import DistributedSampler 32 | 33 | import tokenization 34 | from modeling import BertConfig, BertForSequenceClassification 35 | from optimization import BERTAdam 36 | 37 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 38 | datefmt = '%m/%d/%Y %H:%M:%S', 39 | level = logging.INFO) 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | class InputExample(object): 44 | """A single training/test example for simple sequence classification.""" 45 | 46 | def __init__(self, guid, text_a, text_b=None, label=None): 47 | """Constructs a InputExample. 48 | 49 | Args: 50 | guid: Unique id for the example. 51 | text_a: string. The untokenized text of the first sequence. For single 52 | sequence tasks, only this sequence must be specified. 53 | text_b: (Optional) string. The untokenized text of the second sequence. 54 | Only must be specified for sequence pair tasks. 55 | label: (Optional) string. The label of the example. This should be 56 | specified for train and dev examples, but not for test examples. 57 | """ 58 | self.guid = guid 59 | self.text_a = text_a 60 | self.text_b = text_b 61 | self.label = label 62 | 63 | 64 | class InputFeatures(object): 65 | """A single set of features of data.""" 66 | 67 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 68 | self.input_ids = input_ids 69 | self.input_mask = input_mask 70 | self.segment_ids = segment_ids 71 | self.label_id = label_id 72 | 73 | 74 | class DataProcessor(object): 75 | """Base class for data converters for sequence classification data sets.""" 76 | 77 | def get_train_examples(self, data_dir): 78 | """Gets a collection of `InputExample`s for the train set.""" 79 | raise NotImplementedError() 80 | 81 | def get_dev_examples(self, data_dir): 82 | """Gets a collection of `InputExample`s for the dev set.""" 83 | raise NotImplementedError() 84 | 85 | def get_labels(self): 86 | """Gets the list of labels for this data set.""" 87 | raise NotImplementedError() 88 | 89 | @classmethod 90 | def _read_tsv(cls, input_file, quotechar=None): 91 | """Reads a tab separated value file.""" 92 | with open(input_file, "r") as f: 93 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 94 | lines = [] 95 | for line in reader: 96 | lines.append(line) 97 | return lines 98 | 99 | 100 | class MrpcProcessor(DataProcessor): 101 | """Processor for the MRPC data set (GLUE version).""" 102 | 103 | def get_train_examples(self, data_dir): 104 | """See base class.""" 105 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 106 | return self._create_examples( 107 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 108 | 109 | def get_dev_examples(self, data_dir): 110 | """See base class.""" 111 | return self._create_examples( 112 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 113 | 114 | def get_labels(self): 115 | """See base class.""" 116 | return ["0", "1"] 117 | 118 | def _create_examples(self, lines, set_type): 119 | """Creates examples for the training and dev sets.""" 120 | examples = [] 121 | for (i, line) in enumerate(lines): 122 | if i == 0: 123 | continue 124 | guid = "%s-%s" % (set_type, i) 125 | text_a = tokenization.convert_to_unicode(line[3]) 126 | text_b = tokenization.convert_to_unicode(line[4]) 127 | label = tokenization.convert_to_unicode(line[0]) 128 | examples.append( 129 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 130 | return examples 131 | 132 | 133 | class MnliProcessor(DataProcessor): 134 | """Processor for the MultiNLI data set (GLUE version).""" 135 | 136 | def get_train_examples(self, data_dir): 137 | """See base class.""" 138 | return self._create_examples( 139 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 140 | 141 | def get_dev_examples(self, data_dir): 142 | """See base class.""" 143 | return self._create_examples( 144 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 145 | "dev_matched") 146 | 147 | def get_labels(self): 148 | """See base class.""" 149 | return ["contradiction", "entailment", "neutral"] 150 | 151 | def _create_examples(self, lines, set_type): 152 | """Creates examples for the training and dev sets.""" 153 | examples = [] 154 | for (i, line) in enumerate(lines): 155 | if i == 0: 156 | continue 157 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) 158 | text_a = tokenization.convert_to_unicode(line[8]) 159 | text_b = tokenization.convert_to_unicode(line[9]) 160 | label = tokenization.convert_to_unicode(line[-1]) 161 | examples.append( 162 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 163 | return examples 164 | 165 | 166 | class ColaProcessor(DataProcessor): 167 | """Processor for the CoLA data set (GLUE version).""" 168 | 169 | def get_train_examples(self, data_dir): 170 | """See base class.""" 171 | return self._create_examples( 172 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 173 | 174 | def get_dev_examples(self, data_dir): 175 | """See base class.""" 176 | return self._create_examples( 177 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 178 | 179 | def get_labels(self): 180 | """See base class.""" 181 | return ["0", "1"] 182 | 183 | def _create_examples(self, lines, set_type): 184 | """Creates examples for the training and dev sets.""" 185 | examples = [] 186 | for (i, line) in enumerate(lines): 187 | guid = "%s-%s" % (set_type, i) 188 | text_a = tokenization.convert_to_unicode(line[3]) 189 | label = tokenization.convert_to_unicode(line[1]) 190 | examples.append( 191 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 192 | return examples 193 | 194 | 195 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 196 | """Loads a data file into a list of `InputBatch`s.""" 197 | 198 | label_map = {} 199 | for (i, label) in enumerate(label_list): 200 | label_map[label] = i 201 | 202 | features = [] 203 | for (ex_index, example) in enumerate(examples): 204 | tokens_a = tokenizer.tokenize(example.text_a) 205 | 206 | tokens_b = None 207 | if example.text_b: 208 | tokens_b = tokenizer.tokenize(example.text_b) 209 | 210 | if tokens_b: 211 | # Modifies `tokens_a` and `tokens_b` in place so that the total 212 | # length is less than the specified length. 213 | # Account for [CLS], [SEP], [SEP] with "- 3" 214 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 215 | else: 216 | # Account for [CLS] and [SEP] with "- 2" 217 | if len(tokens_a) > max_seq_length - 2: 218 | tokens_a = tokens_a[0:(max_seq_length - 2)] 219 | 220 | # The convention in BERT is: 221 | # (a) For sequence pairs: 222 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 223 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 224 | # (b) For single sequences: 225 | # tokens: [CLS] the dog is hairy . [SEP] 226 | # type_ids: 0 0 0 0 0 0 0 227 | # 228 | # Where "type_ids" are used to indicate whether this is the first 229 | # sequence or the second sequence. The embedding vectors for `type=0` and 230 | # `type=1` were learned during pre-training and are added to the wordpiece 231 | # embedding vector (and position vector). This is not *strictly* necessary 232 | # since the [SEP] token unambigiously separates the sequences, but it makes 233 | # it easier for the model to learn the concept of sequences. 234 | # 235 | # For classification tasks, the first vector (corresponding to [CLS]) is 236 | # used as as the "sentence vector". Note that this only makes sense because 237 | # the entire model is fine-tuned. 238 | tokens = [] 239 | segment_ids = [] 240 | tokens.append("[CLS]") 241 | segment_ids.append(0) 242 | for token in tokens_a: 243 | tokens.append(token) 244 | segment_ids.append(0) 245 | tokens.append("[SEP]") 246 | segment_ids.append(0) 247 | 248 | if tokens_b: 249 | for token in tokens_b: 250 | tokens.append(token) 251 | segment_ids.append(1) 252 | tokens.append("[SEP]") 253 | segment_ids.append(1) 254 | 255 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 256 | 257 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 258 | # tokens are attended to. 259 | input_mask = [1] * len(input_ids) 260 | 261 | # Zero-pad up to the sequence length. 262 | while len(input_ids) < max_seq_length: 263 | input_ids.append(0) 264 | input_mask.append(0) 265 | segment_ids.append(0) 266 | 267 | assert len(input_ids) == max_seq_length 268 | assert len(input_mask) == max_seq_length 269 | assert len(segment_ids) == max_seq_length 270 | 271 | label_id = label_map[example.label] 272 | if ex_index < 5: 273 | logger.info("*** Example ***") 274 | logger.info("guid: %s" % (example.guid)) 275 | logger.info("tokens: %s" % " ".join( 276 | [tokenization.printable_text(x) for x in tokens])) 277 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 278 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 279 | logger.info( 280 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 281 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 282 | 283 | features.append( 284 | InputFeatures( 285 | input_ids=input_ids, 286 | input_mask=input_mask, 287 | segment_ids=segment_ids, 288 | label_id=label_id)) 289 | return features 290 | 291 | 292 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 293 | """Truncates a sequence pair in place to the maximum length.""" 294 | 295 | # This is a simple heuristic which will always truncate the longer sequence 296 | # one token at a time. This makes more sense than truncating an equal percent 297 | # of tokens from each, since if one sequence is very short then each token 298 | # that's truncated likely contains more information than a longer sequence. 299 | while True: 300 | total_length = len(tokens_a) + len(tokens_b) 301 | if total_length <= max_length: 302 | break 303 | if len(tokens_a) > len(tokens_b): 304 | tokens_a.pop() 305 | else: 306 | tokens_b.pop() 307 | 308 | def accuracy(out, labels): 309 | outputs = np.argmax(out, axis=1) 310 | return np.sum(outputs==labels) 311 | 312 | def main(): 313 | parser = argparse.ArgumentParser() 314 | 315 | ## Required parameters 316 | parser.add_argument("--data_dir", 317 | default=None, 318 | type=str, 319 | required=True, 320 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 321 | parser.add_argument("--bert_config_file", 322 | default=None, 323 | type=str, 324 | required=True, 325 | help="The config json file corresponding to the pre-trained BERT model. \n" 326 | "This specifies the model architecture.") 327 | parser.add_argument("--task_name", 328 | default=None, 329 | type=str, 330 | required=True, 331 | help="The name of the task to train.") 332 | parser.add_argument("--vocab_file", 333 | default=None, 334 | type=str, 335 | required=True, 336 | help="The vocabulary file that the BERT model was trained on.") 337 | parser.add_argument("--output_dir", 338 | default=None, 339 | type=str, 340 | required=True, 341 | help="The output directory where the model checkpoints will be written.") 342 | 343 | ## Other parameters 344 | parser.add_argument("--init_checkpoint", 345 | default=None, 346 | type=str, 347 | help="Initial checkpoint (usually from a pre-trained BERT model).") 348 | parser.add_argument("--do_lower_case", 349 | default=False, 350 | action='store_true', 351 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 352 | parser.add_argument("--max_seq_length", 353 | default=128, 354 | type=int, 355 | help="The maximum total input sequence length after WordPiece tokenization. \n" 356 | "Sequences longer than this will be truncated, and sequences shorter \n" 357 | "than this will be padded.") 358 | parser.add_argument("--do_train", 359 | default=False, 360 | action='store_true', 361 | help="Whether to run training.") 362 | parser.add_argument("--do_eval", 363 | default=False, 364 | action='store_true', 365 | help="Whether to run eval on the dev set.") 366 | parser.add_argument("--train_batch_size", 367 | default=32, 368 | type=int, 369 | help="Total batch size for training.") 370 | parser.add_argument("--eval_batch_size", 371 | default=8, 372 | type=int, 373 | help="Total batch size for eval.") 374 | parser.add_argument("--learning_rate", 375 | default=5e-5, 376 | type=float, 377 | help="The initial learning rate for Adam.") 378 | parser.add_argument("--num_train_epochs", 379 | default=3.0, 380 | type=float, 381 | help="Total number of training epochs to perform.") 382 | parser.add_argument("--warmup_proportion", 383 | default=0.1, 384 | type=float, 385 | help="Proportion of training to perform linear learning rate warmup for. " 386 | "E.g., 0.1 = 10%% of training.") 387 | parser.add_argument("--save_checkpoints_steps", 388 | default=1000, 389 | type=int, 390 | help="How often to save the model checkpoint.") 391 | parser.add_argument("--no_cuda", 392 | default=False, 393 | action='store_true', 394 | help="Whether not to use CUDA when available") 395 | parser.add_argument("--accumulate_gradients", 396 | type=int, 397 | default=1, 398 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)") 399 | parser.add_argument("--local_rank", 400 | type=int, 401 | default=-1, 402 | help="local_rank for distributed training on gpus") 403 | parser.add_argument('--seed', 404 | type=int, 405 | default=42, 406 | help="random seed for initialization") 407 | parser.add_argument('--gradient_accumulation_steps', 408 | type=int, 409 | default=1, 410 | help="Number of updates steps to accumualte before performing a backward/update pass.") 411 | args = parser.parse_args() 412 | 413 | processors = { 414 | "cola": ColaProcessor, 415 | "mnli": MnliProcessor, 416 | "mrpc": MrpcProcessor, 417 | } 418 | 419 | if args.local_rank == -1 or args.no_cuda: 420 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 421 | n_gpu = torch.cuda.device_count() 422 | else: 423 | device = torch.device("cuda", args.local_rank) 424 | n_gpu = 1 425 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 426 | torch.distributed.init_process_group(backend='nccl') 427 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 428 | 429 | if args.accumulate_gradients < 1: 430 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( 431 | args.accumulate_gradients)) 432 | 433 | args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients) 434 | 435 | random.seed(args.seed) 436 | np.random.seed(args.seed) 437 | torch.manual_seed(args.seed) 438 | if n_gpu > 0: 439 | torch.cuda.manual_seed_all(args.seed) 440 | 441 | if not args.do_train and not args.do_eval: 442 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 443 | 444 | bert_config = BertConfig.from_json_file(args.bert_config_file) 445 | 446 | if args.max_seq_length > bert_config.max_position_embeddings: 447 | raise ValueError( 448 | "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format( 449 | args.max_seq_length, bert_config.max_position_embeddings)) 450 | 451 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 452 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 453 | os.makedirs(args.output_dir, exist_ok=True) 454 | 455 | task_name = args.task_name.lower() 456 | 457 | if task_name not in processors: 458 | raise ValueError("Task not found: %s" % (task_name)) 459 | 460 | processor = processors[task_name]() 461 | 462 | label_list = processor.get_labels() 463 | 464 | tokenizer = tokenization.FullTokenizer( 465 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) 466 | 467 | train_examples = None 468 | num_train_steps = None 469 | if args.do_train: 470 | train_examples = processor.get_train_examples(args.data_dir) 471 | num_train_steps = int( 472 | len(train_examples) / args.train_batch_size * args.num_train_epochs) 473 | 474 | model = BertForSequenceClassification(bert_config, len(label_list)) 475 | if args.init_checkpoint is not None: 476 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 477 | model.to(device) 478 | 479 | if args.local_rank != -1: 480 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 481 | output_device=args.local_rank) 482 | elif n_gpu > 1: 483 | model = torch.nn.DataParallel(model) 484 | 485 | no_decay = ['bias', 'gamma', 'beta'] 486 | optimizer_parameters = [ 487 | {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01}, 488 | {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0} 489 | ] 490 | 491 | optimizer = BERTAdam(optimizer_parameters, 492 | lr=args.learning_rate, 493 | warmup=args.warmup_proportion, 494 | t_total=num_train_steps) 495 | 496 | global_step = 0 497 | if args.do_train: 498 | train_features = convert_examples_to_features( 499 | train_examples, label_list, args.max_seq_length, tokenizer) 500 | logger.info("***** Running training *****") 501 | logger.info(" Num examples = %d", len(train_examples)) 502 | logger.info(" Batch size = %d", args.train_batch_size) 503 | logger.info(" Num steps = %d", num_train_steps) 504 | 505 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 506 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 507 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 508 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 509 | 510 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 511 | if args.local_rank == -1: 512 | train_sampler = RandomSampler(train_data) 513 | else: 514 | train_sampler = DistributedSampler(train_data) 515 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 516 | 517 | model.train() 518 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 519 | tr_loss = 0 520 | nb_tr_examples, nb_tr_steps = 0, 0 521 | for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")): 522 | input_ids = input_ids.to(device) 523 | input_mask = input_mask.to(device) 524 | segment_ids = segment_ids.to(device) 525 | label_ids = label_ids.to(device) 526 | 527 | loss, _ = model(input_ids, segment_ids, input_mask, label_ids) 528 | if n_gpu > 1: 529 | loss = loss.mean() # mean() to average on multi-gpu. 530 | tr_loss += loss.item() 531 | nb_tr_examples += input_ids.size(0) 532 | nb_tr_steps += 1 533 | loss.backward() 534 | 535 | if (step + 1) % args.gradient_accumulation_steps == 0: 536 | optimizer.step() # We have accumulated enought gradients 537 | model.zero_grad() 538 | global_step += 1 539 | 540 | if args.do_eval: 541 | eval_examples = processor.get_dev_examples(args.data_dir) 542 | eval_features = convert_examples_to_features( 543 | eval_examples, label_list, args.max_seq_length, tokenizer) 544 | 545 | logger.info("***** Running evaluation *****") 546 | logger.info(" Num examples = %d", len(eval_examples)) 547 | logger.info(" Batch size = %d", args.eval_batch_size) 548 | 549 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 550 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 551 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 552 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 553 | 554 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 555 | if args.local_rank == -1: 556 | eval_sampler = SequentialSampler(eval_data) 557 | else: 558 | eval_sampler = DistributedSampler(eval_data) 559 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 560 | 561 | model.eval() 562 | eval_loss, eval_accuracy = 0, 0 563 | nb_eval_steps, nb_eval_examples = 0, 0 564 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: 565 | input_ids = input_ids.to(device) 566 | input_mask = input_mask.to(device) 567 | segment_ids = segment_ids.to(device) 568 | label_ids = label_ids.to(device) 569 | 570 | tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) 571 | 572 | logits = logits.detach().cpu().numpy() 573 | label_ids = label_ids.to('cpu').numpy() 574 | tmp_eval_accuracy = accuracy(logits, label_ids) 575 | 576 | eval_loss += tmp_eval_loss.mean().item() 577 | eval_accuracy += tmp_eval_accuracy 578 | 579 | nb_eval_examples += input_ids.size(0) 580 | nb_eval_steps += 1 581 | 582 | eval_loss = eval_loss / nb_eval_steps #len(eval_dataloader) 583 | eval_accuracy = eval_accuracy / nb_eval_examples #len(eval_dataloader) 584 | 585 | result = {'eval_loss': eval_loss, 586 | 'eval_accuracy': eval_accuracy, 587 | 'global_step': global_step, 588 | 'loss': tr_loss/nb_tr_steps}#'loss': loss.item()} 589 | 590 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 591 | with open(output_eval_file, "w") as writer: 592 | logger.info("***** Eval results *****") 593 | for key in sorted(result.keys()): 594 | logger.info(" %s = %s", key, str(result[key])) 595 | writer.write("%s = %s\n" % (key, str(result[key]))) 596 | 597 | if __name__ == "__main__": 598 | main() 599 | -------------------------------------------------------------------------------- /run_squad.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run BERT on SQuAD.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import collections 23 | import logging 24 | import json 25 | import math 26 | import os 27 | import random 28 | import six 29 | from tqdm import tqdm, trange 30 | 31 | import numpy as np 32 | import torch 33 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 34 | from torch.utils.data.distributed import DistributedSampler 35 | 36 | import tokenization 37 | from modeling import BertConfig, BertForQuestionAnswering 38 | from optimization import BERTAdam 39 | 40 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 41 | datefmt = '%m/%d/%Y %H:%M:%S', 42 | level = logging.INFO) 43 | logger = logging.getLogger(__name__) 44 | 45 | 46 | class SquadExample(object): 47 | """A single training/test example for simple sequence classification.""" 48 | 49 | def __init__(self, 50 | qas_id, 51 | question_text, 52 | doc_tokens, 53 | orig_answer_text=None, 54 | start_position=None, 55 | end_position=None): 56 | self.qas_id = qas_id 57 | self.question_text = question_text 58 | self.doc_tokens = doc_tokens 59 | self.orig_answer_text = orig_answer_text 60 | self.start_position = start_position 61 | self.end_position = end_position 62 | 63 | def __str__(self): 64 | return self.__repr__() 65 | 66 | def __repr__(self): 67 | s = "" 68 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) 69 | s += ", question_text: %s" % ( 70 | tokenization.printable_text(self.question_text)) 71 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 72 | if self.start_position: 73 | s += ", start_position: %d" % (self.start_position) 74 | if self.start_position: 75 | s += ", end_position: %d" % (self.end_position) 76 | return s 77 | 78 | 79 | class InputFeatures(object): 80 | """A single set of features of data.""" 81 | 82 | def __init__(self, 83 | unique_id, 84 | example_index, 85 | doc_span_index, 86 | tokens, 87 | token_to_orig_map, 88 | token_is_max_context, 89 | input_ids, 90 | input_mask, 91 | segment_ids, 92 | start_position=None, 93 | end_position=None): 94 | self.unique_id = unique_id 95 | self.example_index = example_index 96 | self.doc_span_index = doc_span_index 97 | self.tokens = tokens 98 | self.token_to_orig_map = token_to_orig_map 99 | self.token_is_max_context = token_is_max_context 100 | self.input_ids = input_ids 101 | self.input_mask = input_mask 102 | self.segment_ids = segment_ids 103 | self.start_position = start_position 104 | self.end_position = end_position 105 | 106 | 107 | def read_squad_examples(input_file, is_training): 108 | """Read a SQuAD json file into a list of SquadExample.""" 109 | with open(input_file, "r") as reader: 110 | input_data = json.load(reader)["data"] 111 | 112 | def is_whitespace(c): 113 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 114 | return True 115 | return False 116 | 117 | examples = [] 118 | for entry in input_data: 119 | for paragraph in entry["paragraphs"]: 120 | paragraph_text = paragraph["context"] 121 | doc_tokens = [] 122 | char_to_word_offset = [] 123 | prev_is_whitespace = True 124 | for c in paragraph_text: 125 | if is_whitespace(c): 126 | prev_is_whitespace = True 127 | else: 128 | if prev_is_whitespace: 129 | doc_tokens.append(c) 130 | else: 131 | doc_tokens[-1] += c 132 | prev_is_whitespace = False 133 | char_to_word_offset.append(len(doc_tokens) - 1) 134 | 135 | for qa in paragraph["qas"]: 136 | qas_id = qa["id"] 137 | question_text = qa["question"] 138 | start_position = None 139 | end_position = None 140 | orig_answer_text = None 141 | if is_training: 142 | if len(qa["answers"]) != 1: 143 | raise ValueError( 144 | "For training, each question should have exactly 1 answer.") 145 | answer = qa["answers"][0] 146 | orig_answer_text = answer["text"] 147 | answer_offset = answer["answer_start"] 148 | answer_length = len(orig_answer_text) 149 | start_position = char_to_word_offset[answer_offset] 150 | end_position = char_to_word_offset[answer_offset + answer_length - 1] 151 | # Only add answers where the text can be exactly recovered from the 152 | # document. If this CAN'T happen it's likely due to weird Unicode 153 | # stuff so we will just skip the example. 154 | # 155 | # Note that this means for training mode, every example is NOT 156 | # guaranteed to be preserved. 157 | actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) 158 | cleaned_answer_text = " ".join( 159 | tokenization.whitespace_tokenize(orig_answer_text)) 160 | if actual_text.find(cleaned_answer_text) == -1: 161 | logger.warning("Could not find answer: '%s' vs. '%s'", 162 | actual_text, cleaned_answer_text) 163 | continue 164 | 165 | example = SquadExample( 166 | qas_id=qas_id, 167 | question_text=question_text, 168 | doc_tokens=doc_tokens, 169 | orig_answer_text=orig_answer_text, 170 | start_position=start_position, 171 | end_position=end_position) 172 | examples.append(example) 173 | return examples 174 | 175 | 176 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 177 | doc_stride, max_query_length, is_training): 178 | """Loads a data file into a list of `InputBatch`s.""" 179 | 180 | unique_id = 1000000000 181 | 182 | features = [] 183 | for (example_index, example) in enumerate(examples): 184 | query_tokens = tokenizer.tokenize(example.question_text) 185 | 186 | if len(query_tokens) > max_query_length: 187 | query_tokens = query_tokens[0:max_query_length] 188 | 189 | tok_to_orig_index = [] 190 | orig_to_tok_index = [] 191 | all_doc_tokens = [] 192 | for (i, token) in enumerate(example.doc_tokens): 193 | orig_to_tok_index.append(len(all_doc_tokens)) 194 | sub_tokens = tokenizer.tokenize(token) 195 | for sub_token in sub_tokens: 196 | tok_to_orig_index.append(i) 197 | all_doc_tokens.append(sub_token) 198 | 199 | tok_start_position = None 200 | tok_end_position = None 201 | if is_training: 202 | tok_start_position = orig_to_tok_index[example.start_position] 203 | if example.end_position < len(example.doc_tokens) - 1: 204 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 205 | else: 206 | tok_end_position = len(all_doc_tokens) - 1 207 | (tok_start_position, tok_end_position) = _improve_answer_span( 208 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, 209 | example.orig_answer_text) 210 | 211 | # The -3 accounts for [CLS], [SEP] and [SEP] 212 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 213 | 214 | # We can have documents that are longer than the maximum sequence length. 215 | # To deal with this we do a sliding window approach, where we take chunks 216 | # of the up to our max length with a stride of `doc_stride`. 217 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 218 | "DocSpan", ["start", "length"]) 219 | doc_spans = [] 220 | start_offset = 0 221 | while start_offset < len(all_doc_tokens): 222 | length = len(all_doc_tokens) - start_offset 223 | if length > max_tokens_for_doc: 224 | length = max_tokens_for_doc 225 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 226 | if start_offset + length == len(all_doc_tokens): 227 | break 228 | start_offset += min(length, doc_stride) 229 | 230 | for (doc_span_index, doc_span) in enumerate(doc_spans): 231 | tokens = [] 232 | token_to_orig_map = {} 233 | token_is_max_context = {} 234 | segment_ids = [] 235 | tokens.append("[CLS]") 236 | segment_ids.append(0) 237 | for token in query_tokens: 238 | tokens.append(token) 239 | segment_ids.append(0) 240 | tokens.append("[SEP]") 241 | segment_ids.append(0) 242 | 243 | for i in range(doc_span.length): 244 | split_token_index = doc_span.start + i 245 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 246 | 247 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 248 | split_token_index) 249 | token_is_max_context[len(tokens)] = is_max_context 250 | tokens.append(all_doc_tokens[split_token_index]) 251 | segment_ids.append(1) 252 | tokens.append("[SEP]") 253 | segment_ids.append(1) 254 | 255 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 256 | 257 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 258 | # tokens are attended to. 259 | input_mask = [1] * len(input_ids) 260 | 261 | # Zero-pad up to the sequence length. 262 | while len(input_ids) < max_seq_length: 263 | input_ids.append(0) 264 | input_mask.append(0) 265 | segment_ids.append(0) 266 | 267 | assert len(input_ids) == max_seq_length 268 | assert len(input_mask) == max_seq_length 269 | assert len(segment_ids) == max_seq_length 270 | 271 | start_position = None 272 | end_position = None 273 | if is_training: 274 | # For training, if our document chunk does not contain an annotation 275 | # we throw it out, since there is nothing to predict. 276 | doc_start = doc_span.start 277 | doc_end = doc_span.start + doc_span.length - 1 278 | if (example.start_position < doc_start or 279 | example.end_position < doc_start or 280 | example.start_position > doc_end or example.end_position > doc_end): 281 | continue 282 | 283 | doc_offset = len(query_tokens) + 2 284 | start_position = tok_start_position - doc_start + doc_offset 285 | end_position = tok_end_position - doc_start + doc_offset 286 | 287 | if example_index < 20: 288 | logger.info("*** Example ***") 289 | logger.info("unique_id: %s" % (unique_id)) 290 | logger.info("example_index: %s" % (example_index)) 291 | logger.info("doc_span_index: %s" % (doc_span_index)) 292 | logger.info("tokens: %s" % " ".join( 293 | [tokenization.printable_text(x) for x in tokens])) 294 | logger.info("token_to_orig_map: %s" % " ".join( 295 | ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) 296 | logger.info("token_is_max_context: %s" % " ".join([ 297 | "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) 298 | ])) 299 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 300 | logger.info( 301 | "input_mask: %s" % " ".join([str(x) for x in input_mask])) 302 | logger.info( 303 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 304 | if is_training: 305 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 306 | logger.info("start_position: %d" % (start_position)) 307 | logger.info("end_position: %d" % (end_position)) 308 | logger.info( 309 | "answer: %s" % (tokenization.printable_text(answer_text))) 310 | 311 | features.append( 312 | InputFeatures( 313 | unique_id=unique_id, 314 | example_index=example_index, 315 | doc_span_index=doc_span_index, 316 | tokens=tokens, 317 | token_to_orig_map=token_to_orig_map, 318 | token_is_max_context=token_is_max_context, 319 | input_ids=input_ids, 320 | input_mask=input_mask, 321 | segment_ids=segment_ids, 322 | start_position=start_position, 323 | end_position=end_position)) 324 | unique_id += 1 325 | 326 | return features 327 | 328 | 329 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 330 | orig_answer_text): 331 | """Returns tokenized answer spans that better match the annotated answer.""" 332 | 333 | # The SQuAD annotations are character based. We first project them to 334 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 335 | # often find a "better match". For example: 336 | # 337 | # Question: What year was John Smith born? 338 | # Context: The leader was John Smith (1895-1943). 339 | # Answer: 1895 340 | # 341 | # The original whitespace-tokenized answer will be "(1895-1943).". However 342 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 343 | # the exact answer, 1895. 344 | # 345 | # However, this is not always possible. Consider the following: 346 | # 347 | # Question: What country is the top exporter of electornics? 348 | # Context: The Japanese electronics industry is the lagest in the world. 349 | # Answer: Japan 350 | # 351 | # In this case, the annotator chose "Japan" as a character sub-span of 352 | # the word "Japanese". Since our WordPiece tokenizer does not split 353 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 354 | # in SQuAD, but does happen. 355 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 356 | 357 | for new_start in range(input_start, input_end + 1): 358 | for new_end in range(input_end, new_start - 1, -1): 359 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 360 | if text_span == tok_answer_text: 361 | return (new_start, new_end) 362 | 363 | return (input_start, input_end) 364 | 365 | 366 | def _check_is_max_context(doc_spans, cur_span_index, position): 367 | """Check if this is the 'max context' doc span for the token.""" 368 | 369 | # Because of the sliding window approach taken to scoring documents, a single 370 | # token can appear in multiple documents. E.g. 371 | # Doc: the man went to the store and bought a gallon of milk 372 | # Span A: the man went to the 373 | # Span B: to the store and bought 374 | # Span C: and bought a gallon of 375 | # ... 376 | # 377 | # Now the word 'bought' will have two scores from spans B and C. We only 378 | # want to consider the score with "maximum context", which we define as 379 | # the *minimum* of its left and right context (the *sum* of left and 380 | # right context will always be the same, of course). 381 | # 382 | # In the example the maximum context for 'bought' would be span C since 383 | # it has 1 left context and 3 right context, while span B has 4 left context 384 | # and 0 right context. 385 | best_score = None 386 | best_span_index = None 387 | for (span_index, doc_span) in enumerate(doc_spans): 388 | end = doc_span.start + doc_span.length - 1 389 | if position < doc_span.start: 390 | continue 391 | if position > end: 392 | continue 393 | num_left_context = position - doc_span.start 394 | num_right_context = end - position 395 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 396 | if best_score is None or score > best_score: 397 | best_score = score 398 | best_span_index = span_index 399 | 400 | return cur_span_index == best_span_index 401 | 402 | 403 | 404 | RawResult = collections.namedtuple("RawResult", 405 | ["unique_id", "start_logits", "end_logits"]) 406 | 407 | 408 | def write_predictions(all_examples, all_features, all_results, n_best_size, 409 | max_answer_length, do_lower_case, output_prediction_file, 410 | output_nbest_file, verbose_logging): 411 | """Write final predictions to the json file.""" 412 | logger.info("Writing predictions to: %s" % (output_prediction_file)) 413 | logger.info("Writing nbest to: %s" % (output_nbest_file)) 414 | 415 | example_index_to_features = collections.defaultdict(list) 416 | for feature in all_features: 417 | example_index_to_features[feature.example_index].append(feature) 418 | 419 | unique_id_to_result = {} 420 | for result in all_results: 421 | unique_id_to_result[result.unique_id] = result 422 | 423 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 424 | "PrelimPrediction", 425 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 426 | 427 | all_predictions = collections.OrderedDict() 428 | all_nbest_json = collections.OrderedDict() 429 | for (example_index, example) in enumerate(all_examples): 430 | features = example_index_to_features[example_index] 431 | 432 | prelim_predictions = [] 433 | for (feature_index, feature) in enumerate(features): 434 | result = unique_id_to_result[feature.unique_id] 435 | 436 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 437 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 438 | for start_index in start_indexes: 439 | for end_index in end_indexes: 440 | # We could hypothetically create invalid predictions, e.g., predict 441 | # that the start of the span is in the question. We throw out all 442 | # invalid predictions. 443 | if start_index >= len(feature.tokens): 444 | continue 445 | if end_index >= len(feature.tokens): 446 | continue 447 | if start_index not in feature.token_to_orig_map: 448 | continue 449 | if end_index not in feature.token_to_orig_map: 450 | continue 451 | if not feature.token_is_max_context.get(start_index, False): 452 | continue 453 | if end_index < start_index: 454 | continue 455 | length = end_index - start_index + 1 456 | if length > max_answer_length: 457 | continue 458 | prelim_predictions.append( 459 | _PrelimPrediction( 460 | feature_index=feature_index, 461 | start_index=start_index, 462 | end_index=end_index, 463 | start_logit=result.start_logits[start_index], 464 | end_logit=result.end_logits[end_index])) 465 | 466 | prelim_predictions = sorted( 467 | prelim_predictions, 468 | key=lambda x: (x.start_logit + x.end_logit), 469 | reverse=True) 470 | 471 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 472 | "NbestPrediction", ["text", "start_logit", "end_logit"]) 473 | 474 | seen_predictions = {} 475 | nbest = [] 476 | for pred in prelim_predictions: 477 | if len(nbest) >= n_best_size: 478 | break 479 | feature = features[pred.feature_index] 480 | 481 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 482 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 483 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 484 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 485 | tok_text = " ".join(tok_tokens) 486 | 487 | # De-tokenize WordPieces that have been split off. 488 | tok_text = tok_text.replace(" ##", "") 489 | tok_text = tok_text.replace("##", "") 490 | 491 | # Clean whitespace 492 | tok_text = tok_text.strip() 493 | tok_text = " ".join(tok_text.split()) 494 | orig_text = " ".join(orig_tokens) 495 | 496 | final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) 497 | if final_text in seen_predictions: 498 | continue 499 | 500 | seen_predictions[final_text] = True 501 | nbest.append( 502 | _NbestPrediction( 503 | text=final_text, 504 | start_logit=pred.start_logit, 505 | end_logit=pred.end_logit)) 506 | 507 | # In very rare edge cases we could have no valid predictions. So we 508 | # just create a nonce prediction in this case to avoid failure. 509 | if not nbest: 510 | nbest.append( 511 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 512 | 513 | assert len(nbest) >= 1 514 | 515 | total_scores = [] 516 | for entry in nbest: 517 | total_scores.append(entry.start_logit + entry.end_logit) 518 | 519 | probs = _compute_softmax(total_scores) 520 | 521 | nbest_json = [] 522 | for (i, entry) in enumerate(nbest): 523 | output = collections.OrderedDict() 524 | output["text"] = entry.text 525 | output["probability"] = probs[i] 526 | output["start_logit"] = entry.start_logit 527 | output["end_logit"] = entry.end_logit 528 | nbest_json.append(output) 529 | 530 | assert len(nbest_json) >= 1 531 | 532 | all_predictions[example.qas_id] = nbest_json[0]["text"] 533 | all_nbest_json[example.qas_id] = nbest_json 534 | 535 | with open(output_prediction_file, "w") as writer: 536 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 537 | 538 | with open(output_nbest_file, "w") as writer: 539 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 540 | 541 | 542 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 543 | """Project the tokenized prediction back to the original text.""" 544 | 545 | # When we created the data, we kept track of the alignment between original 546 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 547 | # now `orig_text` contains the span of our original text corresponding to the 548 | # span that we predicted. 549 | # 550 | # However, `orig_text` may contain extra characters that we don't want in 551 | # our prediction. 552 | # 553 | # For example, let's say: 554 | # pred_text = steve smith 555 | # orig_text = Steve Smith's 556 | # 557 | # We don't want to return `orig_text` because it contains the extra "'s". 558 | # 559 | # We don't want to return `pred_text` because it's already been normalized 560 | # (the SQuAD eval script also does punctuation stripping/lower casing but 561 | # our tokenizer does additional normalization like stripping accent 562 | # characters). 563 | # 564 | # What we really want to return is "Steve Smith". 565 | # 566 | # Therefore, we have to apply a semi-complicated alignment heruistic between 567 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This 568 | # can fail in certain cases in which case we just return `orig_text`. 569 | 570 | def _strip_spaces(text): 571 | ns_chars = [] 572 | ns_to_s_map = collections.OrderedDict() 573 | for (i, c) in enumerate(text): 574 | if c == " ": 575 | continue 576 | ns_to_s_map[len(ns_chars)] = i 577 | ns_chars.append(c) 578 | ns_text = "".join(ns_chars) 579 | return (ns_text, ns_to_s_map) 580 | 581 | # We first tokenize `orig_text`, strip whitespace from the result 582 | # and `pred_text`, and check if they are the same length. If they are 583 | # NOT the same length, the heuristic has failed. If they are the same 584 | # length, we assume the characters are one-to-one aligned. 585 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 586 | 587 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 588 | 589 | start_position = tok_text.find(pred_text) 590 | if start_position == -1: 591 | if verbose_logging: 592 | logger.info( 593 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 594 | return orig_text 595 | end_position = start_position + len(pred_text) - 1 596 | 597 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 598 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 599 | 600 | if len(orig_ns_text) != len(tok_ns_text): 601 | if verbose_logging: 602 | logger.info("Length not equal after stripping spaces: '%s' vs '%s'", 603 | orig_ns_text, tok_ns_text) 604 | return orig_text 605 | 606 | # We then project the characters in `pred_text` back to `orig_text` using 607 | # the character-to-character alignment. 608 | tok_s_to_ns_map = {} 609 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 610 | tok_s_to_ns_map[tok_index] = i 611 | 612 | orig_start_position = None 613 | if start_position in tok_s_to_ns_map: 614 | ns_start_position = tok_s_to_ns_map[start_position] 615 | if ns_start_position in orig_ns_to_s_map: 616 | orig_start_position = orig_ns_to_s_map[ns_start_position] 617 | 618 | if orig_start_position is None: 619 | if verbose_logging: 620 | logger.info("Couldn't map start position") 621 | return orig_text 622 | 623 | orig_end_position = None 624 | if end_position in tok_s_to_ns_map: 625 | ns_end_position = tok_s_to_ns_map[end_position] 626 | if ns_end_position in orig_ns_to_s_map: 627 | orig_end_position = orig_ns_to_s_map[ns_end_position] 628 | 629 | if orig_end_position is None: 630 | if verbose_logging: 631 | logger.info("Couldn't map end position") 632 | return orig_text 633 | 634 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 635 | return output_text 636 | 637 | 638 | def _get_best_indexes(logits, n_best_size): 639 | """Get the n-best logits from a list.""" 640 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 641 | 642 | best_indexes = [] 643 | for i in range(len(index_and_score)): 644 | if i >= n_best_size: 645 | break 646 | best_indexes.append(index_and_score[i][0]) 647 | return best_indexes 648 | 649 | 650 | def _compute_softmax(scores): 651 | """Compute softmax probability over raw logits.""" 652 | if not scores: 653 | return [] 654 | 655 | max_score = None 656 | for score in scores: 657 | if max_score is None or score > max_score: 658 | max_score = score 659 | 660 | exp_scores = [] 661 | total_sum = 0.0 662 | for score in scores: 663 | x = math.exp(score - max_score) 664 | exp_scores.append(x) 665 | total_sum += x 666 | 667 | probs = [] 668 | for score in exp_scores: 669 | probs.append(score / total_sum) 670 | return probs 671 | 672 | 673 | def main(): 674 | parser = argparse.ArgumentParser() 675 | 676 | ## Required parameters 677 | parser.add_argument("--bert_config_file", default=None, type=str, required=True, 678 | help="The config json file corresponding to the pre-trained BERT model. " 679 | "This specifies the model architecture.") 680 | parser.add_argument("--vocab_file", default=None, type=str, required=True, 681 | help="The vocabulary file that the BERT model was trained on.") 682 | parser.add_argument("--output_dir", default=None, type=str, required=True, 683 | help="The output directory where the model checkpoints will be written.") 684 | 685 | ## Other parameters 686 | parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json") 687 | parser.add_argument("--predict_file", default=None, type=str, 688 | help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 689 | parser.add_argument("--init_checkpoint", default=None, type=str, 690 | help="Initial checkpoint (usually from a pre-trained BERT model).") 691 | parser.add_argument("--do_lower_case", default=True, action='store_true', 692 | help="Whether to lower case the input text. Should be True for uncased " 693 | "models and False for cased models.") 694 | parser.add_argument("--max_seq_length", default=384, type=int, 695 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 696 | "longer than this will be truncated, and sequences shorter than this will be padded.") 697 | parser.add_argument("--doc_stride", default=128, type=int, 698 | help="When splitting up a long document into chunks, how much stride to take between chunks.") 699 | parser.add_argument("--max_query_length", default=64, type=int, 700 | help="The maximum number of tokens for the question. Questions longer than this will " 701 | "be truncated to this length.") 702 | parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") 703 | parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") 704 | parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") 705 | parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") 706 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 707 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 708 | help="Total number of training epochs to perform.") 709 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 710 | help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " 711 | "of training.") 712 | parser.add_argument("--save_checkpoints_steps", default=1000, type=int, 713 | help="How often to save the model checkpoint.") 714 | parser.add_argument("--iterations_per_loop", default=1000, type=int, 715 | help="How many steps to make in each estimator call.") 716 | parser.add_argument("--n_best_size", default=20, type=int, 717 | help="The total number of n-best predictions to generate in the nbest_predictions.json " 718 | "output file.") 719 | parser.add_argument("--max_answer_length", default=30, type=int, 720 | help="The maximum length of an answer that can be generated. This is needed because the start " 721 | "and end predictions are not conditioned on one another.") 722 | 723 | parser.add_argument("--verbose_logging", default=False, action='store_true', 724 | help="If true, all of the warnings related to data processing will be printed. " 725 | "A number of warnings are expected for a normal SQuAD evaluation.") 726 | parser.add_argument("--no_cuda", 727 | default=False, 728 | action='store_true', 729 | help="Whether not to use CUDA when available") 730 | parser.add_argument("--local_rank", 731 | type=int, 732 | default=-1, 733 | help="local_rank for distributed training on gpus") 734 | parser.add_argument("--accumulate_gradients", 735 | type=int, 736 | default=1, 737 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)") 738 | parser.add_argument('--seed', 739 | type=int, 740 | default=42, 741 | help="random seed for initialization") 742 | parser.add_argument('--gradient_accumulation_steps', 743 | type=int, 744 | default=1, 745 | help="Number of updates steps to accumualte before performing a backward/update pass.") 746 | 747 | args = parser.parse_args() 748 | 749 | if args.local_rank == -1 or args.no_cuda: 750 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 751 | n_gpu = torch.cuda.device_count() 752 | else: 753 | device = torch.device("cuda", args.local_rank) 754 | n_gpu = 1 755 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 756 | torch.distributed.init_process_group(backend='nccl') 757 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 758 | 759 | if args.accumulate_gradients < 1: 760 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( 761 | args.accumulate_gradients)) 762 | 763 | args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients) 764 | 765 | random.seed(args.seed) 766 | np.random.seed(args.seed) 767 | torch.manual_seed(args.seed) 768 | if n_gpu > 0: 769 | torch.cuda.manual_seed_all(args.seed) 770 | 771 | if not args.do_train and not args.do_predict: 772 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 773 | 774 | if args.do_train: 775 | if not args.train_file: 776 | raise ValueError( 777 | "If `do_train` is True, then `train_file` must be specified.") 778 | if args.do_predict: 779 | if not args.predict_file: 780 | raise ValueError( 781 | "If `do_predict` is True, then `predict_file` must be specified.") 782 | 783 | bert_config = BertConfig.from_json_file(args.bert_config_file) 784 | 785 | if args.max_seq_length > bert_config.max_position_embeddings: 786 | raise ValueError( 787 | "Cannot use sequence length %d because the BERT model " 788 | "was only trained up to sequence length %d" % 789 | (args.max_seq_length, bert_config.max_position_embeddings)) 790 | 791 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 792 | raise ValueError("Output directory () already exists and is not empty.") 793 | os.makedirs(args.output_dir, exist_ok=True) 794 | 795 | tokenizer = tokenization.FullTokenizer( 796 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) 797 | 798 | train_examples = None 799 | num_train_steps = None 800 | if args.do_train: 801 | train_examples = read_squad_examples( 802 | input_file=args.train_file, is_training=True) 803 | num_train_steps = int( 804 | len(train_examples) / args.train_batch_size * args.num_train_epochs) 805 | 806 | model = BertForQuestionAnswering(bert_config) 807 | if args.init_checkpoint is not None: 808 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 809 | model.to(device) 810 | 811 | if args.local_rank != -1: 812 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 813 | output_device=args.local_rank) 814 | elif n_gpu > 1: 815 | model = torch.nn.DataParallel(model) 816 | 817 | no_decay = ['bias', 'gamma', 'beta'] 818 | optimizer_parameters = [ 819 | {'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01}, 820 | {'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0} 821 | ] 822 | 823 | optimizer = BERTAdam(optimizer_parameters, 824 | lr=args.learning_rate, 825 | warmup=args.warmup_proportion, 826 | t_total=num_train_steps) 827 | 828 | global_step = 0 829 | if args.do_train: 830 | train_features = convert_examples_to_features( 831 | examples=train_examples, 832 | tokenizer=tokenizer, 833 | max_seq_length=args.max_seq_length, 834 | doc_stride=args.doc_stride, 835 | max_query_length=args.max_query_length, 836 | is_training=True) 837 | logger.info("***** Running training *****") 838 | logger.info(" Num orig examples = %d", len(train_examples)) 839 | logger.info(" Num split examples = %d", len(train_features)) 840 | logger.info(" Batch size = %d", args.train_batch_size) 841 | logger.info(" Num steps = %d", num_train_steps) 842 | 843 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 844 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 845 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 846 | all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) 847 | all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) 848 | 849 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 850 | all_start_positions, all_end_positions) 851 | if args.local_rank == -1: 852 | train_sampler = RandomSampler(train_data) 853 | else: 854 | train_sampler = DistributedSampler(train_data) 855 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 856 | 857 | model.train() 858 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 859 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 860 | input_ids, input_mask, segment_ids, start_positions, end_positions = batch 861 | input_ids = input_ids.to(device) 862 | input_mask = input_mask.to(device) 863 | segment_ids = segment_ids.to(device) 864 | start_positions = start_positions.to(device) 865 | end_positions = start_positions.to(device) 866 | 867 | start_positions = start_positions.view(-1, 1) 868 | end_positions = end_positions.view(-1, 1) 869 | 870 | loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) 871 | if n_gpu > 1: 872 | loss = loss.mean() # mean() to average on multi-gpu. 873 | 874 | loss.backward() 875 | if (step + 1) % args.gradient_accumulation_steps == 0: 876 | optimizer.step() # We have accumulated enought gradients 877 | model.zero_grad() 878 | global_step += 1 879 | 880 | if args.do_predict: 881 | eval_examples = read_squad_examples( 882 | input_file=args.predict_file, is_training=False) 883 | eval_features = convert_examples_to_features( 884 | examples=eval_examples, 885 | tokenizer=tokenizer, 886 | max_seq_length=args.max_seq_length, 887 | doc_stride=args.doc_stride, 888 | max_query_length=args.max_query_length, 889 | is_training=False) 890 | 891 | logger.info("***** Running predictions *****") 892 | logger.info(" Num orig examples = %d", len(eval_examples)) 893 | logger.info(" Num split examples = %d", len(eval_features)) 894 | logger.info(" Batch size = %d", args.predict_batch_size) 895 | 896 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 897 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 898 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 899 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 900 | 901 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) 902 | if args.local_rank == -1: 903 | eval_sampler = SequentialSampler(eval_data) 904 | else: 905 | eval_sampler = DistributedSampler(eval_data) 906 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size) 907 | 908 | model.eval() 909 | all_results = [] 910 | logger.info("Start evaluating") 911 | for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): 912 | if len(all_results) % 1000 == 0: 913 | logger.info("Processing example: %d" % (len(all_results))) 914 | 915 | input_ids = input_ids.to(device) 916 | input_mask = input_mask.to(device) 917 | segment_ids = segment_ids.to(device) 918 | 919 | with torch.no_grad(): 920 | batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask) 921 | 922 | for i, example_index in enumerate(example_indices): 923 | start_logits = batch_start_logits[i].detach().cpu().tolist() 924 | end_logits = batch_end_logits[i].detach().cpu().tolist() 925 | 926 | eval_feature = eval_features[example_index.item()] 927 | unique_id = int(eval_feature.unique_id) 928 | all_results.append(RawResult(unique_id=unique_id, 929 | start_logits=start_logits, 930 | end_logits=end_logits)) 931 | 932 | output_prediction_file = os.path.join(args.output_dir, "predictions.json") 933 | output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") 934 | write_predictions(eval_examples, eval_features, all_results, 935 | args.n_best_size, args.max_answer_length, 936 | args.do_lower_case, output_prediction_file, 937 | output_nbest_file, args.verbose_logging) 938 | 939 | 940 | if __name__ == "__main__": 941 | main() 942 | --------------------------------------------------------------------------------