├── MANIFEST.in ├── samples ├── input.txt └── sample_text.txt ├── requirements.txt ├── docker └── Dockerfile ├── .circleci └── config.yml ├── pytorch_pretrained_bert ├── __init__.py ├── __main__.py ├── convert_tf_checkpoint_to_pytorch.py ├── optimization.py ├── file_utils.py └── tokenization.py ├── examples ├── run.sh ├── evaluate-v1.1.py ├── eval_squad_v2.0.py ├── extract_features.py ├── run_swag.py ├── run_classifier.py └── run_lm_finetuning.py ├── eval.py ├── tests ├── optimization_test.py ├── tokenization_test.py └── modeling_test.py ├── .gitignore ├── setup.py ├── gridSearch.py └── LICENSE /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /samples/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | torch>=0.4.1 3 | # progress bars in model download and training scripts 4 | tqdm 5 | # Accessing files from S3 directly. 6 | boto3 7 | # Used for downloading models over HTTP 8 | requests -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | 3 | RUN git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext 4 | 5 | RUN pip install pytorch-pretrained-bert 6 | 7 | WORKDIR /workspace -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build: 4 | working_directory: ~/pytorch-pretrained-BERT 5 | docker: 6 | - image: circleci/python:3.7 7 | steps: 8 | - checkout 9 | - run: sudo pip install --progress-bar off . 10 | - run: sudo pip install pytest 11 | - run: python -m pytest -sv tests/ 12 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.0" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 4 | BertForMaskedLM, BertForNextSentencePrediction, 5 | BertForSequenceClassification, BertForMultipleChoice, 6 | BertForTokenClassification, BertForQuestionAnswering) 7 | from .optimization import BertAdam 8 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE 9 | -------------------------------------------------------------------------------- /examples/run.sh: -------------------------------------------------------------------------------- 1 | export SQUAD_DIR=/home/meefly/working/tdt/01_data/squad 2 | nohup python run_squad_v2.0.py \ 3 | --bert_model /home/meefly/working/tdt/03_bert/pytorch_pretrained_BERT/uncased_L-12_H-768_A-12/ \ 4 | --do_train \ 5 | --do_predict \ 6 | --do_lower_case \ 7 | --train_file $SQUAD_DIR/train-v2.0.json \ 8 | --predict_file $SQUAD_DIR/dev-v2.0.json \ 9 | --train_batch_size 12 \ 10 | --learning_rate 3e-5 \ 11 | --num_train_epochs 2.0 \ 12 | --max_seq_length 384 \ 13 | --doc_stride 128 \ 14 | --output_dir /tmp/debug_squad/ & 15 | 16 | 17 | python eval_squad_v2.0.py $SQUAD_DIR/dev-v2.0.json /tmp/debug_squad/predictions.json 18 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | try: 5 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 6 | except ModuleNotFoundError: 7 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 8 | "In that case, it requires TensorFlow to be installed. Please see " 9 | "https://www.tensorflow.org/install/ for installation instructions.") 10 | raise 11 | 12 | if len(sys.argv) != 5: 13 | # pylint: disable=line-too-long 14 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 15 | else: 16 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 17 | TF_CONFIG = sys.argv.pop() 18 | TF_CHECKPOINT = sys.argv.pop() 19 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | filepath = '/tmp' 4 | filenames = [filename for filename in os.listdir(filepath) if 'SQuAD' in filename] 5 | with open('results.out', 'r') as f: 6 | evaled = [line.strip() for line in f.readlines()] 7 | filenames = [filename for filename in filenames if filename not in evaled] 8 | for filename in filenames: 9 | cmd = [] 10 | cmd.append("export SQUAD_DIR=/home/meijie/data/squad") 11 | cmd.append('echo {}>>results.out'.format(filename)) 12 | for epoch in range(5): 13 | try: 14 | if 'v1' in filename: 15 | cmd.append("python examples/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json" 16 | " /tmp/{0}/predictions_{1}.json >>results.out" 17 | .format(filename, epoch)) 18 | elif 'v2' in filename: 19 | cmd.append("python examples/eval_squad_v2.0.py $SQUAD_DIR/dev-v2.0.json" 20 | " /tmp/{0}/predictions_{1}.json >>results.out" 21 | .format(filename, epoch)) 22 | except Exception as e: 23 | print(e) 24 | cmd = ";".join(cmd) 25 | os.system(cmd) 26 | -------------------------------------------------------------------------------- /tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | 21 | import torch 22 | 23 | from pytorch_pretrained_bert import BertAdam 24 | 25 | class OptimizationTest(unittest.TestCase): 26 | 27 | def assertListAlmostEqual(self, list1, list2, tol): 28 | self.assertEqual(len(list1), len(list2)) 29 | for a, b in zip(list1, list2): 30 | self.assertAlmostEqual(a, b, delta=tol) 31 | 32 | def test_adam(self): 33 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 34 | target = torch.tensor([0.4, 0.2, -0.5]) 35 | criterion = torch.nn.MSELoss() 36 | # No warmup, constant schedule, no gradient clipping 37 | optimizer = BertAdam(params=[w], lr=2e-1, 38 | weight_decay=0.0, 39 | max_grad_norm=-1) 40 | for _ in range(100): 41 | loss = criterion(w, target) 42 | loss.backward() 43 | optimizer.step() 44 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 45 | w.grad.zero_() 46 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 47 | 48 | 49 | if __name__ == "__main__": 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | # vscode 119 | .vscode 120 | 121 | # TF code 122 | tensorflow_code 123 | 124 | # nohup 125 | *.out 126 | # Models 127 | models -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | 4 | To create the package for pypi. 5 | 6 | 1. Change the version in __init__.py and setup.py. 7 | 8 | 2. Commit these changes with the message: "Release: VERSION" 9 | 10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 11 | Push the tag to git: git push --tags origin master 12 | 13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 14 | creating the wheel and the source distribution (obviously). 15 | 16 | For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. 17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 18 | 19 | For the sources, run: "python setup.py sdist" 20 | You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. 21 | 22 | 5. Check that everything looks correct by uploading the package to the pypi test server: 23 | 24 | twine upload dist/* -r pypitest 25 | (pypi suggest using twine as other methods upload files via plaintext.) 26 | 27 | Check that you can install it in a virtualenv by running: 28 | pip install -i https://testpypi.python.org/pypi allennlp 29 | 30 | 6. Upload the final version to actual pypi: 31 | twine upload dist/* -r pypi 32 | 33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 34 | 35 | """ 36 | from setuptools import find_packages, setup 37 | 38 | setup( 39 | name="pytorch_pretrained_bert", 40 | version="0.4.0", 41 | author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors", 42 | author_email="thomas@huggingface.co", 43 | description="PyTorch version of Google AI BERT model with script to load Google pre-trained models", 44 | long_description=open("README.md", "r", encoding='utf-8').read(), 45 | long_description_content_type="text/markdown", 46 | keywords='BERT NLP deep learning google', 47 | license='Apache', 48 | url="https://github.com/huggingface/pytorch-pretrained-BERT", 49 | packages=find_packages(exclude=["*.tests", "*.tests.*", 50 | "tests.*", "tests"]), 51 | install_requires=['torch>=0.4.1', 52 | 'numpy', 53 | 'boto3', 54 | 'requests', 55 | 'tqdm'], 56 | entry_points={ 57 | 'console_scripts': [ 58 | "pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main" 59 | ] 60 | }, 61 | python_requires='>=3.5.0', 62 | tests_require=['pytest'], 63 | classifiers=[ 64 | 'Intended Audience :: Science/Research', 65 | 'License :: OSI Approved :: Apache Software License', 66 | 'Programming Language :: Python :: 3', 67 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 68 | ], 69 | ) 70 | -------------------------------------------------------------------------------- /gridSearch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | import time 4 | theta = [1, 3, 10] 5 | # theta = [1] 6 | # alpha = [0.001, 0.03, 0.3] 7 | alpha = [0, 0.6, 0.7, 1] 8 | beta = [0, 1, 2, 4] 9 | small_or_large = 'large' 10 | for theta, alpha, beta in itertools.product(theta, alpha, beta): 11 | cmd = [] 12 | cmd.append("export CUDA_VISIBLE_DEVICES=0,2,3,4,6,7") 13 | cmd.append("export SQUAD_DIR=/home/meijie/data/squad") 14 | cmd.append("export PYTHONPATH=/home/meijie/working/pytorch_pretrained_BERT/:$PYTHONPATH") 15 | if small_or_large == 'small': 16 | cmd.append("export SAVE_DIR=/tmp/SQuAD_v1-{0}_{1}_{2}_newloss_saveLoss/".format(theta, alpha, beta)) 17 | cmd.append("python examples/run_squad.py \ 18 | --bert_model /data/nfsdata/meijie/data/uncased_L-12_H-768_A-12 \ 19 | --do_train \ 20 | --do_predict \ 21 | --do_lower_case \ 22 | --train_file $SQUAD_DIR/train-v1.1.json \ 23 | --predict_file $SQUAD_DIR/dev-v1.1.json \ 24 | --train_batch_size 1 \ 25 | --learning_rate 3e-5 \ 26 | --num_train_epochs 3.0 \ 27 | --max_seq_length 384 \ 28 | --doc_stride 128 \ 29 | --seed 1\ 30 | --theta {0}\ 31 | --alpha {1}\ 32 | --beta {2}\ 33 | --output_dir $SAVE_DIR > ./out/{0}_{1}_{2}_newloss_saveLoss.out 2>&1" 34 | .format(theta, alpha, beta)) 35 | elif small_or_large == 'large': 36 | cmd.append("export SAVE_DIR=/tmp/SQuAD_v2-{0}_{1}_{2}_newloss_large_2/".format(theta, alpha, beta)) 37 | cmd.append("python examples/run_squad.py \ 38 | --bert_model /home/meijie/data/bert/uncased_L-24_H-1024_A-16 \ 39 | --do_train \ 40 | --version_2_with_negative\ 41 | --do_predict \ 42 | --do_lower_case \ 43 | --train_file $SQUAD_DIR/train-v2.0.json \ 44 | --predict_file $SQUAD_DIR/dev-v2.0.json \ 45 | --learning_rate 3e-5 \ 46 | --num_train_epochs 3 \ 47 | --max_seq_length 384 \ 48 | --doc_stride 128 \ 49 | --output_dir $SAVE_DIR \ 50 | --train_batch_size 24 \ 51 | --theta {0}\ 52 | --alpha {1}\ 53 | --beta {2}\ 54 | --gradient_accumulation_steps 2\ 55 | --loss_scale 128 > ./out/{0}_{1}_{2}_newloss_large_2.out 2>&1".format(theta, alpha, beta)) 56 | cmd = ";".join(cmd) 57 | for i in range(4): 58 | return_code = os.system(cmd) 59 | if return_code == 0: 60 | break 61 | else: 62 | print('sleep for {} secs'.format(10 ** i)) 63 | time.sleep(10 ** i) 64 | -------------------------------------------------------------------------------- /examples/evaluate-v1.1.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | def remove_articles(text): 14 | return re.sub(r'\b(a|an|the)\b', ' ', text) 15 | 16 | def white_space_fix(text): 17 | return ' '.join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | 29 | def f1_score(prediction, ground_truth): 30 | prediction_tokens = normalize_answer(prediction).split() 31 | ground_truth_tokens = normalize_answer(ground_truth).split() 32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 33 | num_same = sum(common.values()) 34 | if num_same == 0: 35 | return 0 36 | precision = 1.0 * num_same / len(prediction_tokens) 37 | recall = 1.0 * num_same / len(ground_truth_tokens) 38 | f1 = (2 * precision * recall) / (precision + recall) 39 | return f1 40 | 41 | 42 | def exact_match_score(prediction, ground_truth): 43 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 44 | 45 | 46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 47 | scores_for_ground_truths = [] 48 | for ground_truth in ground_truths: 49 | score = metric_fn(prediction, ground_truth) 50 | scores_for_ground_truths.append(score) 51 | return max(scores_for_ground_truths) 52 | 53 | 54 | def evaluate(dataset, predictions): 55 | f1 = exact_match = total = 0 56 | for article in dataset: 57 | for paragraph in article['paragraphs']: 58 | for qa in paragraph['qas']: 59 | total += 1 60 | if qa['id'] not in predictions: 61 | message = 'Unanswered question ' + qa['id'] + \ 62 | ' will receive score 0.' 63 | print(message, file=sys.stderr) 64 | continue 65 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 66 | prediction = predictions[qa['id']] 67 | exact_match += metric_max_over_ground_truths( 68 | exact_match_score, prediction, ground_truths) 69 | f1 += metric_max_over_ground_truths( 70 | f1_score, prediction, ground_truths) 71 | 72 | exact_match = 100.0 * exact_match / total 73 | f1 = 100.0 * f1 / total 74 | 75 | return {'exact_match': exact_match, 'f1': f1} 76 | 77 | 78 | if __name__ == '__main__': 79 | expected_version = '1.1' 80 | parser = argparse.ArgumentParser( 81 | description='Evaluation for SQuAD ' + expected_version) 82 | parser.add_argument('dataset_file', help='Dataset file') 83 | parser.add_argument('prediction_file', help='Prediction File') 84 | args = parser.parse_args() 85 | with open(args.dataset_file) as dataset_file: 86 | dataset_json = json.load(dataset_file) 87 | if (dataset_json['version'] != expected_version): 88 | print('Evaluation expects v-' + expected_version + 89 | ', but got dataset with v-' + dataset_json['version'], 90 | file=sys.stderr) 91 | dataset = dataset_json['data'] 92 | with open(args.prediction_file) as prediction_file: 93 | predictions = json.load(prediction_file) 94 | print(json.dumps(evaluate(dataset, predictions))) 95 | -------------------------------------------------------------------------------- /samples/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from .modeling import BertConfig, BertForPreTraining 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | config_path = os.path.abspath(bert_config_file) 32 | tf_path = os.path.abspath(tf_checkpoint_path) 33 | print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path)) 34 | # Load weights from TF model 35 | init_vars = tf.train.list_variables(tf_path) 36 | names = [] 37 | arrays = [] 38 | for name, shape in init_vars: 39 | print("Loading TF weight {} with shape {}".format(name, shape)) 40 | array = tf.train.load_variable(tf_path, name) 41 | names.append(name) 42 | arrays.append(array) 43 | 44 | # Initialise PyTorch model 45 | config = BertConfig.from_json_file(bert_config_file) 46 | print("Building PyTorch model from configuration: {}".format(str(config))) 47 | model = BertForPreTraining(config) 48 | 49 | for name, array in zip(names, arrays): 50 | name = name.split('/') 51 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 52 | # which are not required for using pretrained model 53 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name): 54 | print("Skipping {}".format("/".join(name))) 55 | continue 56 | pointer = model 57 | for m_name in name: 58 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 59 | l = re.split(r'_(\d+)', m_name) 60 | else: 61 | l = [m_name] 62 | if l[0] == 'kernel' or l[0] == 'gamma': 63 | pointer = getattr(pointer, 'weight') 64 | elif l[0] == 'output_bias' or l[0] == 'beta': 65 | pointer = getattr(pointer, 'bias') 66 | elif l[0] == 'output_weights': 67 | pointer = getattr(pointer, 'weight') 68 | else: 69 | pointer = getattr(pointer, l[0]) 70 | if len(l) >= 2: 71 | num = int(l[1]) 72 | pointer = pointer[num] 73 | if m_name[-11:] == '_embeddings': 74 | pointer = getattr(pointer, 'weight') 75 | elif m_name == 'kernel': 76 | array = np.transpose(array) 77 | try: 78 | assert pointer.shape == array.shape 79 | except AssertionError as e: 80 | e.args += (pointer.shape, array.shape) 81 | raise 82 | print("Initialize PyTorch weight {}".format(name)) 83 | pointer.data = torch.from_numpy(array) 84 | 85 | # Save pytorch-model 86 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 87 | torch.save(model.state_dict(), pytorch_dump_path) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | ## Required parameters 93 | parser.add_argument("--tf_checkpoint_path", 94 | default = None, 95 | type = str, 96 | required = True, 97 | help = "Path the TensorFlow checkpoint path.") 98 | parser.add_argument("--bert_config_file", 99 | default = None, 100 | type = str, 101 | required = True, 102 | help = "The config json file corresponding to the pre-trained BERT model. \n" 103 | "This specifies the model architecture.") 104 | parser.add_argument("--pytorch_dump_path", 105 | default = None, 106 | type = str, 107 | required = True, 108 | help = "Path to the output PyTorch model.") 109 | args = parser.parse_args() 110 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 111 | args.bert_config_file, 112 | args.pytorch_dump_path) 113 | -------------------------------------------------------------------------------- /tests/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import unittest 21 | 22 | from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer, 23 | _is_whitespace, _is_control, _is_punctuation) 24 | 25 | 26 | class TokenizationTest(unittest.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer: 34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 35 | 36 | vocab_file = vocab_writer.name 37 | 38 | tokenizer = BertTokenizer(vocab_file) 39 | os.remove(vocab_file) 40 | 41 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 42 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 43 | 44 | self.assertListEqual( 45 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 46 | 47 | def test_full_tokenizer_raises_error_for_long_sequences(self): 48 | vocab_tokens = [ 49 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 50 | "##ing", "," 51 | ] 52 | with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer: 53 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 54 | vocab_file = vocab_writer.name 55 | 56 | tokenizer = BertTokenizer(vocab_file, max_len=10) 57 | os.remove(vocab_file) 58 | tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time") 59 | indices = tokenizer.convert_tokens_to_ids(tokens) 60 | self.assertListEqual(indices, [0 for _ in range(10)]) 61 | 62 | tokens = tokenizer.tokenize(u"the cat sat on the mat in the summer time .") 63 | self.assertRaises(ValueError, tokenizer.convert_tokens_to_ids, tokens) 64 | 65 | def test_chinese(self): 66 | tokenizer = BasicTokenizer() 67 | 68 | self.assertListEqual( 69 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 70 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 71 | 72 | def test_basic_tokenizer_lower(self): 73 | tokenizer = BasicTokenizer(do_lower_case=True) 74 | 75 | self.assertListEqual( 76 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 77 | ["hello", "!", "how", "are", "you", "?"]) 78 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 79 | 80 | def test_basic_tokenizer_no_lower(self): 81 | tokenizer = BasicTokenizer(do_lower_case=False) 82 | 83 | self.assertListEqual( 84 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 85 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 86 | 87 | def test_wordpiece_tokenizer(self): 88 | vocab_tokens = [ 89 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 90 | "##ing" 91 | ] 92 | 93 | vocab = {} 94 | for (i, token) in enumerate(vocab_tokens): 95 | vocab[token] = i 96 | tokenizer = WordpieceTokenizer(vocab=vocab) 97 | 98 | self.assertListEqual(tokenizer.tokenize(""), []) 99 | 100 | self.assertListEqual( 101 | tokenizer.tokenize("unwanted running"), 102 | ["un", "##want", "##ed", "runn", "##ing"]) 103 | 104 | self.assertListEqual( 105 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(_is_whitespace(u" ")) 109 | self.assertTrue(_is_whitespace(u"\t")) 110 | self.assertTrue(_is_whitespace(u"\r")) 111 | self.assertTrue(_is_whitespace(u"\n")) 112 | self.assertTrue(_is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(_is_whitespace(u"A")) 115 | self.assertFalse(_is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(_is_control(u"\u0005")) 119 | 120 | self.assertFalse(_is_control(u"A")) 121 | self.assertFalse(_is_control(u" ")) 122 | self.assertFalse(_is_control(u"\t")) 123 | self.assertFalse(_is_control(u"\r")) 124 | 125 | def test_is_punctuation(self): 126 | self.assertTrue(_is_punctuation(u"-")) 127 | self.assertTrue(_is_punctuation(u"$")) 128 | self.assertTrue(_is_punctuation(u"`")) 129 | self.assertTrue(_is_punctuation(u".")) 130 | 131 | self.assertFalse(_is_punctuation(u"A")) 132 | self.assertFalse(_is_punctuation(u" ")) 133 | 134 | 135 | if __name__ == '__main__': 136 | unittest.main() 137 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | 23 | def warmup_cosine(x, warmup=0.002): 24 | if x < warmup: 25 | return x/warmup 26 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 27 | 28 | def warmup_constant(x, warmup=0.002): 29 | if x < warmup: 30 | return x/warmup 31 | return 1.0 32 | 33 | def warmup_linear(x, warmup=0.002): 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 - x 37 | 38 | SCHEDULES = { 39 | 'warmup_cosine':warmup_cosine, 40 | 'warmup_constant':warmup_constant, 41 | 'warmup_linear':warmup_linear, 42 | } 43 | 44 | 45 | class BertAdam(Optimizer): 46 | """Implements BERT version of Adam algorithm with weight decay fix. 47 | Params: 48 | lr: learning rate 49 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 50 | t_total: total number of training steps for the learning 51 | rate schedule, -1 means constant learning rate. Default: -1 52 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 53 | b1: Adams b1. Default: 0.9 54 | b2: Adams b2. Default: 0.999 55 | e: Adams epsilon. Default: 1e-6 56 | weight_decay: Weight decay. Default: 0.01 57 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 58 | """ 59 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 60 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 61 | max_grad_norm=1.0): 62 | if lr is not required and lr < 0.0: 63 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 64 | if schedule not in SCHEDULES: 65 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 66 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 67 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 68 | if not 0.0 <= b1 < 1.0: 69 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 70 | if not 0.0 <= b2 < 1.0: 71 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 72 | if not e >= 0.0: 73 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 74 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 75 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 76 | max_grad_norm=max_grad_norm) 77 | super(BertAdam, self).__init__(params, defaults) 78 | 79 | def get_lr(self): 80 | lr = [] 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | state = self.state[p] 84 | if len(state) == 0: 85 | return [0] 86 | if group['t_total'] != -1: 87 | schedule_fct = SCHEDULES[group['schedule']] 88 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 89 | else: 90 | lr_scheduled = group['lr'] 91 | lr.append(lr_scheduled) 92 | return lr 93 | 94 | def step(self, closure=None): 95 | """Performs a single optimization step. 96 | 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | """ 101 | loss = None 102 | if closure is not None: 103 | loss = closure() 104 | 105 | for group in self.param_groups: 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data 110 | if grad.is_sparse: 111 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 112 | 113 | state = self.state[p] 114 | 115 | # State initialization 116 | if len(state) == 0: 117 | state['step'] = 0 118 | # Exponential moving average of gradient values 119 | state['next_m'] = torch.zeros_like(p.data) 120 | # Exponential moving average of squared gradient values 121 | state['next_v'] = torch.zeros_like(p.data) 122 | 123 | next_m, next_v = state['next_m'], state['next_v'] 124 | beta1, beta2 = group['b1'], group['b2'] 125 | 126 | # Add grad clipping 127 | if group['max_grad_norm'] > 0: 128 | clip_grad_norm_(p, group['max_grad_norm']) 129 | 130 | # Decay the first and second moment running average coefficient 131 | # In-place operations to update the averages at the same time 132 | next_m.mul_(beta1).add_(1 - beta1, grad) 133 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 134 | update = next_m / (next_v.sqrt() + group['e']) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want to decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if group['weight_decay'] > 0.0: 144 | update += group['weight_decay'] * p.data 145 | 146 | if group['t_total'] != -1: 147 | schedule_fct = SCHEDULES[group['schedule']] 148 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 149 | else: 150 | lr_scheduled = group['lr'] 151 | 152 | update_with_lr = lr_scheduled * update 153 | p.data.add_(-update_with_lr) 154 | 155 | state['step'] += 1 156 | 157 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 158 | # No bias correction 159 | # bias_correction1 = 1 - beta1 ** state['step'] 160 | # bias_correction2 = 1 - beta2 ** state['step'] 161 | 162 | return loss 163 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | 7 | import os 8 | import logging 9 | import shutil 10 | import tempfile 11 | import json 12 | from urllib.parse import urlparse 13 | from pathlib import Path 14 | from typing import Optional, Tuple, Union, IO, Callable, Set 15 | from hashlib import sha256 16 | from functools import wraps 17 | 18 | from tqdm import tqdm 19 | 20 | import boto3 21 | from botocore.exceptions import ClientError 22 | import requests 23 | 24 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 25 | 26 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 27 | Path.home() / '.pytorch_pretrained_bert')) 28 | 29 | 30 | def url_to_filename(url: str, etag: str = None) -> str: 31 | """ 32 | Convert `url` into a hashed filename in a repeatable way. 33 | If `etag` is specified, append its hash to the url's, delimited 34 | by a period. 35 | """ 36 | url_bytes = url.encode('utf-8') 37 | url_hash = sha256(url_bytes) 38 | filename = url_hash.hexdigest() 39 | 40 | if etag: 41 | etag_bytes = etag.encode('utf-8') 42 | etag_hash = sha256(etag_bytes) 43 | filename += '.' + etag_hash.hexdigest() 44 | 45 | return filename 46 | 47 | 48 | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: 49 | """ 50 | Return the url and etag (which may be ``None``) stored for `filename`. 51 | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. 52 | """ 53 | if cache_dir is None: 54 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 55 | if isinstance(cache_dir, Path): 56 | cache_dir = str(cache_dir) 57 | 58 | cache_path = os.path.join(cache_dir, filename) 59 | if not os.path.exists(cache_path): 60 | raise FileNotFoundError("file {} not found".format(cache_path)) 61 | 62 | meta_path = cache_path + '.json' 63 | if not os.path.exists(meta_path): 64 | raise FileNotFoundError("file {} not found".format(meta_path)) 65 | 66 | with open(meta_path) as meta_file: 67 | metadata = json.load(meta_file) 68 | url = metadata['url'] 69 | etag = metadata['etag'] 70 | 71 | return url, etag 72 | 73 | 74 | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: 75 | """ 76 | Given something that might be a URL (or might be a local path), 77 | determine which. If it's a URL, download the file and cache it, and 78 | return the path to the cached file. If it's already a local path, 79 | make sure the file exists and then return the path. 80 | """ 81 | if cache_dir is None: 82 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 83 | if isinstance(url_or_filename, Path): 84 | url_or_filename = str(url_or_filename) 85 | if isinstance(cache_dir, Path): 86 | cache_dir = str(cache_dir) 87 | 88 | parsed = urlparse(url_or_filename) 89 | 90 | if parsed.scheme in ('http', 'https', 's3'): 91 | # URL, so get it from the cache (downloading if necessary) 92 | return get_from_cache(url_or_filename, cache_dir) 93 | elif os.path.exists(url_or_filename): 94 | # File, and it exists. 95 | return url_or_filename 96 | elif parsed.scheme == '': 97 | # File, but it doesn't exist. 98 | raise FileNotFoundError("file {} not found".format(url_or_filename)) 99 | else: 100 | # Something unknown 101 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 102 | 103 | 104 | def split_s3_path(url: str) -> Tuple[str, str]: 105 | """Split a full s3 path into the bucket name and path.""" 106 | parsed = urlparse(url) 107 | if not parsed.netloc or not parsed.path: 108 | raise ValueError("bad s3 path {}".format(url)) 109 | bucket_name = parsed.netloc 110 | s3_path = parsed.path 111 | # Remove '/' at beginning of path. 112 | if s3_path.startswith("/"): 113 | s3_path = s3_path[1:] 114 | return bucket_name, s3_path 115 | 116 | 117 | def s3_request(func: Callable): 118 | """ 119 | Wrapper function for s3 requests in order to create more helpful error 120 | messages. 121 | """ 122 | 123 | @wraps(func) 124 | def wrapper(url: str, *args, **kwargs): 125 | try: 126 | return func(url, *args, **kwargs) 127 | except ClientError as exc: 128 | if int(exc.response["Error"]["Code"]) == 404: 129 | raise FileNotFoundError("file {} not found".format(url)) 130 | else: 131 | raise 132 | 133 | return wrapper 134 | 135 | 136 | @s3_request 137 | def s3_etag(url: str) -> Optional[str]: 138 | """Check ETag on S3 object.""" 139 | s3_resource = boto3.resource("s3") 140 | bucket_name, s3_path = split_s3_path(url) 141 | s3_object = s3_resource.Object(bucket_name, s3_path) 142 | return s3_object.e_tag 143 | 144 | 145 | @s3_request 146 | def s3_get(url: str, temp_file: IO) -> None: 147 | """Pull a file directly from S3.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 151 | 152 | 153 | def http_get(url: str, temp_file: IO) -> None: 154 | req = requests.get(url, stream=True) 155 | content_length = req.headers.get('Content-Length') 156 | total = int(content_length) if content_length is not None else None 157 | progress = tqdm(unit="B", total=total) 158 | for chunk in req.iter_content(chunk_size=1024): 159 | if chunk: # filter out keep-alive new chunks 160 | progress.update(len(chunk)) 161 | temp_file.write(chunk) 162 | progress.close() 163 | 164 | 165 | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: 166 | """ 167 | Given a URL, look for the corresponding dataset in the local cache. 168 | If it's not there, download it. Then return the path to the cached file. 169 | """ 170 | if cache_dir is None: 171 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 172 | if isinstance(cache_dir, Path): 173 | cache_dir = str(cache_dir) 174 | 175 | os.makedirs(cache_dir, exist_ok=True) 176 | 177 | # Get eTag to add to filename, if it exists. 178 | if url.startswith("s3://"): 179 | etag = s3_etag(url) 180 | else: 181 | response = requests.head(url, allow_redirects=True) 182 | if response.status_code != 200: 183 | raise IOError("HEAD request failed for url {} with status code {}" 184 | .format(url, response.status_code)) 185 | etag = response.headers.get("ETag") 186 | 187 | filename = url_to_filename(url, etag) 188 | 189 | # get cache path to put the file 190 | cache_path = os.path.join(cache_dir, filename) 191 | 192 | if not os.path.exists(cache_path): 193 | # Download to temporary file, then copy to cache dir once finished. 194 | # Otherwise you get corrupt cache entries if the download gets interrupted. 195 | with tempfile.NamedTemporaryFile() as temp_file: 196 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 197 | 198 | # GET file object 199 | if url.startswith("s3://"): 200 | s3_get(url, temp_file) 201 | else: 202 | http_get(url, temp_file) 203 | 204 | # we are copying the file before closing it, so flush to avoid truncation 205 | temp_file.flush() 206 | # shutil.copyfileobj() starts at the current position, so go to the start 207 | temp_file.seek(0) 208 | 209 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 210 | with open(cache_path, 'wb') as cache_file: 211 | shutil.copyfileobj(temp_file, cache_file) 212 | 213 | logger.info("creating metadata file for %s", cache_path) 214 | meta = {'url': url, 'etag': etag} 215 | meta_path = cache_path + '.json' 216 | with open(meta_path, 'w') as meta_file: 217 | json.dump(meta, meta_file) 218 | 219 | logger.info("removing temp file %s", temp_file.name) 220 | 221 | return cache_path 222 | 223 | 224 | def read_set_from_file(filename: str) -> Set[str]: 225 | ''' 226 | Extract a de-duped collection (set) of text from a file. 227 | Expected file format is one item per line. 228 | ''' 229 | collection = set() 230 | with open(filename, 'r', encoding='utf-8') as file_: 231 | for line in file_: 232 | collection.add(line.rstrip()) 233 | return collection 234 | 235 | 236 | def get_file_extension(path: str, dot=True, lower: bool = True): 237 | ext = os.path.splitext(path)[1] 238 | ext = ext if dot else ext[1:] 239 | return ext.lower() if lower else ext 240 | -------------------------------------------------------------------------------- /examples/eval_squad_v2.0.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | OPTS = None 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 21 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 22 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 23 | parser.add_argument('--out-file', '-o', metavar='eval.json', 24 | help='Write accuracy metrics to file (default is stdout).') 25 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 26 | help='Model estimates of probability of no answer.') 27 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 28 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 29 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 30 | help='Save precision-recall curves to directory.') 31 | parser.add_argument('--verbose', '-v', action='store_true') 32 | if len(sys.argv) == 1: 33 | parser.print_help() 34 | sys.exit(1) 35 | return parser.parse_args() 36 | 37 | def make_qid_to_has_ans(dataset): 38 | qid_to_has_ans = {} 39 | for article in dataset: 40 | for p in article['paragraphs']: 41 | for qa in p['qas']: 42 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 43 | return qid_to_has_ans 44 | 45 | def normalize_answer(s): 46 | """Lower text and remove punctuation, articles and extra whitespace.""" 47 | def remove_articles(text): 48 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 49 | return re.sub(regex, ' ', text) 50 | def white_space_fix(text): 51 | return ' '.join(text.split()) 52 | def remove_punc(text): 53 | exclude = set(string.punctuation) 54 | return ''.join(ch for ch in text if ch not in exclude) 55 | def lower(text): 56 | return text.lower() 57 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 58 | 59 | def get_tokens(s): 60 | if not s: return [] 61 | return normalize_answer(s).split() 62 | 63 | def compute_exact(a_gold, a_pred): 64 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 65 | 66 | def compute_f1(a_gold, a_pred): 67 | gold_toks = get_tokens(a_gold) 68 | pred_toks = get_tokens(a_pred) 69 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 70 | num_same = sum(common.values()) 71 | if len(gold_toks) == 0 or len(pred_toks) == 0: 72 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 73 | return int(gold_toks == pred_toks) 74 | if num_same == 0: 75 | return 0 76 | precision = 1.0 * num_same / len(pred_toks) 77 | recall = 1.0 * num_same / len(gold_toks) 78 | f1 = (2 * precision * recall) / (precision + recall) 79 | return f1 80 | 81 | def get_raw_scores(dataset, preds): 82 | exact_scores = {} 83 | f1_scores = {} 84 | for article in dataset: 85 | for p in article['paragraphs']: 86 | for qa in p['qas']: 87 | qid = qa['id'] 88 | gold_answers = [a['text'] for a in qa['answers'] 89 | if normalize_answer(a['text'])] 90 | if not gold_answers: 91 | # For unanswerable questions, only correct answer is empty string 92 | gold_answers = [''] 93 | if qid not in preds: 94 | print('Missing prediction for %s' % qid) 95 | continue 96 | a_pred = preds[qid] 97 | # Take max over all gold answers 98 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 99 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 100 | return exact_scores, f1_scores 101 | 102 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 103 | new_scores = {} 104 | for qid, s in scores.items(): 105 | pred_na = na_probs[qid] > na_prob_thresh 106 | if pred_na: 107 | new_scores[qid] = float(not qid_to_has_ans[qid]) 108 | else: 109 | new_scores[qid] = s 110 | return new_scores 111 | 112 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 113 | if not qid_list: 114 | total = len(exact_scores) 115 | return collections.OrderedDict([ 116 | ('exact', 100.0 * sum(exact_scores.values()) / total), 117 | ('f1', 100.0 * sum(f1_scores.values()) / total), 118 | ('total', total), 119 | ]) 120 | else: 121 | total = len(qid_list) 122 | return collections.OrderedDict([ 123 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 124 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 125 | ('total', total), 126 | ]) 127 | 128 | def merge_eval(main_eval, new_eval, prefix): 129 | for k in new_eval: 130 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 131 | 132 | def plot_pr_curve(precisions, recalls, out_image, title): 133 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 134 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 135 | plt.xlabel('Recall') 136 | plt.ylabel('Precision') 137 | plt.xlim([0.0, 1.05]) 138 | plt.ylim([0.0, 1.05]) 139 | plt.title(title) 140 | plt.savefig(out_image) 141 | plt.clf() 142 | 143 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 144 | out_image=None, title=None): 145 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 146 | true_pos = 0.0 147 | cur_p = 1.0 148 | cur_r = 0.0 149 | precisions = [1.0] 150 | recalls = [0.0] 151 | avg_prec = 0.0 152 | for i, qid in enumerate(qid_list): 153 | if qid_to_has_ans[qid]: 154 | true_pos += scores[qid] 155 | cur_p = true_pos / float(i+1) 156 | cur_r = true_pos / float(num_true_pos) 157 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 158 | # i.e., if we can put a threshold after this point 159 | avg_prec += cur_p * (cur_r - recalls[-1]) 160 | precisions.append(cur_p) 161 | recalls.append(cur_r) 162 | if out_image: 163 | plot_pr_curve(precisions, recalls, out_image, title) 164 | return {'ap': 100.0 * avg_prec} 165 | 166 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 167 | qid_to_has_ans, out_image_dir): 168 | if out_image_dir and not os.path.exists(out_image_dir): 169 | os.makedirs(out_image_dir) 170 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 171 | if num_true_pos == 0: 172 | return 173 | pr_exact = make_precision_recall_eval( 174 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 175 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 176 | title='Precision-Recall curve for Exact Match score') 177 | pr_f1 = make_precision_recall_eval( 178 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 179 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 180 | title='Precision-Recall curve for F1 score') 181 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 182 | pr_oracle = make_precision_recall_eval( 183 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 184 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 185 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 186 | merge_eval(main_eval, pr_exact, 'pr_exact') 187 | merge_eval(main_eval, pr_f1, 'pr_f1') 188 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 189 | 190 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 191 | if not qid_list: 192 | return 193 | x = [na_probs[k] for k in qid_list] 194 | weights = np.ones_like(x) / float(len(x)) 195 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 196 | plt.xlabel('Model probability of no-answer') 197 | plt.ylabel('Proportion of dataset') 198 | plt.title('Histogram of no-answer probability: %s' % name) 199 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 200 | plt.clf() 201 | 202 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 203 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 204 | cur_score = num_no_ans 205 | best_score = cur_score 206 | best_thresh = 0.0 207 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 208 | for i, qid in enumerate(qid_list): 209 | if qid not in scores: continue 210 | if qid_to_has_ans[qid]: 211 | diff = scores[qid] 212 | else: 213 | if preds[qid]: 214 | diff = -1 215 | else: 216 | diff = 0 217 | cur_score += diff 218 | if cur_score > best_score: 219 | best_score = cur_score 220 | best_thresh = na_probs[qid] 221 | return 100.0 * best_score / len(scores), best_thresh 222 | 223 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 224 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 225 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 226 | main_eval['best_exact'] = best_exact 227 | main_eval['best_exact_thresh'] = exact_thresh 228 | main_eval['best_f1'] = best_f1 229 | main_eval['best_f1_thresh'] = f1_thresh 230 | 231 | def main(): 232 | with open(OPTS.data_file) as f: 233 | dataset_json = json.load(f) 234 | dataset = dataset_json['data'] 235 | with open(OPTS.pred_file) as f: 236 | preds = json.load(f) 237 | if OPTS.na_prob_file: 238 | with open(OPTS.na_prob_file) as f: 239 | na_probs = json.load(f) 240 | else: 241 | na_probs = {k: 0.0 for k in preds} 242 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 243 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 244 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 245 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 246 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 247 | OPTS.na_prob_thresh) 248 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 249 | OPTS.na_prob_thresh) 250 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 251 | if has_ans_qids: 252 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 253 | merge_eval(out_eval, has_ans_eval, 'HasAns') 254 | if no_ans_qids: 255 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 256 | merge_eval(out_eval, no_ans_eval, 'NoAns') 257 | if OPTS.na_prob_file: 258 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 259 | if OPTS.na_prob_file and OPTS.out_image_dir: 260 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 261 | qid_to_has_ans, OPTS.out_image_dir) 262 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 263 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 264 | if OPTS.out_file: 265 | with open(OPTS.out_file, 'w') as f: 266 | json.dump(out_eval, f) 267 | else: 268 | print(json.dumps(out_eval, indent=2)) 269 | 270 | if __name__ == '__main__': 271 | OPTS = parse_args() 272 | if OPTS.out_image_dir: 273 | import matplotlib 274 | matplotlib.use('Agg') 275 | import matplotlib.pyplot as plt 276 | main() 277 | 278 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /examples/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from a PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import collections 23 | import logging 24 | import json 25 | import re 26 | 27 | import torch 28 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 29 | from torch.utils.data.distributed import DistributedSampler 30 | 31 | from pytorch_pretrained_bert.tokenization import BertTokenizer 32 | from pytorch_pretrained_bert.modeling import BertModel 33 | 34 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 35 | datefmt = '%m/%d/%Y %H:%M:%S', 36 | level = logging.INFO) 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | class InputExample(object): 41 | 42 | def __init__(self, unique_id, text_a, text_b): 43 | self.unique_id = unique_id 44 | self.text_a = text_a 45 | self.text_b = text_b 46 | 47 | 48 | class InputFeatures(object): 49 | """A single set of features of data.""" 50 | 51 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 52 | self.unique_id = unique_id 53 | self.tokens = tokens 54 | self.input_ids = input_ids 55 | self.input_mask = input_mask 56 | self.input_type_ids = input_type_ids 57 | 58 | 59 | def convert_examples_to_features(examples, seq_length, tokenizer): 60 | """Loads a data file into a list of `InputBatch`s.""" 61 | 62 | features = [] 63 | for (ex_index, example) in enumerate(examples): 64 | tokens_a = tokenizer.tokenize(example.text_a) 65 | 66 | tokens_b = None 67 | if example.text_b: 68 | tokens_b = tokenizer.tokenize(example.text_b) 69 | 70 | if tokens_b: 71 | # Modifies `tokens_a` and `tokens_b` in place so that the total 72 | # length is less than the specified length. 73 | # Account for [CLS], [SEP], [SEP] with "- 3" 74 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 75 | else: 76 | # Account for [CLS] and [SEP] with "- 2" 77 | if len(tokens_a) > seq_length - 2: 78 | tokens_a = tokens_a[0:(seq_length - 2)] 79 | 80 | # The convention in BERT is: 81 | # (a) For sequence pairs: 82 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 83 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 84 | # (b) For single sequences: 85 | # tokens: [CLS] the dog is hairy . [SEP] 86 | # type_ids: 0 0 0 0 0 0 0 87 | # 88 | # Where "type_ids" are used to indicate whether this is the first 89 | # sequence or the second sequence. The embedding vectors for `type=0` and 90 | # `type=1` were learned during pre-training and are added to the wordpiece 91 | # embedding vector (and position vector). This is not *strictly* necessary 92 | # since the [SEP] token unambigiously separates the sequences, but it makes 93 | # it easier for the model to learn the concept of sequences. 94 | # 95 | # For classification tasks, the first vector (corresponding to [CLS]) is 96 | # used as as the "sentence vector". Note that this only makes sense because 97 | # the entire model is fine-tuned. 98 | tokens = [] 99 | input_type_ids = [] 100 | tokens.append("[CLS]") 101 | input_type_ids.append(0) 102 | for token in tokens_a: 103 | tokens.append(token) 104 | input_type_ids.append(0) 105 | tokens.append("[SEP]") 106 | input_type_ids.append(0) 107 | 108 | if tokens_b: 109 | for token in tokens_b: 110 | tokens.append(token) 111 | input_type_ids.append(1) 112 | tokens.append("[SEP]") 113 | input_type_ids.append(1) 114 | 115 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 116 | 117 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 118 | # tokens are attended to. 119 | input_mask = [1] * len(input_ids) 120 | 121 | # Zero-pad up to the sequence length. 122 | while len(input_ids) < seq_length: 123 | input_ids.append(0) 124 | input_mask.append(0) 125 | input_type_ids.append(0) 126 | 127 | assert len(input_ids) == seq_length 128 | assert len(input_mask) == seq_length 129 | assert len(input_type_ids) == seq_length 130 | 131 | if ex_index < 5: 132 | logger.info("*** Example ***") 133 | logger.info("unique_id: %s" % (example.unique_id)) 134 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) 135 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 136 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 137 | logger.info( 138 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 139 | 140 | features.append( 141 | InputFeatures( 142 | unique_id=example.unique_id, 143 | tokens=tokens, 144 | input_ids=input_ids, 145 | input_mask=input_mask, 146 | input_type_ids=input_type_ids)) 147 | return features 148 | 149 | 150 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 151 | """Truncates a sequence pair in place to the maximum length.""" 152 | 153 | # This is a simple heuristic which will always truncate the longer sequence 154 | # one token at a time. This makes more sense than truncating an equal percent 155 | # of tokens from each, since if one sequence is very short then each token 156 | # that's truncated likely contains more information than a longer sequence. 157 | while True: 158 | total_length = len(tokens_a) + len(tokens_b) 159 | if total_length <= max_length: 160 | break 161 | if len(tokens_a) > len(tokens_b): 162 | tokens_a.pop() 163 | else: 164 | tokens_b.pop() 165 | 166 | 167 | def read_examples(input_file): 168 | """Read a list of `InputExample`s from an input file.""" 169 | examples = [] 170 | unique_id = 0 171 | with open(input_file, "r", encoding='utf-8') as reader: 172 | while True: 173 | line = reader.readline() 174 | if not line: 175 | break 176 | line = line.strip() 177 | text_a = None 178 | text_b = None 179 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 180 | if m is None: 181 | text_a = line 182 | else: 183 | text_a = m.group(1) 184 | text_b = m.group(2) 185 | examples.append( 186 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 187 | unique_id += 1 188 | return examples 189 | 190 | 191 | def main(): 192 | parser = argparse.ArgumentParser() 193 | 194 | ## Required parameters 195 | parser.add_argument("--input_file", default=None, type=str, required=True) 196 | parser.add_argument("--output_file", default=None, type=str, required=True) 197 | parser.add_argument("--bert_model", default=None, type=str, required=True, 198 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 199 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 200 | 201 | ## Other parameters 202 | parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") 203 | parser.add_argument("--layers", default="-1,-2,-3,-4", type=str) 204 | parser.add_argument("--max_seq_length", default=128, type=int, 205 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer " 206 | "than this will be truncated, and sequences shorter than this will be padded.") 207 | parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.") 208 | parser.add_argument("--local_rank", 209 | type=int, 210 | default=-1, 211 | help = "local_rank for distributed training on gpus") 212 | parser.add_argument("--no_cuda", 213 | action='store_true', 214 | help="Whether not to use CUDA when available") 215 | 216 | args = parser.parse_args() 217 | 218 | if args.local_rank == -1 or args.no_cuda: 219 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 220 | n_gpu = torch.cuda.device_count() 221 | else: 222 | device = torch.device("cuda", args.local_rank) 223 | n_gpu = 1 224 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 225 | torch.distributed.init_process_group(backend='nccl') 226 | logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1))) 227 | 228 | layer_indexes = [int(x) for x in args.layers.split(",")] 229 | 230 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 231 | 232 | examples = read_examples(args.input_file) 233 | 234 | features = convert_examples_to_features( 235 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer) 236 | 237 | unique_id_to_feature = {} 238 | for feature in features: 239 | unique_id_to_feature[feature.unique_id] = feature 240 | 241 | model = BertModel.from_pretrained(args.bert_model) 242 | model.to(device) 243 | 244 | if args.local_rank != -1: 245 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 246 | output_device=args.local_rank) 247 | elif n_gpu > 1: 248 | model = torch.nn.DataParallel(model) 249 | 250 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 251 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 252 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 253 | 254 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) 255 | if args.local_rank == -1: 256 | eval_sampler = SequentialSampler(eval_data) 257 | else: 258 | eval_sampler = DistributedSampler(eval_data) 259 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) 260 | 261 | model.eval() 262 | with open(args.output_file, "w", encoding='utf-8') as writer: 263 | for input_ids, input_mask, example_indices in eval_dataloader: 264 | input_ids = input_ids.to(device) 265 | input_mask = input_mask.to(device) 266 | 267 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) 268 | all_encoder_layers = all_encoder_layers 269 | 270 | for b, example_index in enumerate(example_indices): 271 | feature = features[example_index.item()] 272 | unique_id = int(feature.unique_id) 273 | # feature = unique_id_to_feature[unique_id] 274 | output_json = collections.OrderedDict() 275 | output_json["linex_index"] = unique_id 276 | all_out_features = [] 277 | for (i, token) in enumerate(feature.tokens): 278 | all_layers = [] 279 | for (j, layer_index) in enumerate(layer_indexes): 280 | layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy() 281 | layer_output = layer_output[b] 282 | layers = collections.OrderedDict() 283 | layers["index"] = layer_index 284 | layers["values"] = [ 285 | round(x.item(), 6) for x in layer_output[i] 286 | ] 287 | all_layers.append(layers) 288 | out_features = collections.OrderedDict() 289 | out_features["token"] = token 290 | out_features["layers"] = all_layers 291 | all_out_features.append(out_features) 292 | output_json["features"] = all_out_features 293 | writer.write(json.dumps(output_json) + "\n") 294 | 295 | 296 | if __name__ == "__main__": 297 | main() 298 | -------------------------------------------------------------------------------- /tests/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import json 21 | import random 22 | 23 | import torch 24 | 25 | from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, 26 | BertForNextSentencePrediction, BertForPreTraining, 27 | BertForQuestionAnswering, BertForSequenceClassification, 28 | BertForTokenClassification) 29 | 30 | 31 | class BertModelTest(unittest.TestCase): 32 | class BertModelTester(object): 33 | 34 | def __init__(self, 35 | parent, 36 | batch_size=13, 37 | seq_length=7, 38 | is_training=True, 39 | use_input_mask=True, 40 | use_token_type_ids=True, 41 | use_labels=True, 42 | vocab_size=99, 43 | hidden_size=32, 44 | num_hidden_layers=5, 45 | num_attention_heads=4, 46 | intermediate_size=37, 47 | hidden_act="gelu", 48 | hidden_dropout_prob=0.1, 49 | attention_probs_dropout_prob=0.1, 50 | max_position_embeddings=512, 51 | type_vocab_size=16, 52 | type_sequence_label_size=2, 53 | initializer_range=0.02, 54 | num_labels=3, 55 | scope=None): 56 | self.parent = parent 57 | self.batch_size = batch_size 58 | self.seq_length = seq_length 59 | self.is_training = is_training 60 | self.use_input_mask = use_input_mask 61 | self.use_token_type_ids = use_token_type_ids 62 | self.use_labels = use_labels 63 | self.vocab_size = vocab_size 64 | self.hidden_size = hidden_size 65 | self.num_hidden_layers = num_hidden_layers 66 | self.num_attention_heads = num_attention_heads 67 | self.intermediate_size = intermediate_size 68 | self.hidden_act = hidden_act 69 | self.hidden_dropout_prob = hidden_dropout_prob 70 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 71 | self.max_position_embeddings = max_position_embeddings 72 | self.type_vocab_size = type_vocab_size 73 | self.type_sequence_label_size = type_sequence_label_size 74 | self.initializer_range = initializer_range 75 | self.num_labels = num_labels 76 | self.scope = scope 77 | 78 | def prepare_config_and_inputs(self): 79 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 80 | 81 | input_mask = None 82 | if self.use_input_mask: 83 | input_mask = BertModelTest.ids_tensor([self.batch_size, self.seq_length], vocab_size=2) 84 | 85 | token_type_ids = None 86 | if self.use_token_type_ids: 87 | token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) 88 | 89 | sequence_labels = None 90 | token_labels = None 91 | if self.use_labels: 92 | sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) 93 | token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels) 94 | 95 | config = BertConfig( 96 | vocab_size_or_config_json_file=self.vocab_size, 97 | hidden_size=self.hidden_size, 98 | num_hidden_layers=self.num_hidden_layers, 99 | num_attention_heads=self.num_attention_heads, 100 | intermediate_size=self.intermediate_size, 101 | hidden_act=self.hidden_act, 102 | hidden_dropout_prob=self.hidden_dropout_prob, 103 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 104 | max_position_embeddings=self.max_position_embeddings, 105 | type_vocab_size=self.type_vocab_size, 106 | initializer_range=self.initializer_range) 107 | 108 | return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels 109 | 110 | def check_loss_output(self, result): 111 | self.parent.assertListEqual( 112 | list(result["loss"].size()), 113 | []) 114 | 115 | def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 116 | model = BertModel(config=config) 117 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 118 | outputs = { 119 | "sequence_output": all_encoder_layers[-1], 120 | "pooled_output": pooled_output, 121 | "all_encoder_layers": all_encoder_layers, 122 | } 123 | return outputs 124 | 125 | def check_bert_model_output(self, result): 126 | self.parent.assertListEqual( 127 | [size for layer in result["all_encoder_layers"] for size in layer.size()], 128 | [self.batch_size, self.seq_length, self.hidden_size] * self.num_hidden_layers) 129 | self.parent.assertListEqual( 130 | list(result["sequence_output"].size()), 131 | [self.batch_size, self.seq_length, self.hidden_size]) 132 | self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) 133 | 134 | 135 | def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 136 | model = BertForMaskedLM(config=config) 137 | loss = model(input_ids, token_type_ids, input_mask, token_labels) 138 | prediction_scores = model(input_ids, token_type_ids, input_mask) 139 | outputs = { 140 | "loss": loss, 141 | "prediction_scores": prediction_scores, 142 | } 143 | return outputs 144 | 145 | def check_bert_for_masked_lm_output(self, result): 146 | self.parent.assertListEqual( 147 | list(result["prediction_scores"].size()), 148 | [self.batch_size, self.seq_length, self.vocab_size]) 149 | 150 | def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 151 | model = BertForNextSentencePrediction(config=config) 152 | loss = model(input_ids, token_type_ids, input_mask, sequence_labels) 153 | seq_relationship_score = model(input_ids, token_type_ids, input_mask) 154 | outputs = { 155 | "loss": loss, 156 | "seq_relationship_score": seq_relationship_score, 157 | } 158 | return outputs 159 | 160 | def check_bert_for_next_sequence_prediction_output(self, result): 161 | self.parent.assertListEqual( 162 | list(result["seq_relationship_score"].size()), 163 | [self.batch_size, 2]) 164 | 165 | 166 | def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 167 | model = BertForPreTraining(config=config) 168 | loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) 169 | prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask) 170 | outputs = { 171 | "loss": loss, 172 | "prediction_scores": prediction_scores, 173 | "seq_relationship_score": seq_relationship_score, 174 | } 175 | return outputs 176 | 177 | def check_bert_for_pretraining_output(self, result): 178 | self.parent.assertListEqual( 179 | list(result["prediction_scores"].size()), 180 | [self.batch_size, self.seq_length, self.vocab_size]) 181 | self.parent.assertListEqual( 182 | list(result["seq_relationship_score"].size()), 183 | [self.batch_size, 2]) 184 | 185 | 186 | def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 187 | model = BertForQuestionAnswering(config=config) 188 | loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) 189 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 190 | outputs = { 191 | "loss": loss, 192 | "start_logits": start_logits, 193 | "end_logits": end_logits, 194 | } 195 | return outputs 196 | 197 | def check_bert_for_question_answering_output(self, result): 198 | self.parent.assertListEqual( 199 | list(result["start_logits"].size()), 200 | [self.batch_size, self.seq_length]) 201 | self.parent.assertListEqual( 202 | list(result["end_logits"].size()), 203 | [self.batch_size, self.seq_length]) 204 | 205 | 206 | def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 207 | model = BertForSequenceClassification(config=config, num_labels=self.num_labels) 208 | loss = model(input_ids, token_type_ids, input_mask, sequence_labels) 209 | logits = model(input_ids, token_type_ids, input_mask) 210 | outputs = { 211 | "loss": loss, 212 | "logits": logits, 213 | } 214 | return outputs 215 | 216 | def check_bert_for_sequence_classification_output(self, result): 217 | self.parent.assertListEqual( 218 | list(result["logits"].size()), 219 | [self.batch_size, self.num_labels]) 220 | 221 | 222 | def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): 223 | model = BertForTokenClassification(config=config, num_labels=self.num_labels) 224 | loss = model(input_ids, token_type_ids, input_mask, token_labels) 225 | logits = model(input_ids, token_type_ids, input_mask) 226 | outputs = { 227 | "loss": loss, 228 | "logits": logits, 229 | } 230 | return outputs 231 | 232 | def check_bert_for_token_classification_output(self, result): 233 | self.parent.assertListEqual( 234 | list(result["logits"].size()), 235 | [self.batch_size, self.seq_length, self.num_labels]) 236 | 237 | 238 | def test_default(self): 239 | self.run_tester(BertModelTest.BertModelTester(self)) 240 | 241 | def test_config_to_json_string(self): 242 | config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37) 243 | obj = json.loads(config.to_json_string()) 244 | self.assertEqual(obj["vocab_size"], 99) 245 | self.assertEqual(obj["hidden_size"], 37) 246 | 247 | def run_tester(self, tester): 248 | config_and_inputs = tester.prepare_config_and_inputs() 249 | output_result = tester.create_bert_model(*config_and_inputs) 250 | tester.check_bert_model_output(output_result) 251 | 252 | output_result = tester.create_bert_for_masked_lm(*config_and_inputs) 253 | tester.check_bert_for_masked_lm_output(output_result) 254 | tester.check_loss_output(output_result) 255 | 256 | output_result = tester.create_bert_for_next_sequence_prediction(*config_and_inputs) 257 | tester.check_bert_for_next_sequence_prediction_output(output_result) 258 | tester.check_loss_output(output_result) 259 | 260 | output_result = tester.create_bert_for_pretraining(*config_and_inputs) 261 | tester.check_bert_for_pretraining_output(output_result) 262 | tester.check_loss_output(output_result) 263 | 264 | output_result = tester.create_bert_for_question_answering(*config_and_inputs) 265 | tester.check_bert_for_question_answering_output(output_result) 266 | tester.check_loss_output(output_result) 267 | 268 | output_result = tester.create_bert_for_sequence_classification(*config_and_inputs) 269 | tester.check_bert_for_sequence_classification_output(output_result) 270 | tester.check_loss_output(output_result) 271 | 272 | output_result = tester.create_bert_for_token_classification(*config_and_inputs) 273 | tester.check_bert_for_token_classification_output(output_result) 274 | tester.check_loss_output(output_result) 275 | 276 | @classmethod 277 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 278 | """Creates a random int32 tensor of the shape within the vocab size.""" 279 | if rng is None: 280 | rng = random.Random() 281 | 282 | total_dims = 1 283 | for dim in shape: 284 | total_dims *= dim 285 | 286 | values = [] 287 | for _ in range(total_dims): 288 | values.append(rng.randint(0, vocab_size - 1)) 289 | 290 | return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() 291 | 292 | 293 | if __name__ == "__main__": 294 | unittest.main() 295 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import os 24 | import logging 25 | 26 | from .file_utils import cached_path 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 31 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 32 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 33 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 34 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 35 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 36 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 37 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'bert-base-uncased': 512, 41 | 'bert-large-uncased': 512, 42 | 'bert-base-cased': 512, 43 | 'bert-large-cased': 512, 44 | 'bert-base-multilingual-uncased': 512, 45 | 'bert-base-multilingual-cased': 512, 46 | 'bert-base-chinese': 512, 47 | } 48 | VOCAB_NAME = 'vocab.txt' 49 | 50 | 51 | def load_vocab(vocab_file): 52 | """Loads a vocabulary file into a dictionary.""" 53 | vocab = collections.OrderedDict() 54 | index = 0 55 | with open(vocab_file, "r", encoding="utf-8") as reader: 56 | while True: 57 | token = reader.readline() 58 | if not token: 59 | break 60 | token = token.strip() 61 | vocab[token] = index 62 | index += 1 63 | return vocab 64 | 65 | 66 | def whitespace_tokenize(text): 67 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 68 | text = text.strip() 69 | if not text: 70 | return [] 71 | tokens = text.split() 72 | return tokens 73 | 74 | 75 | class BertTokenizer(object): 76 | """Runs end-to-end tokenization: punctuation splitting + wordpiece""" 77 | 78 | def __init__(self, vocab_file, do_lower_case=True, max_len=None, 79 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 80 | if not os.path.isfile(vocab_file): 81 | raise ValueError( 82 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 83 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 84 | self.vocab = load_vocab(vocab_file) 85 | self.ids_to_tokens = collections.OrderedDict( 86 | [(ids, tok) for tok, ids in self.vocab.items()]) 87 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 88 | never_split=never_split) 89 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 90 | self.max_len = max_len if max_len is not None else int(1e12) 91 | 92 | def tokenize(self, text): 93 | split_tokens = [] 94 | for token in self.basic_tokenizer.tokenize(text): 95 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 96 | split_tokens.append(sub_token) 97 | return split_tokens 98 | 99 | def convert_tokens_to_ids(self, tokens): 100 | """Converts a sequence of tokens into ids using the vocab.""" 101 | ids = [] 102 | for token in tokens: 103 | ids.append(self.vocab[token]) 104 | if len(ids) > self.max_len: 105 | raise ValueError( 106 | "Token indices sequence length is longer than the specified maximum " 107 | " sequence length for this BERT model ({} > {}). Running this" 108 | " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) 109 | ) 110 | return ids 111 | 112 | def convert_ids_to_tokens(self, ids): 113 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 114 | tokens = [] 115 | for i in ids: 116 | tokens.append(self.ids_to_tokens[i]) 117 | return tokens 118 | 119 | @classmethod 120 | def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): 121 | """ 122 | Instantiate a PreTrainedBertModel from a pre-trained model file. 123 | Download and cache the pre-trained model file if needed. 124 | """ 125 | if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: 126 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] 127 | else: 128 | vocab_file = pretrained_model_name 129 | if os.path.isdir(vocab_file): 130 | vocab_file = os.path.join(vocab_file, VOCAB_NAME) 131 | # redirect to the cache, if necessary 132 | try: 133 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 134 | except FileNotFoundError: 135 | logger.error( 136 | "Model name '{}' was not found in model name list ({}). " 137 | "We assumed '{}' was a path or url but couldn't find any file " 138 | "associated to this path or url.".format( 139 | pretrained_model_name, 140 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 141 | vocab_file)) 142 | return None 143 | if resolved_vocab_file == vocab_file: 144 | logger.info("loading vocabulary file {}".format(vocab_file)) 145 | else: 146 | logger.info("loading vocabulary file {} from cache at {}".format( 147 | vocab_file, resolved_vocab_file)) 148 | if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 149 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 150 | # than the number of positional embeddings 151 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name] 152 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 153 | # Instantiate tokenizer. 154 | tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) 155 | return tokenizer 156 | 157 | 158 | class BasicTokenizer(object): 159 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 160 | 161 | def __init__(self, 162 | do_lower_case=True, 163 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 164 | """Constructs a BasicTokenizer. 165 | 166 | Args: 167 | do_lower_case: Whether to lower case the input. 168 | """ 169 | self.do_lower_case = do_lower_case 170 | self.never_split = never_split 171 | 172 | def tokenize(self, text): 173 | """Tokenizes a piece of text.""" 174 | text = self._clean_text(text) 175 | # This was added on November 1st, 2018 for the multilingual and Chinese 176 | # models. This is also applied to the English models now, but it doesn't 177 | # matter since the English models were not trained on any Chinese data 178 | # and generally don't have any Chinese data in them (there are Chinese 179 | # characters in the vocabulary because Wikipedia does have some Chinese 180 | # words in the English Wikipedia.). 181 | text = self._tokenize_chinese_chars(text) 182 | orig_tokens = whitespace_tokenize(text) 183 | split_tokens = [] 184 | for token in orig_tokens: 185 | if self.do_lower_case and token not in self.never_split: 186 | token = token.lower() 187 | token = self._run_strip_accents(token) 188 | split_tokens.extend(self._run_split_on_punc(token)) 189 | 190 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 191 | return output_tokens 192 | 193 | def _run_strip_accents(self, text): 194 | """Strips accents from a piece of text.""" 195 | text = unicodedata.normalize("NFD", text) 196 | output = [] 197 | for char in text: 198 | cat = unicodedata.category(char) 199 | if cat == "Mn": 200 | continue 201 | output.append(char) 202 | return "".join(output) 203 | 204 | def _run_split_on_punc(self, text): 205 | """Splits punctuation on a piece of text.""" 206 | if text in self.never_split: 207 | return [text] 208 | chars = list(text) 209 | i = 0 210 | start_new_word = True 211 | output = [] 212 | while i < len(chars): 213 | char = chars[i] 214 | if _is_punctuation(char): 215 | output.append([char]) 216 | start_new_word = True 217 | else: 218 | if start_new_word: 219 | output.append([]) 220 | start_new_word = False 221 | output[-1].append(char) 222 | i += 1 223 | 224 | return ["".join(x) for x in output] 225 | 226 | def _tokenize_chinese_chars(self, text): 227 | """Adds whitespace around any CJK character.""" 228 | output = [] 229 | for char in text: 230 | cp = ord(char) 231 | if self._is_chinese_char(cp): 232 | output.append(" ") 233 | output.append(char) 234 | output.append(" ") 235 | else: 236 | output.append(char) 237 | return "".join(output) 238 | 239 | def _is_chinese_char(self, cp): 240 | """Checks whether CP is the codepoint of a CJK character.""" 241 | # This defines a "chinese character" as anything in the CJK Unicode block: 242 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 243 | # 244 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 245 | # despite its name. The modern Korean Hangul alphabet is a different block, 246 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 247 | # space-separated words, so they are not treated specially and handled 248 | # like the all of the other languages. 249 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 250 | (cp >= 0x3400 and cp <= 0x4DBF) or # 251 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 252 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 253 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 254 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 255 | (cp >= 0xF900 and cp <= 0xFAFF) or # 256 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 257 | return True 258 | 259 | return False 260 | 261 | def _clean_text(self, text): 262 | """Performs invalid character removal and whitespace cleanup on text.""" 263 | output = [] 264 | for char in text: 265 | cp = ord(char) 266 | if cp == 0 or cp == 0xfffd or _is_control(char): 267 | continue 268 | if _is_whitespace(char): 269 | output.append(" ") 270 | else: 271 | output.append(char) 272 | return "".join(output) 273 | 274 | 275 | class WordpieceTokenizer(object): 276 | """Runs WordPiece tokenization.""" 277 | 278 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 279 | self.vocab = vocab 280 | self.unk_token = unk_token 281 | self.max_input_chars_per_word = max_input_chars_per_word 282 | 283 | def tokenize(self, text): 284 | """Tokenizes a piece of text into its word pieces. 285 | 286 | This uses a greedy longest-match-first algorithm to perform tokenization 287 | using the given vocabulary. 288 | 289 | For example: 290 | input = "unaffable" 291 | output = ["un", "##aff", "##able"] 292 | 293 | Args: 294 | text: A single token or whitespace separated tokens. This should have 295 | already been passed through `BasicTokenizer`. 296 | 297 | Returns: 298 | A list of wordpiece tokens. 299 | """ 300 | 301 | output_tokens = [] 302 | for token in whitespace_tokenize(text): 303 | chars = list(token) 304 | if len(chars) > self.max_input_chars_per_word: 305 | output_tokens.append(self.unk_token) 306 | continue 307 | 308 | is_bad = False 309 | start = 0 310 | sub_tokens = [] 311 | while start < len(chars): 312 | end = len(chars) 313 | cur_substr = None 314 | while start < end: 315 | substr = "".join(chars[start:end]) 316 | if start > 0: 317 | substr = "##" + substr 318 | if substr in self.vocab: 319 | cur_substr = substr 320 | break 321 | end -= 1 322 | if cur_substr is None: 323 | is_bad = True 324 | break 325 | sub_tokens.append(cur_substr) 326 | start = end 327 | 328 | if is_bad: 329 | output_tokens.append(self.unk_token) 330 | else: 331 | output_tokens.extend(sub_tokens) 332 | return output_tokens 333 | 334 | 335 | def _is_whitespace(char): 336 | """Checks whether `chars` is a whitespace character.""" 337 | # \t, \n, and \r are technically contorl characters but we treat them 338 | # as whitespace since they are generally considered as such. 339 | if char == " " or char == "\t" or char == "\n" or char == "\r": 340 | return True 341 | cat = unicodedata.category(char) 342 | if cat == "Zs": 343 | return True 344 | return False 345 | 346 | 347 | def _is_control(char): 348 | """Checks whether `chars` is a control character.""" 349 | # These are technically control characters but we count them as whitespace 350 | # characters. 351 | if char == "\t" or char == "\n" or char == "\r": 352 | return False 353 | cat = unicodedata.category(char) 354 | if cat.startswith("C"): 355 | return True 356 | return False 357 | 358 | 359 | def _is_punctuation(char): 360 | """Checks whether `chars` is a punctuation character.""" 361 | cp = ord(char) 362 | # We treat all non-letter/number ASCII as punctuation. 363 | # Characters such as "^", "$", and "`" are not in the Unicode 364 | # Punctuation class but we treat them as punctuation anyways, for 365 | # consistency. 366 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 367 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 368 | return True 369 | cat = unicodedata.category(char) 370 | if cat.startswith("P"): 371 | return True 372 | return False 373 | -------------------------------------------------------------------------------- /examples/run_swag.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | import logging 19 | import os 20 | import argparse 21 | import random 22 | from tqdm import tqdm, trange 23 | import csv 24 | 25 | import numpy as np 26 | import torch 27 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 28 | from torch.utils.data.distributed import DistributedSampler 29 | 30 | from pytorch_pretrained_bert.tokenization import BertTokenizer 31 | from pytorch_pretrained_bert.modeling import BertForMultipleChoice 32 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 33 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 34 | 35 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 36 | datefmt = '%m/%d/%Y %H:%M:%S', 37 | level = logging.INFO) 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | class SwagExample(object): 42 | """A single training/test example for the SWAG dataset.""" 43 | def __init__(self, 44 | swag_id, 45 | context_sentence, 46 | start_ending, 47 | ending_0, 48 | ending_1, 49 | ending_2, 50 | ending_3, 51 | label = None): 52 | self.swag_id = swag_id 53 | self.context_sentence = context_sentence 54 | self.start_ending = start_ending 55 | self.endings = [ 56 | ending_0, 57 | ending_1, 58 | ending_2, 59 | ending_3, 60 | ] 61 | self.label = label 62 | 63 | def __str__(self): 64 | return self.__repr__() 65 | 66 | def __repr__(self): 67 | l = [ 68 | f"swag_id: {self.swag_id}", 69 | f"context_sentence: {self.context_sentence}", 70 | f"start_ending: {self.start_ending}", 71 | f"ending_0: {self.endings[0]}", 72 | f"ending_1: {self.endings[1]}", 73 | f"ending_2: {self.endings[2]}", 74 | f"ending_3: {self.endings[3]}", 75 | ] 76 | 77 | if self.label is not None: 78 | l.append(f"label: {self.label}") 79 | 80 | return ", ".join(l) 81 | 82 | 83 | class InputFeatures(object): 84 | def __init__(self, 85 | example_id, 86 | choices_features, 87 | label 88 | 89 | ): 90 | self.example_id = example_id 91 | self.choices_features = [ 92 | { 93 | 'input_ids': input_ids, 94 | 'input_mask': input_mask, 95 | 'segment_ids': segment_ids 96 | } 97 | for _, input_ids, input_mask, segment_ids in choices_features 98 | ] 99 | self.label = label 100 | 101 | 102 | def read_swag_examples(input_file, is_training): 103 | with open(input_file, 'r', encoding='utf-8') as f: 104 | reader = csv.reader(f) 105 | lines = list(reader) 106 | 107 | if is_training and lines[0][-1] != 'label': 108 | raise ValueError( 109 | "For training, the input file must contain a label column." 110 | ) 111 | 112 | examples = [ 113 | SwagExample( 114 | swag_id = line[2], 115 | context_sentence = line[4], 116 | start_ending = line[5], # in the swag dataset, the 117 | # common beginning of each 118 | # choice is stored in "sent2". 119 | ending_0 = line[7], 120 | ending_1 = line[8], 121 | ending_2 = line[9], 122 | ending_3 = line[10], 123 | label = int(line[11]) if is_training else None 124 | ) for line in lines[1:] # we skip the line with the column names 125 | ] 126 | 127 | return examples 128 | 129 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 130 | is_training): 131 | """Loads a data file into a list of `InputBatch`s.""" 132 | 133 | # Swag is a multiple choice task. To perform this task using Bert, 134 | # we will use the formatting proposed in "Improving Language 135 | # Understanding by Generative Pre-Training" and suggested by 136 | # @jacobdevlin-google in this issue 137 | # https://github.com/google-research/bert/issues/38. 138 | # 139 | # Each choice will correspond to a sample on which we run the 140 | # inference. For a given Swag example, we will create the 4 141 | # following inputs: 142 | # - [CLS] context [SEP] choice_1 [SEP] 143 | # - [CLS] context [SEP] choice_2 [SEP] 144 | # - [CLS] context [SEP] choice_3 [SEP] 145 | # - [CLS] context [SEP] choice_4 [SEP] 146 | # The model will output a single value for each input. To get the 147 | # final decision of the model, we will run a softmax over these 4 148 | # outputs. 149 | features = [] 150 | for example_index, example in enumerate(examples): 151 | context_tokens = tokenizer.tokenize(example.context_sentence) 152 | start_ending_tokens = tokenizer.tokenize(example.start_ending) 153 | 154 | choices_features = [] 155 | for ending_index, ending in enumerate(example.endings): 156 | # We create a copy of the context tokens in order to be 157 | # able to shrink it according to ending_tokens 158 | context_tokens_choice = context_tokens[:] 159 | ending_tokens = start_ending_tokens + tokenizer.tokenize(ending) 160 | # Modifies `context_tokens_choice` and `ending_tokens` in 161 | # place so that the total length is less than the 162 | # specified length. Account for [CLS], [SEP], [SEP] with 163 | # "- 3" 164 | _truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3) 165 | 166 | tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"] 167 | segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1) 168 | 169 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 170 | input_mask = [1] * len(input_ids) 171 | 172 | # Zero-pad up to the sequence length. 173 | padding = [0] * (max_seq_length - len(input_ids)) 174 | input_ids += padding 175 | input_mask += padding 176 | segment_ids += padding 177 | 178 | assert len(input_ids) == max_seq_length 179 | assert len(input_mask) == max_seq_length 180 | assert len(segment_ids) == max_seq_length 181 | 182 | choices_features.append((tokens, input_ids, input_mask, segment_ids)) 183 | 184 | label = example.label 185 | if example_index < 5: 186 | logger.info("*** Example ***") 187 | logger.info(f"swag_id: {example.swag_id}") 188 | for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features): 189 | logger.info(f"choice: {choice_idx}") 190 | logger.info(f"tokens: {' '.join(tokens)}") 191 | logger.info(f"input_ids: {' '.join(map(str, input_ids))}") 192 | logger.info(f"input_mask: {' '.join(map(str, input_mask))}") 193 | logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}") 194 | if is_training: 195 | logger.info(f"label: {label}") 196 | 197 | features.append( 198 | InputFeatures( 199 | example_id = example.swag_id, 200 | choices_features = choices_features, 201 | label = label 202 | ) 203 | ) 204 | 205 | return features 206 | 207 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 208 | """Truncates a sequence pair in place to the maximum length.""" 209 | 210 | # This is a simple heuristic which will always truncate the longer sequence 211 | # one token at a time. This makes more sense than truncating an equal percent 212 | # of tokens from each, since if one sequence is very short then each token 213 | # that's truncated likely contains more information than a longer sequence. 214 | while True: 215 | total_length = len(tokens_a) + len(tokens_b) 216 | if total_length <= max_length: 217 | break 218 | if len(tokens_a) > len(tokens_b): 219 | tokens_a.pop() 220 | else: 221 | tokens_b.pop() 222 | 223 | def accuracy(out, labels): 224 | outputs = np.argmax(out, axis=1) 225 | return np.sum(outputs == labels) 226 | 227 | def select_field(features, field): 228 | return [ 229 | [ 230 | choice[field] 231 | for choice in feature.choices_features 232 | ] 233 | for feature in features 234 | ] 235 | 236 | def main(): 237 | parser = argparse.ArgumentParser() 238 | 239 | ## Required parameters 240 | parser.add_argument("--data_dir", 241 | default=None, 242 | type=str, 243 | required=True, 244 | help="The input data dir. Should contain the .csv files (or other data files) for the task.") 245 | parser.add_argument("--bert_model", default=None, type=str, required=True, 246 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 247 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 248 | "bert-base-multilingual-cased, bert-base-chinese.") 249 | parser.add_argument("--output_dir", 250 | default=None, 251 | type=str, 252 | required=True, 253 | help="The output directory where the model checkpoints will be written.") 254 | 255 | ## Other parameters 256 | parser.add_argument("--max_seq_length", 257 | default=128, 258 | type=int, 259 | help="The maximum total input sequence length after WordPiece tokenization. \n" 260 | "Sequences longer than this will be truncated, and sequences shorter \n" 261 | "than this will be padded.") 262 | parser.add_argument("--do_train", 263 | action='store_true', 264 | help="Whether to run training.") 265 | parser.add_argument("--do_eval", 266 | action='store_true', 267 | help="Whether to run eval on the dev set.") 268 | parser.add_argument("--do_lower_case", 269 | action='store_true', 270 | help="Set this flag if you are using an uncased model.") 271 | parser.add_argument("--train_batch_size", 272 | default=32, 273 | type=int, 274 | help="Total batch size for training.") 275 | parser.add_argument("--eval_batch_size", 276 | default=8, 277 | type=int, 278 | help="Total batch size for eval.") 279 | parser.add_argument("--learning_rate", 280 | default=5e-5, 281 | type=float, 282 | help="The initial learning rate for Adam.") 283 | parser.add_argument("--num_train_epochs", 284 | default=3.0, 285 | type=float, 286 | help="Total number of training epochs to perform.") 287 | parser.add_argument("--warmup_proportion", 288 | default=0.1, 289 | type=float, 290 | help="Proportion of training to perform linear learning rate warmup for. " 291 | "E.g., 0.1 = 10%% of training.") 292 | parser.add_argument("--no_cuda", 293 | action='store_true', 294 | help="Whether not to use CUDA when available") 295 | parser.add_argument("--local_rank", 296 | type=int, 297 | default=-1, 298 | help="local_rank for distributed training on gpus") 299 | parser.add_argument('--seed', 300 | type=int, 301 | default=42, 302 | help="random seed for initialization") 303 | parser.add_argument('--gradient_accumulation_steps', 304 | type=int, 305 | default=1, 306 | help="Number of updates steps to accumulate before performing a backward/update pass.") 307 | parser.add_argument('--fp16', 308 | action='store_true', 309 | help="Whether to use 16-bit float precision instead of 32-bit") 310 | parser.add_argument('--loss_scale', 311 | type=float, default=0, 312 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 313 | "0 (default value): dynamic loss scaling.\n" 314 | "Positive power of 2: static loss scaling value.\n") 315 | 316 | args = parser.parse_args() 317 | 318 | if args.local_rank == -1 or args.no_cuda: 319 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 320 | n_gpu = torch.cuda.device_count() 321 | else: 322 | torch.cuda.set_device(args.local_rank) 323 | device = torch.device("cuda", args.local_rank) 324 | n_gpu = 1 325 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 326 | torch.distributed.init_process_group(backend='nccl') 327 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 328 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 329 | 330 | if args.gradient_accumulation_steps < 1: 331 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 332 | args.gradient_accumulation_steps)) 333 | 334 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 335 | 336 | random.seed(args.seed) 337 | np.random.seed(args.seed) 338 | torch.manual_seed(args.seed) 339 | if n_gpu > 0: 340 | torch.cuda.manual_seed_all(args.seed) 341 | 342 | if not args.do_train and not args.do_eval: 343 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 344 | 345 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 346 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 347 | os.makedirs(args.output_dir, exist_ok=True) 348 | 349 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 350 | 351 | train_examples = None 352 | num_train_optimization_steps = None 353 | if args.do_train: 354 | train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True) 355 | num_train_optimization_steps = int( 356 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 357 | if args.local_rank != -1: 358 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 359 | 360 | # Prepare model 361 | model = BertForMultipleChoice.from_pretrained(args.bert_model, 362 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), 363 | num_choices=4) 364 | if args.fp16: 365 | model.half() 366 | model.to(device) 367 | if args.local_rank != -1: 368 | try: 369 | from apex.parallel import DistributedDataParallel as DDP 370 | except ImportError: 371 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 372 | 373 | model = DDP(model) 374 | elif n_gpu > 1: 375 | model = torch.nn.DataParallel(model) 376 | 377 | # Prepare optimizer 378 | param_optimizer = list(model.named_parameters()) 379 | 380 | # hack to remove pooler, which is not used 381 | # thus it produce None grad that break apex 382 | param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] 383 | 384 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 385 | optimizer_grouped_parameters = [ 386 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 387 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 388 | ] 389 | if args.fp16: 390 | try: 391 | from apex.optimizers import FP16_Optimizer 392 | from apex.optimizers import FusedAdam 393 | except ImportError: 394 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 395 | 396 | optimizer = FusedAdam(optimizer_grouped_parameters, 397 | lr=args.learning_rate, 398 | bias_correction=False, 399 | max_grad_norm=1.0) 400 | if args.loss_scale == 0: 401 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 402 | else: 403 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 404 | else: 405 | optimizer = BertAdam(optimizer_grouped_parameters, 406 | lr=args.learning_rate, 407 | warmup=args.warmup_proportion, 408 | t_total=num_train_optimization_steps) 409 | 410 | global_step = 0 411 | if args.do_train: 412 | train_features = convert_examples_to_features( 413 | train_examples, tokenizer, args.max_seq_length, True) 414 | logger.info("***** Running training *****") 415 | logger.info(" Num examples = %d", len(train_examples)) 416 | logger.info(" Batch size = %d", args.train_batch_size) 417 | logger.info(" Num steps = %d", num_train_optimization_steps) 418 | all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long) 419 | all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long) 420 | all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long) 421 | all_label = torch.tensor([f.label for f in train_features], dtype=torch.long) 422 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label) 423 | if args.local_rank == -1: 424 | train_sampler = RandomSampler(train_data) 425 | else: 426 | train_sampler = DistributedSampler(train_data) 427 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 428 | 429 | model.train() 430 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 431 | tr_loss = 0 432 | nb_tr_examples, nb_tr_steps = 0, 0 433 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 434 | batch = tuple(t.to(device) for t in batch) 435 | input_ids, input_mask, segment_ids, label_ids = batch 436 | loss = model(input_ids, segment_ids, input_mask, label_ids) 437 | if n_gpu > 1: 438 | loss = loss.mean() # mean() to average on multi-gpu. 439 | if args.fp16 and args.loss_scale != 1.0: 440 | # rescale loss for fp16 training 441 | # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html 442 | loss = loss * args.loss_scale 443 | if args.gradient_accumulation_steps > 1: 444 | loss = loss / args.gradient_accumulation_steps 445 | tr_loss += loss.item() 446 | nb_tr_examples += input_ids.size(0) 447 | nb_tr_steps += 1 448 | 449 | if args.fp16: 450 | optimizer.backward(loss) 451 | else: 452 | loss.backward() 453 | if (step + 1) % args.gradient_accumulation_steps == 0: 454 | if args.fp16: 455 | # modify learning rate with special warm up BERT uses 456 | # if args.fp16 is False, BertAdam is used that handles this automatically 457 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion) 458 | for param_group in optimizer.param_groups: 459 | param_group['lr'] = lr_this_step 460 | optimizer.step() 461 | optimizer.zero_grad() 462 | global_step += 1 463 | 464 | # Save a trained model 465 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 466 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 467 | torch.save(model_to_save.state_dict(), output_model_file) 468 | 469 | # Load a trained model that you have fine-tuned 470 | model_state_dict = torch.load(output_model_file) 471 | model = BertForMultipleChoice.from_pretrained(args.bert_model, 472 | state_dict=model_state_dict, 473 | num_choices=4) 474 | model.to(device) 475 | 476 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 477 | eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True) 478 | eval_features = convert_examples_to_features( 479 | eval_examples, tokenizer, args.max_seq_length, True) 480 | logger.info("***** Running evaluation *****") 481 | logger.info(" Num examples = %d", len(eval_examples)) 482 | logger.info(" Batch size = %d", args.eval_batch_size) 483 | all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long) 484 | all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long) 485 | all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long) 486 | all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long) 487 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label) 488 | # Run prediction for full data 489 | eval_sampler = SequentialSampler(eval_data) 490 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 491 | 492 | model.eval() 493 | eval_loss, eval_accuracy = 0, 0 494 | nb_eval_steps, nb_eval_examples = 0, 0 495 | for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: 496 | input_ids = input_ids.to(device) 497 | input_mask = input_mask.to(device) 498 | segment_ids = segment_ids.to(device) 499 | label_ids = label_ids.to(device) 500 | 501 | with torch.no_grad(): 502 | tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) 503 | logits = model(input_ids, segment_ids, input_mask) 504 | 505 | logits = logits.detach().cpu().numpy() 506 | label_ids = label_ids.to('cpu').numpy() 507 | tmp_eval_accuracy = accuracy(logits, label_ids) 508 | 509 | eval_loss += tmp_eval_loss.mean().item() 510 | eval_accuracy += tmp_eval_accuracy 511 | 512 | nb_eval_examples += input_ids.size(0) 513 | nb_eval_steps += 1 514 | 515 | eval_loss = eval_loss / nb_eval_steps 516 | eval_accuracy = eval_accuracy / nb_eval_examples 517 | 518 | result = {'eval_loss': eval_loss, 519 | 'eval_accuracy': eval_accuracy, 520 | 'global_step': global_step, 521 | 'loss': tr_loss/nb_tr_steps} 522 | 523 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 524 | with open(output_eval_file, "w") as writer: 525 | logger.info("***** Eval results *****") 526 | for key in sorted(result.keys()): 527 | logger.info(" %s = %s", key, str(result[key])) 528 | writer.write("%s = %s\n" % (key, str(result[key]))) 529 | 530 | 531 | if __name__ == "__main__": 532 | main() 533 | -------------------------------------------------------------------------------- /examples/run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import csv 23 | import os 24 | import logging 25 | import argparse 26 | import random 27 | from tqdm import tqdm, trange 28 | 29 | import numpy as np 30 | import torch 31 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 32 | from torch.utils.data.distributed import DistributedSampler 33 | 34 | from pytorch_pretrained_bert.tokenization import BertTokenizer 35 | from pytorch_pretrained_bert.modeling import BertForSequenceClassification 36 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 37 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 38 | 39 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 40 | datefmt = '%m/%d/%Y %H:%M:%S', 41 | level = logging.INFO) 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | class InputExample(object): 46 | """A single training/test example for simple sequence classification.""" 47 | 48 | def __init__(self, guid, text_a, text_b=None, label=None): 49 | """Constructs a InputExample. 50 | 51 | Args: 52 | guid: Unique id for the example. 53 | text_a: string. The untokenized text of the first sequence. For single 54 | sequence tasks, only this sequence must be specified. 55 | text_b: (Optional) string. The untokenized text of the second sequence. 56 | Only must be specified for sequence pair tasks. 57 | label: (Optional) string. The label of the example. This should be 58 | specified for train and dev examples, but not for test examples. 59 | """ 60 | self.guid = guid 61 | self.text_a = text_a 62 | self.text_b = text_b 63 | self.label = label 64 | 65 | 66 | class InputFeatures(object): 67 | """A single set of features of data.""" 68 | 69 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 70 | self.input_ids = input_ids 71 | self.input_mask = input_mask 72 | self.segment_ids = segment_ids 73 | self.label_id = label_id 74 | 75 | 76 | class DataProcessor(object): 77 | """Base class for data converters for sequence classification data sets.""" 78 | 79 | def get_train_examples(self, data_dir): 80 | """Gets a collection of `InputExample`s for the train set.""" 81 | raise NotImplementedError() 82 | 83 | def get_dev_examples(self, data_dir): 84 | """Gets a collection of `InputExample`s for the dev set.""" 85 | raise NotImplementedError() 86 | 87 | def get_labels(self): 88 | """Gets the list of labels for this data set.""" 89 | raise NotImplementedError() 90 | 91 | @classmethod 92 | def _read_tsv(cls, input_file, quotechar=None): 93 | """Reads a tab separated value file.""" 94 | with open(input_file, "r", encoding='utf-8') as f: 95 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 96 | lines = [] 97 | for line in reader: 98 | lines.append(line) 99 | return lines 100 | 101 | 102 | class MrpcProcessor(DataProcessor): 103 | """Processor for the MRPC data set (GLUE version).""" 104 | 105 | def get_train_examples(self, data_dir): 106 | """See base class.""" 107 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) 108 | return self._create_examples( 109 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 110 | 111 | def get_dev_examples(self, data_dir): 112 | """See base class.""" 113 | return self._create_examples( 114 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 115 | 116 | def get_labels(self): 117 | """See base class.""" 118 | return ["0", "1"] 119 | 120 | def _create_examples(self, lines, set_type): 121 | """Creates examples for the training and dev sets.""" 122 | examples = [] 123 | for (i, line) in enumerate(lines): 124 | if i == 0: 125 | continue 126 | guid = "%s-%s" % (set_type, i) 127 | text_a = line[3] 128 | text_b = line[4] 129 | label = line[0] 130 | examples.append( 131 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 132 | return examples 133 | 134 | 135 | class MnliProcessor(DataProcessor): 136 | """Processor for the MultiNLI data set (GLUE version).""" 137 | 138 | def get_train_examples(self, data_dir): 139 | """See base class.""" 140 | return self._create_examples( 141 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 142 | 143 | def get_dev_examples(self, data_dir): 144 | """See base class.""" 145 | return self._create_examples( 146 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 147 | "dev_matched") 148 | 149 | def get_labels(self): 150 | """See base class.""" 151 | return ["contradiction", "entailment", "neutral"] 152 | 153 | def _create_examples(self, lines, set_type): 154 | """Creates examples for the training and dev sets.""" 155 | examples = [] 156 | for (i, line) in enumerate(lines): 157 | if i == 0: 158 | continue 159 | guid = "%s-%s" % (set_type, line[0]) 160 | text_a = line[8] 161 | text_b = line[9] 162 | label = line[-1] 163 | examples.append( 164 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 165 | return examples 166 | 167 | 168 | class ColaProcessor(DataProcessor): 169 | """Processor for the CoLA data set (GLUE version).""" 170 | 171 | def get_train_examples(self, data_dir): 172 | """See base class.""" 173 | return self._create_examples( 174 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 175 | 176 | def get_dev_examples(self, data_dir): 177 | """See base class.""" 178 | return self._create_examples( 179 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 180 | 181 | def get_labels(self): 182 | """See base class.""" 183 | return ["0", "1"] 184 | 185 | def _create_examples(self, lines, set_type): 186 | """Creates examples for the training and dev sets.""" 187 | examples = [] 188 | for (i, line) in enumerate(lines): 189 | guid = "%s-%s" % (set_type, i) 190 | text_a = line[3] 191 | label = line[1] 192 | examples.append( 193 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 194 | return examples 195 | 196 | 197 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 198 | """Loads a data file into a list of `InputBatch`s.""" 199 | 200 | label_map = {label : i for i, label in enumerate(label_list)} 201 | 202 | features = [] 203 | for (ex_index, example) in enumerate(examples): 204 | tokens_a = tokenizer.tokenize(example.text_a) 205 | 206 | tokens_b = None 207 | if example.text_b: 208 | tokens_b = tokenizer.tokenize(example.text_b) 209 | # Modifies `tokens_a` and `tokens_b` in place so that the total 210 | # length is less than the specified length. 211 | # Account for [CLS], [SEP], [SEP] with "- 3" 212 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 213 | else: 214 | # Account for [CLS] and [SEP] with "- 2" 215 | if len(tokens_a) > max_seq_length - 2: 216 | tokens_a = tokens_a[:(max_seq_length - 2)] 217 | 218 | # The convention in BERT is: 219 | # (a) For sequence pairs: 220 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 221 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 222 | # (b) For single sequences: 223 | # tokens: [CLS] the dog is hairy . [SEP] 224 | # type_ids: 0 0 0 0 0 0 0 225 | # 226 | # Where "type_ids" are used to indicate whether this is the first 227 | # sequence or the second sequence. The embedding vectors for `type=0` and 228 | # `type=1` were learned during pre-training and are added to the wordpiece 229 | # embedding vector (and position vector). This is not *strictly* necessary 230 | # since the [SEP] token unambigiously separates the sequences, but it makes 231 | # it easier for the model to learn the concept of sequences. 232 | # 233 | # For classification tasks, the first vector (corresponding to [CLS]) is 234 | # used as as the "sentence vector". Note that this only makes sense because 235 | # the entire model is fine-tuned. 236 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 237 | segment_ids = [0] * len(tokens) 238 | 239 | if tokens_b: 240 | tokens += tokens_b + ["[SEP]"] 241 | segment_ids += [1] * (len(tokens_b) + 1) 242 | 243 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 244 | 245 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 246 | # tokens are attended to. 247 | input_mask = [1] * len(input_ids) 248 | 249 | # Zero-pad up to the sequence length. 250 | padding = [0] * (max_seq_length - len(input_ids)) 251 | input_ids += padding 252 | input_mask += padding 253 | segment_ids += padding 254 | 255 | assert len(input_ids) == max_seq_length 256 | assert len(input_mask) == max_seq_length 257 | assert len(segment_ids) == max_seq_length 258 | 259 | label_id = label_map[example.label] 260 | if ex_index < 5: 261 | logger.info("*** Example ***") 262 | logger.info("guid: %s" % (example.guid)) 263 | logger.info("tokens: %s" % " ".join( 264 | [str(x) for x in tokens])) 265 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 266 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 267 | logger.info( 268 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 269 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 270 | 271 | features.append( 272 | InputFeatures(input_ids=input_ids, 273 | input_mask=input_mask, 274 | segment_ids=segment_ids, 275 | label_id=label_id)) 276 | return features 277 | 278 | 279 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 280 | """Truncates a sequence pair in place to the maximum length.""" 281 | 282 | # This is a simple heuristic which will always truncate the longer sequence 283 | # one token at a time. This makes more sense than truncating an equal percent 284 | # of tokens from each, since if one sequence is very short then each token 285 | # that's truncated likely contains more information than a longer sequence. 286 | while True: 287 | total_length = len(tokens_a) + len(tokens_b) 288 | if total_length <= max_length: 289 | break 290 | if len(tokens_a) > len(tokens_b): 291 | tokens_a.pop() 292 | else: 293 | tokens_b.pop() 294 | 295 | def accuracy(out, labels): 296 | outputs = np.argmax(out, axis=1) 297 | return np.sum(outputs == labels) 298 | 299 | def main(): 300 | parser = argparse.ArgumentParser() 301 | 302 | ## Required parameters 303 | parser.add_argument("--data_dir", 304 | default=None, 305 | type=str, 306 | required=True, 307 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 308 | parser.add_argument("--bert_model", default=None, type=str, required=True, 309 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 310 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " 311 | "bert-base-multilingual-cased, bert-base-chinese.") 312 | parser.add_argument("--task_name", 313 | default=None, 314 | type=str, 315 | required=True, 316 | help="The name of the task to train.") 317 | parser.add_argument("--output_dir", 318 | default=None, 319 | type=str, 320 | required=True, 321 | help="The output directory where the model predictions and checkpoints will be written.") 322 | 323 | ## Other parameters 324 | parser.add_argument("--max_seq_length", 325 | default=128, 326 | type=int, 327 | help="The maximum total input sequence length after WordPiece tokenization. \n" 328 | "Sequences longer than this will be truncated, and sequences shorter \n" 329 | "than this will be padded.") 330 | parser.add_argument("--do_train", 331 | action='store_true', 332 | help="Whether to run training.") 333 | parser.add_argument("--do_eval", 334 | action='store_true', 335 | help="Whether to run eval on the dev set.") 336 | parser.add_argument("--do_lower_case", 337 | action='store_true', 338 | help="Set this flag if you are using an uncased model.") 339 | parser.add_argument("--train_batch_size", 340 | default=32, 341 | type=int, 342 | help="Total batch size for training.") 343 | parser.add_argument("--eval_batch_size", 344 | default=8, 345 | type=int, 346 | help="Total batch size for eval.") 347 | parser.add_argument("--learning_rate", 348 | default=5e-5, 349 | type=float, 350 | help="The initial learning rate for Adam.") 351 | parser.add_argument("--num_train_epochs", 352 | default=3.0, 353 | type=float, 354 | help="Total number of training epochs to perform.") 355 | parser.add_argument("--warmup_proportion", 356 | default=0.1, 357 | type=float, 358 | help="Proportion of training to perform linear learning rate warmup for. " 359 | "E.g., 0.1 = 10%% of training.") 360 | parser.add_argument("--no_cuda", 361 | action='store_true', 362 | help="Whether not to use CUDA when available") 363 | parser.add_argument("--local_rank", 364 | type=int, 365 | default=-1, 366 | help="local_rank for distributed training on gpus") 367 | parser.add_argument('--seed', 368 | type=int, 369 | default=42, 370 | help="random seed for initialization") 371 | parser.add_argument('--gradient_accumulation_steps', 372 | type=int, 373 | default=1, 374 | help="Number of updates steps to accumulate before performing a backward/update pass.") 375 | parser.add_argument('--fp16', 376 | action='store_true', 377 | help="Whether to use 16-bit float precision instead of 32-bit") 378 | parser.add_argument('--loss_scale', 379 | type=float, default=0, 380 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 381 | "0 (default value): dynamic loss scaling.\n" 382 | "Positive power of 2: static loss scaling value.\n") 383 | 384 | args = parser.parse_args() 385 | 386 | processors = { 387 | "cola": ColaProcessor, 388 | "mnli": MnliProcessor, 389 | "mrpc": MrpcProcessor, 390 | } 391 | 392 | num_labels_task = { 393 | "cola": 2, 394 | "mnli": 3, 395 | "mrpc": 2, 396 | } 397 | 398 | if args.local_rank == -1 or args.no_cuda: 399 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 400 | n_gpu = torch.cuda.device_count() 401 | else: 402 | torch.cuda.set_device(args.local_rank) 403 | device = torch.device("cuda", args.local_rank) 404 | n_gpu = 1 405 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 406 | torch.distributed.init_process_group(backend='nccl') 407 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 408 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 409 | 410 | if args.gradient_accumulation_steps < 1: 411 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 412 | args.gradient_accumulation_steps)) 413 | 414 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 415 | 416 | random.seed(args.seed) 417 | np.random.seed(args.seed) 418 | torch.manual_seed(args.seed) 419 | if n_gpu > 0: 420 | torch.cuda.manual_seed_all(args.seed) 421 | 422 | if not args.do_train and not args.do_eval: 423 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 424 | 425 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 426 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 427 | os.makedirs(args.output_dir, exist_ok=True) 428 | 429 | task_name = args.task_name.lower() 430 | 431 | if task_name not in processors: 432 | raise ValueError("Task not found: %s" % (task_name)) 433 | 434 | processor = processors[task_name]() 435 | num_labels = num_labels_task[task_name] 436 | label_list = processor.get_labels() 437 | 438 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 439 | 440 | train_examples = None 441 | num_train_optimization_steps = None 442 | if args.do_train: 443 | train_examples = processor.get_train_examples(args.data_dir) 444 | num_train_optimization_steps = int( 445 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 446 | if args.local_rank != -1: 447 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 448 | 449 | # Prepare model 450 | model = BertForSequenceClassification.from_pretrained(args.bert_model, 451 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), 452 | num_labels = num_labels) 453 | if args.fp16: 454 | model.half() 455 | model.to(device) 456 | if args.local_rank != -1: 457 | try: 458 | from apex.parallel import DistributedDataParallel as DDP 459 | except ImportError: 460 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 461 | 462 | model = DDP(model) 463 | elif n_gpu > 1: 464 | model = torch.nn.DataParallel(model) 465 | 466 | # Prepare optimizer 467 | param_optimizer = list(model.named_parameters()) 468 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 469 | optimizer_grouped_parameters = [ 470 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 471 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 472 | ] 473 | if args.fp16: 474 | try: 475 | from apex.optimizers import FP16_Optimizer 476 | from apex.optimizers import FusedAdam 477 | except ImportError: 478 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 479 | 480 | optimizer = FusedAdam(optimizer_grouped_parameters, 481 | lr=args.learning_rate, 482 | bias_correction=False, 483 | max_grad_norm=1.0) 484 | if args.loss_scale == 0: 485 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 486 | else: 487 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 488 | 489 | else: 490 | optimizer = BertAdam(optimizer_grouped_parameters, 491 | lr=args.learning_rate, 492 | warmup=args.warmup_proportion, 493 | t_total=num_train_optimization_steps) 494 | 495 | global_step = 0 496 | nb_tr_steps = 0 497 | tr_loss = 0 498 | if args.do_train: 499 | train_features = convert_examples_to_features( 500 | train_examples, label_list, args.max_seq_length, tokenizer) 501 | logger.info("***** Running training *****") 502 | logger.info(" Num examples = %d", len(train_examples)) 503 | logger.info(" Batch size = %d", args.train_batch_size) 504 | logger.info(" Num steps = %d", num_train_optimization_steps) 505 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 506 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 507 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 508 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 509 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 510 | if args.local_rank == -1: 511 | train_sampler = RandomSampler(train_data) 512 | else: 513 | train_sampler = DistributedSampler(train_data) 514 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 515 | 516 | model.train() 517 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 518 | tr_loss = 0 519 | nb_tr_examples, nb_tr_steps = 0, 0 520 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 521 | batch = tuple(t.to(device) for t in batch) 522 | input_ids, input_mask, segment_ids, label_ids = batch 523 | loss = model(input_ids, segment_ids, input_mask, label_ids) 524 | if n_gpu > 1: 525 | loss = loss.mean() # mean() to average on multi-gpu. 526 | if args.gradient_accumulation_steps > 1: 527 | loss = loss / args.gradient_accumulation_steps 528 | 529 | if args.fp16: 530 | optimizer.backward(loss) 531 | else: 532 | loss.backward() 533 | 534 | tr_loss += loss.item() 535 | nb_tr_examples += input_ids.size(0) 536 | nb_tr_steps += 1 537 | if (step + 1) % args.gradient_accumulation_steps == 0: 538 | if args.fp16: 539 | # modify learning rate with special warm up BERT uses 540 | # if args.fp16 is False, BertAdam is used that handles this automatically 541 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion) 542 | for param_group in optimizer.param_groups: 543 | param_group['lr'] = lr_this_step 544 | optimizer.step() 545 | optimizer.zero_grad() 546 | global_step += 1 547 | 548 | # Save a trained model 549 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 550 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 551 | if args.do_train: 552 | torch.save(model_to_save.state_dict(), output_model_file) 553 | 554 | # Load a trained model that you have fine-tuned 555 | model_state_dict = torch.load(output_model_file) 556 | model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels) 557 | model.to(device) 558 | 559 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 560 | eval_examples = processor.get_dev_examples(args.data_dir) 561 | eval_features = convert_examples_to_features( 562 | eval_examples, label_list, args.max_seq_length, tokenizer) 563 | logger.info("***** Running evaluation *****") 564 | logger.info(" Num examples = %d", len(eval_examples)) 565 | logger.info(" Batch size = %d", args.eval_batch_size) 566 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 567 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 568 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 569 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) 570 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 571 | # Run prediction for full data 572 | eval_sampler = SequentialSampler(eval_data) 573 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 574 | 575 | model.eval() 576 | eval_loss, eval_accuracy = 0, 0 577 | nb_eval_steps, nb_eval_examples = 0, 0 578 | 579 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 580 | input_ids = input_ids.to(device) 581 | input_mask = input_mask.to(device) 582 | segment_ids = segment_ids.to(device) 583 | label_ids = label_ids.to(device) 584 | 585 | with torch.no_grad(): 586 | tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) 587 | logits = model(input_ids, segment_ids, input_mask) 588 | 589 | logits = logits.detach().cpu().numpy() 590 | label_ids = label_ids.to('cpu').numpy() 591 | tmp_eval_accuracy = accuracy(logits, label_ids) 592 | 593 | eval_loss += tmp_eval_loss.mean().item() 594 | eval_accuracy += tmp_eval_accuracy 595 | 596 | nb_eval_examples += input_ids.size(0) 597 | nb_eval_steps += 1 598 | 599 | eval_loss = eval_loss / nb_eval_steps 600 | eval_accuracy = eval_accuracy / nb_eval_examples 601 | loss = tr_loss/nb_tr_steps if args.do_train else None 602 | result = {'eval_loss': eval_loss, 603 | 'eval_accuracy': eval_accuracy, 604 | 'global_step': global_step, 605 | 'loss': loss} 606 | 607 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 608 | with open(output_eval_file, "w") as writer: 609 | logger.info("***** Eval results *****") 610 | for key in sorted(result.keys()): 611 | logger.info(" %s = %s", key, str(result[key])) 612 | writer.write("%s = %s\n" % (key, str(result[key]))) 613 | 614 | if __name__ == "__main__": 615 | main() 616 | -------------------------------------------------------------------------------- /examples/run_lm_finetuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """BERT finetuning runner.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import logging 24 | import argparse 25 | from tqdm import tqdm, trange 26 | 27 | import numpy as np 28 | import torch 29 | from torch.utils.data import DataLoader, RandomSampler 30 | from torch.utils.data.distributed import DistributedSampler 31 | 32 | from pytorch_pretrained_bert.tokenization import BertTokenizer 33 | from pytorch_pretrained_bert.modeling import BertForPreTraining 34 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 35 | 36 | from torch.utils.data import Dataset 37 | import random 38 | 39 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 40 | datefmt='%m/%d/%Y %H:%M:%S', 41 | level=logging.INFO) 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | class BERTDataset(Dataset): 46 | def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): 47 | self.vocab = tokenizer.vocab 48 | self.tokenizer = tokenizer 49 | self.seq_len = seq_len 50 | self.on_memory = on_memory 51 | self.corpus_lines = corpus_lines # number of non-empty lines in input corpus 52 | self.corpus_path = corpus_path 53 | self.encoding = encoding 54 | self.current_doc = 0 # to avoid random sentence from same doc 55 | 56 | # for loading samples directly from file 57 | self.sample_counter = 0 # used to keep track of full epochs on file 58 | self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair 59 | 60 | # for loading samples in memory 61 | self.current_random_doc = 0 62 | self.num_docs = 0 63 | self.sample_to_doc = [] # map sample index to doc and line 64 | 65 | # load samples into memory 66 | if on_memory: 67 | self.all_docs = [] 68 | doc = [] 69 | self.corpus_lines = 0 70 | with open(corpus_path, "r", encoding=encoding) as f: 71 | for line in tqdm(f, desc="Loading Dataset", total=corpus_lines): 72 | line = line.strip() 73 | if line == "": 74 | self.all_docs.append(doc) 75 | doc = [] 76 | #remove last added sample because there won't be a subsequent line anymore in the doc 77 | self.sample_to_doc.pop() 78 | else: 79 | #store as one sample 80 | sample = {"doc_id": len(self.all_docs), 81 | "line": len(doc)} 82 | self.sample_to_doc.append(sample) 83 | doc.append(line) 84 | self.corpus_lines = self.corpus_lines + 1 85 | 86 | # if last row in file is not empty 87 | if self.all_docs[-1] != doc: 88 | self.all_docs.append(doc) 89 | self.sample_to_doc.pop() 90 | 91 | self.num_docs = len(self.all_docs) 92 | 93 | # load samples later lazily from disk 94 | else: 95 | if self.corpus_lines is None: 96 | with open(corpus_path, "r", encoding=encoding) as f: 97 | self.corpus_lines = 0 98 | for line in tqdm(f, desc="Loading Dataset", total=corpus_lines): 99 | if line.strip() == "": 100 | self.num_docs += 1 101 | else: 102 | self.corpus_lines += 1 103 | 104 | # if doc does not end with empty line 105 | if line.strip() != "": 106 | self.num_docs += 1 107 | 108 | self.file = open(corpus_path, "r", encoding=encoding) 109 | self.random_file = open(corpus_path, "r", encoding=encoding) 110 | 111 | def __len__(self): 112 | # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0. 113 | return self.corpus_lines - self.num_docs - 1 114 | 115 | def __getitem__(self, item): 116 | cur_id = self.sample_counter 117 | self.sample_counter += 1 118 | if not self.on_memory: 119 | # after one epoch we start again from beginning of file 120 | if cur_id != 0 and (cur_id % len(self) == 0): 121 | self.file.close() 122 | self.file = open(self.corpus_path, "r", encoding=self.encoding) 123 | 124 | t1, t2, is_next_label = self.random_sent(item) 125 | 126 | # tokenize 127 | tokens_a = self.tokenizer.tokenize(t1) 128 | tokens_b = self.tokenizer.tokenize(t2) 129 | 130 | # combine to one sample 131 | cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label) 132 | 133 | # transform sample to features 134 | cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer) 135 | 136 | cur_tensors = (torch.tensor(cur_features.input_ids), 137 | torch.tensor(cur_features.input_mask), 138 | torch.tensor(cur_features.segment_ids), 139 | torch.tensor(cur_features.lm_label_ids), 140 | torch.tensor(cur_features.is_next)) 141 | 142 | return cur_tensors 143 | 144 | def random_sent(self, index): 145 | """ 146 | Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences 147 | from one doc. With 50% the second sentence will be a random one from another doc. 148 | :param index: int, index of sample. 149 | :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label 150 | """ 151 | t1, t2 = self.get_corpus_line(index) 152 | if random.random() > 0.5: 153 | label = 0 154 | else: 155 | t2 = self.get_random_line() 156 | label = 1 157 | 158 | assert len(t1) > 0 159 | assert len(t2) > 0 160 | return t1, t2, label 161 | 162 | def get_corpus_line(self, item): 163 | """ 164 | Get one sample from corpus consisting of a pair of two subsequent lines from the same doc. 165 | :param item: int, index of sample. 166 | :return: (str, str), two subsequent sentences from corpus 167 | """ 168 | t1 = "" 169 | t2 = "" 170 | assert item < self.corpus_lines 171 | if self.on_memory: 172 | sample = self.sample_to_doc[item] 173 | t1 = self.all_docs[sample["doc_id"]][sample["line"]] 174 | t2 = self.all_docs[sample["doc_id"]][sample["line"]+1] 175 | # used later to avoid random nextSentence from same doc 176 | self.current_doc = sample["doc_id"] 177 | return t1, t2 178 | else: 179 | if self.line_buffer is None: 180 | # read first non-empty line of file 181 | while t1 == "" : 182 | t1 = self.file.__next__().strip() 183 | t2 = self.file.__next__().strip() 184 | else: 185 | # use t2 from previous iteration as new t1 186 | t1 = self.line_buffer 187 | t2 = self.file.__next__().strip() 188 | # skip empty rows that are used for separating documents and keep track of current doc id 189 | while t2 == "" or t1 == "": 190 | t1 = self.file.__next__().strip() 191 | t2 = self.file.__next__().strip() 192 | self.current_doc = self.current_doc+1 193 | self.line_buffer = t2 194 | 195 | assert t1 != "" 196 | assert t2 != "" 197 | return t1, t2 198 | 199 | def get_random_line(self): 200 | """ 201 | Get random line from another document for nextSentence task. 202 | :return: str, content of one line 203 | """ 204 | # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large 205 | # corpora. However, just to be careful, we try to make sure that 206 | # the random document is not the same as the document we're processing. 207 | for _ in range(10): 208 | if self.on_memory: 209 | rand_doc_idx = random.randint(0, len(self.all_docs)-1) 210 | rand_doc = self.all_docs[rand_doc_idx] 211 | line = rand_doc[random.randrange(len(rand_doc))] 212 | else: 213 | rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000) 214 | #pick random line 215 | for _ in range(rand_index): 216 | line = self.get_next_line() 217 | #check if our picked random line is really from another doc like we want it to be 218 | if self.current_random_doc != self.current_doc: 219 | break 220 | return line 221 | 222 | def get_next_line(self): 223 | """ Gets next line of random_file and starts over when reaching end of file""" 224 | try: 225 | line = self.random_file.__next__().strip() 226 | #keep track of which document we are currently looking at to later avoid having the same doc as t1 227 | if line == "": 228 | self.current_random_doc = self.current_random_doc + 1 229 | line = self.random_file.__next__().strip() 230 | except StopIteration: 231 | self.random_file.close() 232 | self.random_file = open(self.corpus_path, "r", encoding=self.encoding) 233 | line = self.random_file.__next__().strip() 234 | return line 235 | 236 | 237 | class InputExample(object): 238 | """A single training/test example for the language model.""" 239 | 240 | def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None): 241 | """Constructs a InputExample. 242 | 243 | Args: 244 | guid: Unique id for the example. 245 | tokens_a: string. The untokenized text of the first sequence. For single 246 | sequence tasks, only this sequence must be specified. 247 | tokens_b: (Optional) string. The untokenized text of the second sequence. 248 | Only must be specified for sequence pair tasks. 249 | label: (Optional) string. The label of the example. This should be 250 | specified for train and dev examples, but not for test examples. 251 | """ 252 | self.guid = guid 253 | self.tokens_a = tokens_a 254 | self.tokens_b = tokens_b 255 | self.is_next = is_next # nextSentence 256 | self.lm_labels = lm_labels # masked words for language model 257 | 258 | 259 | class InputFeatures(object): 260 | """A single set of features of data.""" 261 | 262 | def __init__(self, input_ids, input_mask, segment_ids, is_next, lm_label_ids): 263 | self.input_ids = input_ids 264 | self.input_mask = input_mask 265 | self.segment_ids = segment_ids 266 | self.is_next = is_next 267 | self.lm_label_ids = lm_label_ids 268 | 269 | 270 | def random_word(tokens, tokenizer): 271 | """ 272 | Masking some random tokens for Language Model task with probabilities as in the original BERT paper. 273 | :param tokens: list of str, tokenized sentence. 274 | :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here) 275 | :return: (list of str, list of int), masked tokens and related labels for LM prediction 276 | """ 277 | output_label = [] 278 | 279 | for i, token in enumerate(tokens): 280 | prob = random.random() 281 | # mask token with 15% probability 282 | if prob < 0.15: 283 | prob /= 0.15 284 | 285 | # 80% randomly change token to mask token 286 | if prob < 0.8: 287 | tokens[i] = "[MASK]" 288 | 289 | # 10% randomly change token to random token 290 | elif prob < 0.9: 291 | tokens[i] = random.choice(list(tokenizer.vocab.items()))[0] 292 | 293 | # -> rest 10% randomly keep current token 294 | 295 | # append current token to output (we will predict these later) 296 | try: 297 | output_label.append(tokenizer.vocab[token]) 298 | except KeyError: 299 | # For unknown words (should not occur with BPE vocab) 300 | output_label.append(tokenizer.vocab["[UNK]"]) 301 | logger.warning("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token)) 302 | else: 303 | # no masking token (will be ignored by loss function later) 304 | output_label.append(-1) 305 | 306 | return tokens, output_label 307 | 308 | 309 | def convert_example_to_features(example, max_seq_length, tokenizer): 310 | """ 311 | Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with 312 | IDs, LM labels, input_mask, CLS and SEP tokens etc. 313 | :param example: InputExample, containing sentence input as strings and is_next label 314 | :param max_seq_length: int, maximum length of sequence. 315 | :param tokenizer: Tokenizer 316 | :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training) 317 | """ 318 | tokens_a = example.tokens_a 319 | tokens_b = example.tokens_b 320 | # Modifies `tokens_a` and `tokens_b` in place so that the total 321 | # length is less than the specified length. 322 | # Account for [CLS], [SEP], [SEP] with "- 3" 323 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 324 | 325 | tokens_a, t1_label = random_word(tokens_a, tokenizer) 326 | tokens_b, t2_label = random_word(tokens_b, tokenizer) 327 | # concatenate lm labels and account for CLS, SEP, SEP 328 | lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1]) 329 | 330 | # The convention in BERT is: 331 | # (a) For sequence pairs: 332 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 333 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 334 | # (b) For single sequences: 335 | # tokens: [CLS] the dog is hairy . [SEP] 336 | # type_ids: 0 0 0 0 0 0 0 337 | # 338 | # Where "type_ids" are used to indicate whether this is the first 339 | # sequence or the second sequence. The embedding vectors for `type=0` and 340 | # `type=1` were learned during pre-training and are added to the wordpiece 341 | # embedding vector (and position vector). This is not *strictly* necessary 342 | # since the [SEP] token unambigiously separates the sequences, but it makes 343 | # it easier for the model to learn the concept of sequences. 344 | # 345 | # For classification tasks, the first vector (corresponding to [CLS]) is 346 | # used as as the "sentence vector". Note that this only makes sense because 347 | # the entire model is fine-tuned. 348 | tokens = [] 349 | segment_ids = [] 350 | tokens.append("[CLS]") 351 | segment_ids.append(0) 352 | for token in tokens_a: 353 | tokens.append(token) 354 | segment_ids.append(0) 355 | tokens.append("[SEP]") 356 | segment_ids.append(0) 357 | 358 | assert len(tokens_b) > 0 359 | for token in tokens_b: 360 | tokens.append(token) 361 | segment_ids.append(1) 362 | tokens.append("[SEP]") 363 | segment_ids.append(1) 364 | 365 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 366 | 367 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 368 | # tokens are attended to. 369 | input_mask = [1] * len(input_ids) 370 | 371 | # Zero-pad up to the sequence length. 372 | while len(input_ids) < max_seq_length: 373 | input_ids.append(0) 374 | input_mask.append(0) 375 | segment_ids.append(0) 376 | lm_label_ids.append(-1) 377 | 378 | assert len(input_ids) == max_seq_length 379 | assert len(input_mask) == max_seq_length 380 | assert len(segment_ids) == max_seq_length 381 | assert len(lm_label_ids) == max_seq_length 382 | 383 | if example.guid < 5: 384 | logger.info("*** Example ***") 385 | logger.info("guid: %s" % (example.guid)) 386 | logger.info("tokens: %s" % " ".join( 387 | [str(x) for x in tokens])) 388 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 389 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 390 | logger.info( 391 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 392 | logger.info("LM label: %s " % (lm_label_ids)) 393 | logger.info("Is next sentence label: %s " % (example.is_next)) 394 | 395 | features = InputFeatures(input_ids=input_ids, 396 | input_mask=input_mask, 397 | segment_ids=segment_ids, 398 | lm_label_ids=lm_label_ids, 399 | is_next=example.is_next) 400 | return features 401 | 402 | 403 | def main(): 404 | parser = argparse.ArgumentParser() 405 | 406 | ## Required parameters 407 | parser.add_argument("--train_file", 408 | default=None, 409 | type=str, 410 | required=True, 411 | help="The input train corpus.") 412 | parser.add_argument("--bert_model", default=None, type=str, required=True, 413 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 414 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 415 | parser.add_argument("--output_dir", 416 | default=None, 417 | type=str, 418 | required=True, 419 | help="The output directory where the model checkpoints will be written.") 420 | 421 | ## Other parameters 422 | parser.add_argument("--max_seq_length", 423 | default=128, 424 | type=int, 425 | help="The maximum total input sequence length after WordPiece tokenization. \n" 426 | "Sequences longer than this will be truncated, and sequences shorter \n" 427 | "than this will be padded.") 428 | parser.add_argument("--do_train", 429 | action='store_true', 430 | help="Whether to run training.") 431 | parser.add_argument("--train_batch_size", 432 | default=32, 433 | type=int, 434 | help="Total batch size for training.") 435 | parser.add_argument("--learning_rate", 436 | default=3e-5, 437 | type=float, 438 | help="The initial learning rate for Adam.") 439 | parser.add_argument("--num_train_epochs", 440 | default=3.0, 441 | type=float, 442 | help="Total number of training epochs to perform.") 443 | parser.add_argument("--warmup_proportion", 444 | default=0.1, 445 | type=float, 446 | help="Proportion of training to perform linear learning rate warmup for. " 447 | "E.g., 0.1 = 10%% of training.") 448 | parser.add_argument("--no_cuda", 449 | action='store_true', 450 | help="Whether not to use CUDA when available") 451 | parser.add_argument("--on_memory", 452 | action='store_true', 453 | help="Whether to load train samples into memory or use disk") 454 | parser.add_argument("--do_lower_case", 455 | action='store_true', 456 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 457 | parser.add_argument("--local_rank", 458 | type=int, 459 | default=-1, 460 | help="local_rank for distributed training on gpus") 461 | parser.add_argument('--seed', 462 | type=int, 463 | default=42, 464 | help="random seed for initialization") 465 | parser.add_argument('--gradient_accumulation_steps', 466 | type=int, 467 | default=1, 468 | help="Number of updates steps to accumualte before performing a backward/update pass.") 469 | parser.add_argument('--fp16', 470 | action='store_true', 471 | help="Whether to use 16-bit float precision instead of 32-bit") 472 | parser.add_argument('--loss_scale', 473 | type = float, default = 0, 474 | help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 475 | "0 (default value): dynamic loss scaling.\n" 476 | "Positive power of 2: static loss scaling value.\n") 477 | 478 | args = parser.parse_args() 479 | 480 | if args.local_rank == -1 or args.no_cuda: 481 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 482 | n_gpu = torch.cuda.device_count() 483 | else: 484 | torch.cuda.set_device(args.local_rank) 485 | device = torch.device("cuda", args.local_rank) 486 | n_gpu = 1 487 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 488 | torch.distributed.init_process_group(backend='nccl') 489 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( 490 | device, n_gpu, bool(args.local_rank != -1), args.fp16)) 491 | 492 | if args.gradient_accumulation_steps < 1: 493 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 494 | args.gradient_accumulation_steps)) 495 | 496 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 497 | 498 | random.seed(args.seed) 499 | np.random.seed(args.seed) 500 | torch.manual_seed(args.seed) 501 | if n_gpu > 0: 502 | torch.cuda.manual_seed_all(args.seed) 503 | 504 | if not args.do_train: 505 | raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.") 506 | 507 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 508 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 509 | os.makedirs(args.output_dir, exist_ok=True) 510 | 511 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 512 | 513 | #train_examples = None 514 | num_train_optimization_steps = None 515 | if args.do_train: 516 | print("Loading Train Dataset", args.train_file) 517 | train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length, 518 | corpus_lines=None, on_memory=args.on_memory) 519 | num_train_optimization_steps = int( 520 | len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs 521 | if args.local_rank != -1: 522 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 523 | 524 | # Prepare model 525 | model = BertForPreTraining.from_pretrained(args.bert_model) 526 | if args.fp16: 527 | model.half() 528 | model.to(device) 529 | if args.local_rank != -1: 530 | try: 531 | from apex.parallel import DistributedDataParallel as DDP 532 | except ImportError: 533 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 534 | model = DDP(model) 535 | elif n_gpu > 1: 536 | model = torch.nn.DataParallel(model) 537 | 538 | # Prepare optimizer 539 | param_optimizer = list(model.named_parameters()) 540 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 541 | optimizer_grouped_parameters = [ 542 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 543 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 544 | ] 545 | 546 | if args.fp16: 547 | try: 548 | from apex.optimizers import FP16_Optimizer 549 | from apex.optimizers import FusedAdam 550 | except ImportError: 551 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") 552 | 553 | optimizer = FusedAdam(optimizer_grouped_parameters, 554 | lr=args.learning_rate, 555 | bias_correction=False, 556 | max_grad_norm=1.0) 557 | if args.loss_scale == 0: 558 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 559 | else: 560 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 561 | 562 | else: 563 | optimizer = BertAdam(optimizer_grouped_parameters, 564 | lr=args.learning_rate, 565 | warmup=args.warmup_proportion, 566 | t_total=num_train_optimization_steps) 567 | 568 | global_step = 0 569 | if args.do_train: 570 | logger.info("***** Running training *****") 571 | logger.info(" Num examples = %d", len(train_dataset)) 572 | logger.info(" Batch size = %d", args.train_batch_size) 573 | logger.info(" Num steps = %d", num_train_optimization_steps) 574 | 575 | if args.local_rank == -1: 576 | train_sampler = RandomSampler(train_dataset) 577 | else: 578 | #TODO: check if this works with current data generator from disk that relies on file.__next__ 579 | # (it doesn't return item back by index) 580 | train_sampler = DistributedSampler(train_dataset) 581 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 582 | 583 | model.train() 584 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 585 | tr_loss = 0 586 | nb_tr_examples, nb_tr_steps = 0, 0 587 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 588 | batch = tuple(t.to(device) for t in batch) 589 | input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch 590 | loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) 591 | if n_gpu > 1: 592 | loss = loss.mean() # mean() to average on multi-gpu. 593 | if args.gradient_accumulation_steps > 1: 594 | loss = loss / args.gradient_accumulation_steps 595 | if args.fp16: 596 | optimizer.backward(loss) 597 | else: 598 | loss.backward() 599 | tr_loss += loss.item() 600 | nb_tr_examples += input_ids.size(0) 601 | nb_tr_steps += 1 602 | if (step + 1) % args.gradient_accumulation_steps == 0: 603 | if args.fp16: 604 | # modify learning rate with special warm up BERT uses 605 | # if args.fp16 is False, BertAdam is used that handles this automatically 606 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion) 607 | for param_group in optimizer.param_groups: 608 | param_group['lr'] = lr_this_step 609 | optimizer.step() 610 | optimizer.zero_grad() 611 | global_step += 1 612 | 613 | # Save a trained model 614 | logger.info("** ** * Saving fine - tuned model ** ** * ") 615 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 616 | output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") 617 | if args.do_train: 618 | torch.save(model_to_save.state_dict(), output_model_file) 619 | 620 | 621 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 622 | """Truncates a sequence pair in place to the maximum length.""" 623 | 624 | # This is a simple heuristic which will always truncate the longer sequence 625 | # one token at a time. This makes more sense than truncating an equal percent 626 | # of tokens from each, since if one sequence is very short then each token 627 | # that's truncated likely contains more information than a longer sequence. 628 | while True: 629 | total_length = len(tokens_a) + len(tokens_b) 630 | if total_length <= max_length: 631 | break 632 | if len(tokens_a) > len(tokens_b): 633 | tokens_a.pop() 634 | else: 635 | tokens_b.pop() 636 | 637 | 638 | def accuracy(out, labels): 639 | outputs = np.argmax(out, axis=1) 640 | return np.sum(outputs == labels) 641 | 642 | 643 | if __name__ == "__main__": 644 | main() --------------------------------------------------------------------------------