├── .circleci └── config.yml ├── .github └── stale.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── docker └── Dockerfile ├── docs └── imgs │ ├── warmup_constant_schedule.png │ ├── warmup_cosine_hard_restarts_schedule.png │ ├── warmup_cosine_schedule.png │ ├── warmup_cosine_warm_restarts_schedule.png │ └── warmup_linear_schedule.png ├── examples ├── data.py ├── decode.py ├── eval_multi.py ├── eval_squad.py ├── extract_features.py ├── lm_finetuning │ ├── README.md │ ├── finetune_on_pregenerated.py │ ├── pregenerate_training_data.py │ └── simple_lm_finetuning.py ├── run_classifier.py ├── run_gpt2.py ├── run_openai_gpt.py ├── run_squad.py ├── run_swag.py ├── run_transfo_xl.py ├── utils.py └── vector.py ├── hubconf.py ├── notebooks ├── Comparing-TF-and-PT-models-MLM-NSP.ipynb ├── Comparing-TF-and-PT-models-SQuAD.ipynb └── Comparing-TF-and-PT-models.ipynb ├── pytorch_pretrained_bert ├── __init__.py ├── __main__.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── optimization.py ├── optimization_openai.py ├── tokenization.py ├── tokenization_gpt2.py ├── tokenization_openai.py └── tokenization_transfo_xl.py ├── requirements.txt ├── samples ├── input.txt └── sample_text.txt ├── setup.py └── tests ├── conftest.py ├── modeling_gpt2_test.py ├── modeling_openai_test.py ├── modeling_test.py ├── modeling_transfo_xl_test.py ├── optimization_test.py ├── tokenization_gpt2_test.py ├── tokenization_openai_test.py ├── tokenization_test.py └── tokenization_transfo_xl_test.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build_py3: 4 | working_directory: ~/pytorch-pretrained-BERT 5 | docker: 6 | - image: circleci/python:3.5 7 | steps: 8 | - checkout 9 | - run: sudo pip install --progress-bar off . 10 | - run: sudo pip install pytest ftfy spacy 11 | - run: sudo python -m spacy download en 12 | - run: python -m pytest -sv tests/ --runslow 13 | build_py2: 14 | working_directory: ~/pytorch-pretrained-BERT 15 | docker: 16 | - image: circleci/python:2.7 17 | steps: 18 | - checkout 19 | - run: sudo pip install --progress-bar off . 20 | - run: sudo pip install pytest spacy 21 | - run: sudo pip install ftfy==4.4.3 22 | - run: sudo python -m spacy download en 23 | - run: python -m pytest -sv tests/ --runslow 24 | workflows: 25 | version: 2 26 | build_and_test: 27 | jobs: 28 | - build_py3 29 | - build_py2 -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | # vscode 119 | .vscode 120 | 121 | # TF code 122 | tensorflow_code 123 | 124 | # Models 125 | models -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Code 2 | 3 | Code for multi-span answer Reading Comprehension 4 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | 3 | RUN git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext 4 | 5 | RUN pip install pytorch-pretrained-bert 6 | 7 | WORKDIR /workspace -------------------------------------------------------------------------------- /docs/imgs/warmup_constant_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinsu/multi_span/dcdc57da11a350e0d1521bc3039318b3d04975a2/docs/imgs/warmup_constant_schedule.png -------------------------------------------------------------------------------- /docs/imgs/warmup_cosine_hard_restarts_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinsu/multi_span/dcdc57da11a350e0d1521bc3039318b3d04975a2/docs/imgs/warmup_cosine_hard_restarts_schedule.png -------------------------------------------------------------------------------- /docs/imgs/warmup_cosine_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinsu/multi_span/dcdc57da11a350e0d1521bc3039318b3d04975a2/docs/imgs/warmup_cosine_schedule.png -------------------------------------------------------------------------------- /docs/imgs/warmup_cosine_warm_restarts_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinsu/multi_span/dcdc57da11a350e0d1521bc3039318b3d04975a2/docs/imgs/warmup_cosine_warm_restarts_schedule.png -------------------------------------------------------------------------------- /docs/imgs/warmup_linear_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lixinsu/multi_span/dcdc57da11a350e0d1521bc3039318b3d04975a2/docs/imgs/warmup_linear_schedule.png -------------------------------------------------------------------------------- /examples/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import json 5 | from pytorch_pretrained_bert.tokenization import (BasicTokenizer, 6 | BertTokenizer, 7 | whitespace_tokenize) 8 | 9 | class SquadExample(object): 10 | """ 11 | A single training/test example for the Squad dataset. 12 | For examples without an answer, the start and end position are -1. 13 | """ 14 | 15 | def __init__(self, 16 | qas_id, 17 | question_text, 18 | doc_tokens, 19 | orig_answer_text=None, 20 | start_position=None, 21 | end_position=None, 22 | is_impossible=None): 23 | self.qas_id = qas_id 24 | self.question_text = question_text 25 | self.doc_tokens = doc_tokens 26 | self.orig_answer_text = orig_answer_text 27 | self.start_position = start_position 28 | self.end_position = end_position 29 | self.is_impossible = is_impossible 30 | 31 | def __str__(self): 32 | return self.__repr__() 33 | 34 | def __repr__(self): 35 | s = "" 36 | s += "qas_id: %s\n" % (self.qas_id) 37 | s += ", question_text: %s\n" % ( 38 | self.question_text) 39 | s += ", doc_tokens: [{}]\n".format(" ".join(self.doc_tokens)) 40 | if self.start_position: 41 | s += ", start_position: {}\n" .format (self.start_position) 42 | if self.end_position: 43 | s += ", end_position: {}\n".format (self.end_position) 44 | if self.is_impossible: 45 | s += ", is_impossible: {}\n".format (self.is_impossible) 46 | return s 47 | 48 | 49 | def is_whitespace(c): 50 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 51 | return True 52 | return False 53 | 54 | def split_by_space(paragraph_text): 55 | doc_tokens = [] 56 | char_to_word_offset = [] 57 | prev_is_whitespace = True 58 | for c in paragraph_text: 59 | if is_whitespace(c): 60 | prev_is_whitespace = True 61 | else: 62 | if prev_is_whitespace: 63 | doc_tokens.append(c) 64 | else: 65 | doc_tokens[-1] += c 66 | prev_is_whitespace = False 67 | char_to_word_offset.append(len(doc_tokens) - 1) 68 | return doc_tokens, char_to_word_offset 69 | 70 | def read_squad_examples(input_file, is_training, version_2_with_negative): 71 | """Read a SQuAD json file into a list of SquadExample.""" 72 | with open(input_file, "r", encoding='utf-8') as reader: 73 | input_data = json.load(reader)["data"] 74 | 75 | examples = [] 76 | for entry in input_data: 77 | for paragraph in entry["paragraphs"]: 78 | paragraph_text = paragraph["context"] 79 | doc_tokens, char_to_word_offset = split_by_space(paragraph_text) 80 | for qa in paragraph["qas"]: 81 | qas_id = qa["id"] 82 | question_text = qa["question"] 83 | start_position = None 84 | end_position = None 85 | orig_answer_text = None 86 | is_impossible = False 87 | if is_training: 88 | if version_2_with_negative: 89 | is_impossible = qa["is_impossible"] 90 | if (len(qa["answers"]) != 1) and (not is_impossible): 91 | raise ValueError( 92 | "For training, each question should have exactly 1 answer.") 93 | if not is_impossible: 94 | answer = qa["answers"][0] 95 | orig_answer_text = answer["text"] 96 | answer_offset = answer["answer_start"] 97 | answer_length = len(orig_answer_text) 98 | start_position = char_to_word_offset[answer_offset] 99 | end_position = char_to_word_offset[answer_offset + answer_length - 1] 100 | # Only add answers where the text can be exactly recovered from the 101 | # document. If this CAN'T happen it's likely due to weird Unicode 102 | # stuff so we will just skip the example. 103 | # 104 | # Note that this means for training mode, every example is NOT 105 | # guaranteed to be preserved. 106 | actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) 107 | cleaned_answer_text = " ".join( 108 | whitespace_tokenize(orig_answer_text)) 109 | if actual_text.find(cleaned_answer_text) == -1: 110 | logger.warning("Could not find answer: '%s' vs. '%s'", 111 | actual_text, cleaned_answer_text) 112 | continue 113 | else: 114 | start_position = -1 115 | end_position = -1 116 | orig_answer_text = "" 117 | 118 | example = SquadExample( 119 | qas_id=qas_id, 120 | question_text=question_text, 121 | doc_tokens=doc_tokens, 122 | orig_answer_text=orig_answer_text, 123 | start_position=start_position, 124 | end_position=end_position, 125 | is_impossible=is_impossible) 126 | examples.append(example) 127 | return examples 128 | 129 | 130 | def read_multi_examples(input_file, is_training, version_2_with_negative=False): 131 | input_data = [json.loads(line) for line in open(input_file)] 132 | examples = [] 133 | for entry in input_data: 134 | paragraph_text = entry['passage'] 135 | doc_tokens, char_to_word_offset = split_by_space(paragraph_text) 136 | qas_id = entry['query_id'] 137 | question_text = entry['query'] 138 | is_impossible = False 139 | if is_training: 140 | start_position = [char_to_word_offset[x[0]]for x in entry['positions']] 141 | end_position = [char_to_word_offset[x[1]-1]for x in entry['positions']] 142 | orig_answer_text = entry['answers'] 143 | else: 144 | start_position = None 145 | end_position = None 146 | orig_answer_text = None 147 | example = SquadExample( 148 | qas_id=qas_id, 149 | question_text=question_text, 150 | doc_tokens=doc_tokens, 151 | orig_answer_text=orig_answer_text, 152 | start_position=start_position, 153 | end_position=end_position, 154 | is_impossible=is_impossible) 155 | examples.append(example) 156 | return examples 157 | -------------------------------------------------------------------------------- /examples/decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import sys 5 | import pickle 6 | import collections 7 | from utils import write_predictions_couple_labeling, write_predictions_single_labeling 8 | 9 | RawResult = collections.namedtuple("RawResult", 10 | ["unique_id", "start_logits", "end_logits"]) 11 | eval_examples, eval_features, all_results = pickle.load(open(sys.argv[1],'rb')) 12 | 13 | #write_predictions_couple_labeling(eval_examples, eval_features, all_results, 14 | # 20, 30, True, 'output.json', 'output_nbest.json', None, False, False, 0) 15 | 16 | write_predictions_single_labeling(eval_examples, eval_features, all_results, 17 | 20, 30, True, 'output.json', 'output_nbest.json', None, False, False, 0) 18 | -------------------------------------------------------------------------------- /examples/eval_multi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import json 5 | import sys 6 | import os 7 | import argparse 8 | import statistics 9 | from collections import defaultdict 10 | import ipdb 11 | 12 | 13 | def load_prediction_file(filename): 14 | return json.load(open(filename)) 15 | 16 | 17 | def load_reference_file(filename): 18 | records = [json.loads(line) for line in open(filename)] 19 | print('{} samples in dataset'.format(len(records))) 20 | return {datum['query_id']:datum['answers'] for datum in records} 21 | 22 | 23 | def lengths(gts): 24 | return sum([len(x) for x in gts]) * 1.0 / len(gts) 25 | 26 | 27 | 28 | def exact_match(qid2answers, qid2preditions): 29 | qid2em = defaultdict(dict) 30 | for qid, answers in qid2answers.items(): 31 | if qid not in qid2predictions: 32 | print('{} is not answered'.format(qid)) 33 | continue 34 | predictions = qid2predictions[qid] 35 | if isinstance( predictions, str ): 36 | predictions = [predictions ] 37 | if len(answers) != len(predictions): 38 | qid2em[qid]['em'] = 0.0 39 | continue 40 | if isinstance(predictions[0], dict): 41 | predictions = [pred['text'] for pred in predictions] 42 | if set(answers) == set(predictions): 43 | qid2em[qid]['em'] = 1.0 44 | else: 45 | qid2em[qid]['em'] = 0.0 46 | return 100*statistics.mean([x['em'] for x in qid2em.values()]), qid2em 47 | 48 | 49 | def _f1(preds, answers): 50 | if len(preds) == 0: 51 | return 0.0 52 | tp = len(set(preds) & set(answers)) 53 | if tp == 0: 54 | return 0.0 55 | p = float(tp) / len(preds) 56 | r = float(tp) / len(answers) 57 | return 2*p*r/(p+r) 58 | 59 | 60 | def f_measure(qid2answer, qid2prediction): 61 | qid2f1 = defaultdict(dict) 62 | for qid, answers in qid2answers.items(): 63 | if qid not in qid2predictions: 64 | print('{} is not answered'.format(qid)) 65 | continue 66 | predictions = qid2predictions[qid] 67 | if isinstance(predictions[0], dict): 68 | predictions = [pred['text'] for pred in predictions] 69 | if isinstance( predictions, str ): 70 | predictions = [predictions ] 71 | qid2f1[qid]['f1'] = _f1(predictions, answers) 72 | return 100*statistics.mean([x['f1'] for x in qid2f1.values()]), qid2f1 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('dataset', default=None) 78 | parser.add_argument('pred', default=None) 79 | args = parser.parse_args() 80 | qid2answers = load_reference_file(args.dataset) 81 | qid2predictions = load_prediction_file(args.pred) 82 | print('{} samples in prediction'.format(len(qid2predictions))) 83 | em, qid2em = exact_match(qid2answers, qid2predictions) 84 | f1, qid2f1 = f_measure(qid2answers, qid2predictions) 85 | print('em: {:.2f}\nf1: {:.2f}'.format(em, f1)) 86 | -------------------------------------------------------------------------------- /examples/eval_squad.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | def remove_articles(text): 14 | return re.sub(r'\b(a|an|the)\b', ' ', text) 15 | 16 | def white_space_fix(text): 17 | return ' '.join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | 29 | def f1_score(prediction, ground_truth): 30 | prediction_tokens = normalize_answer(prediction).split() 31 | ground_truth_tokens = normalize_answer(ground_truth).split() 32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 33 | num_same = sum(common.values()) 34 | if num_same == 0: 35 | return 0 36 | precision = 1.0 * num_same / len(prediction_tokens) 37 | recall = 1.0 * num_same / len(ground_truth_tokens) 38 | f1 = (2 * precision * recall) / (precision + recall) 39 | return f1 40 | 41 | 42 | def exact_match_score(prediction, ground_truth): 43 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 44 | 45 | 46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 47 | scores_for_ground_truths = [] 48 | for ground_truth in ground_truths: 49 | score = metric_fn(prediction, ground_truth) 50 | scores_for_ground_truths.append(score) 51 | return max(scores_for_ground_truths) 52 | 53 | 54 | def evaluate(dataset, predictions): 55 | f1 = exact_match = total = 0 56 | for article in dataset: 57 | for paragraph in article['paragraphs']: 58 | for qa in paragraph['qas']: 59 | total += 1 60 | if qa['id'] not in predictions: 61 | message = 'Unanswered question ' + qa['id'] + \ 62 | ' will receive score 0.' 63 | print(message, file=sys.stderr) 64 | continue 65 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 66 | prediction = predictions[qa['id']] 67 | exact_match += metric_max_over_ground_truths( 68 | exact_match_score, prediction, ground_truths) 69 | f1 += metric_max_over_ground_truths( 70 | f1_score, prediction, ground_truths) 71 | 72 | exact_match = 100.0 * exact_match / total 73 | f1 = 100.0 * f1 / total 74 | 75 | return {'exact_match': exact_match, 'f1': f1} 76 | 77 | 78 | if __name__ == '__main__': 79 | expected_version = '1.1' 80 | parser = argparse.ArgumentParser( 81 | description='Evaluation for SQuAD ' + expected_version) 82 | parser.add_argument('dataset_file', help='Dataset file') 83 | parser.add_argument('prediction_file', help='Prediction File') 84 | args = parser.parse_args() 85 | with open(args.dataset_file) as dataset_file: 86 | dataset_json = json.load(dataset_file) 87 | if (dataset_json['version'] != expected_version): 88 | print('Evaluation expects v-' + expected_version + 89 | ', but got dataset with v-' + dataset_json['version'], 90 | file=sys.stderr) 91 | dataset = dataset_json['data'] 92 | with open(args.prediction_file) as prediction_file: 93 | predictions = json.load(prediction_file) 94 | print(json.dumps(evaluate(dataset, predictions))) 95 | -------------------------------------------------------------------------------- /examples/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from a PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import collections 23 | import logging 24 | import json 25 | import re 26 | 27 | import torch 28 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 29 | from torch.utils.data.distributed import DistributedSampler 30 | 31 | from pytorch_pretrained_bert.tokenization import BertTokenizer 32 | from pytorch_pretrained_bert.modeling import BertModel 33 | 34 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 35 | datefmt = '%m/%d/%Y %H:%M:%S', 36 | level = logging.INFO) 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | class InputExample(object): 41 | 42 | def __init__(self, unique_id, text_a, text_b): 43 | self.unique_id = unique_id 44 | self.text_a = text_a 45 | self.text_b = text_b 46 | 47 | 48 | class InputFeatures(object): 49 | """A single set of features of data.""" 50 | 51 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 52 | self.unique_id = unique_id 53 | self.tokens = tokens 54 | self.input_ids = input_ids 55 | self.input_mask = input_mask 56 | self.input_type_ids = input_type_ids 57 | 58 | 59 | def convert_examples_to_features(examples, seq_length, tokenizer): 60 | """Loads a data file into a list of `InputFeature`s.""" 61 | 62 | features = [] 63 | for (ex_index, example) in enumerate(examples): 64 | tokens_a = tokenizer.tokenize(example.text_a) 65 | 66 | tokens_b = None 67 | if example.text_b: 68 | tokens_b = tokenizer.tokenize(example.text_b) 69 | 70 | if tokens_b: 71 | # Modifies `tokens_a` and `tokens_b` in place so that the total 72 | # length is less than the specified length. 73 | # Account for [CLS], [SEP], [SEP] with "- 3" 74 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 75 | else: 76 | # Account for [CLS] and [SEP] with "- 2" 77 | if len(tokens_a) > seq_length - 2: 78 | tokens_a = tokens_a[0:(seq_length - 2)] 79 | 80 | # The convention in BERT is: 81 | # (a) For sequence pairs: 82 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 83 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 84 | # (b) For single sequences: 85 | # tokens: [CLS] the dog is hairy . [SEP] 86 | # type_ids: 0 0 0 0 0 0 0 87 | # 88 | # Where "type_ids" are used to indicate whether this is the first 89 | # sequence or the second sequence. The embedding vectors for `type=0` and 90 | # `type=1` were learned during pre-training and are added to the wordpiece 91 | # embedding vector (and position vector). This is not *strictly* necessary 92 | # since the [SEP] token unambigiously separates the sequences, but it makes 93 | # it easier for the model to learn the concept of sequences. 94 | # 95 | # For classification tasks, the first vector (corresponding to [CLS]) is 96 | # used as as the "sentence vector". Note that this only makes sense because 97 | # the entire model is fine-tuned. 98 | tokens = [] 99 | input_type_ids = [] 100 | tokens.append("[CLS]") 101 | input_type_ids.append(0) 102 | for token in tokens_a: 103 | tokens.append(token) 104 | input_type_ids.append(0) 105 | tokens.append("[SEP]") 106 | input_type_ids.append(0) 107 | 108 | if tokens_b: 109 | for token in tokens_b: 110 | tokens.append(token) 111 | input_type_ids.append(1) 112 | tokens.append("[SEP]") 113 | input_type_ids.append(1) 114 | 115 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 116 | 117 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 118 | # tokens are attended to. 119 | input_mask = [1] * len(input_ids) 120 | 121 | # Zero-pad up to the sequence length. 122 | while len(input_ids) < seq_length: 123 | input_ids.append(0) 124 | input_mask.append(0) 125 | input_type_ids.append(0) 126 | 127 | assert len(input_ids) == seq_length 128 | assert len(input_mask) == seq_length 129 | assert len(input_type_ids) == seq_length 130 | 131 | if ex_index < 5: 132 | logger.info("*** Example ***") 133 | logger.info("unique_id: %s" % (example.unique_id)) 134 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) 135 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 136 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 137 | logger.info( 138 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 139 | 140 | features.append( 141 | InputFeatures( 142 | unique_id=example.unique_id, 143 | tokens=tokens, 144 | input_ids=input_ids, 145 | input_mask=input_mask, 146 | input_type_ids=input_type_ids)) 147 | return features 148 | 149 | 150 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 151 | """Truncates a sequence pair in place to the maximum length.""" 152 | 153 | # This is a simple heuristic which will always truncate the longer sequence 154 | # one token at a time. This makes more sense than truncating an equal percent 155 | # of tokens from each, since if one sequence is very short then each token 156 | # that's truncated likely contains more information than a longer sequence. 157 | while True: 158 | total_length = len(tokens_a) + len(tokens_b) 159 | if total_length <= max_length: 160 | break 161 | if len(tokens_a) > len(tokens_b): 162 | tokens_a.pop() 163 | else: 164 | tokens_b.pop() 165 | 166 | 167 | def read_examples(input_file): 168 | """Read a list of `InputExample`s from an input file.""" 169 | examples = [] 170 | unique_id = 0 171 | with open(input_file, "r", encoding='utf-8') as reader: 172 | while True: 173 | line = reader.readline() 174 | if not line: 175 | break 176 | line = line.strip() 177 | text_a = None 178 | text_b = None 179 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 180 | if m is None: 181 | text_a = line 182 | else: 183 | text_a = m.group(1) 184 | text_b = m.group(2) 185 | examples.append( 186 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 187 | unique_id += 1 188 | return examples 189 | 190 | 191 | def main(): 192 | parser = argparse.ArgumentParser() 193 | 194 | ## Required parameters 195 | parser.add_argument("--input_file", default=None, type=str, required=True) 196 | parser.add_argument("--output_file", default=None, type=str, required=True) 197 | parser.add_argument("--bert_model", default=None, type=str, required=True, 198 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 199 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 200 | 201 | ## Other parameters 202 | parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") 203 | parser.add_argument("--layers", default="-1,-2,-3,-4", type=str) 204 | parser.add_argument("--max_seq_length", default=128, type=int, 205 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer " 206 | "than this will be truncated, and sequences shorter than this will be padded.") 207 | parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.") 208 | parser.add_argument("--local_rank", 209 | type=int, 210 | default=-1, 211 | help = "local_rank for distributed training on gpus") 212 | parser.add_argument("--no_cuda", 213 | action='store_true', 214 | help="Whether not to use CUDA when available") 215 | 216 | args = parser.parse_args() 217 | 218 | if args.local_rank == -1 or args.no_cuda: 219 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 220 | n_gpu = torch.cuda.device_count() 221 | else: 222 | device = torch.device("cuda", args.local_rank) 223 | n_gpu = 1 224 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 225 | torch.distributed.init_process_group(backend='nccl') 226 | logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1))) 227 | 228 | layer_indexes = [int(x) for x in args.layers.split(",")] 229 | 230 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 231 | 232 | examples = read_examples(args.input_file) 233 | 234 | features = convert_examples_to_features( 235 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer) 236 | 237 | unique_id_to_feature = {} 238 | for feature in features: 239 | unique_id_to_feature[feature.unique_id] = feature 240 | 241 | model = BertModel.from_pretrained(args.bert_model) 242 | model.to(device) 243 | 244 | if args.local_rank != -1: 245 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 246 | output_device=args.local_rank) 247 | elif n_gpu > 1: 248 | model = torch.nn.DataParallel(model) 249 | 250 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 251 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 252 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 253 | 254 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) 255 | if args.local_rank == -1: 256 | eval_sampler = SequentialSampler(eval_data) 257 | else: 258 | eval_sampler = DistributedSampler(eval_data) 259 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) 260 | 261 | model.eval() 262 | with open(args.output_file, "w", encoding='utf-8') as writer: 263 | for input_ids, input_mask, example_indices in eval_dataloader: 264 | input_ids = input_ids.to(device) 265 | input_mask = input_mask.to(device) 266 | 267 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) 268 | all_encoder_layers = all_encoder_layers 269 | 270 | for b, example_index in enumerate(example_indices): 271 | feature = features[example_index.item()] 272 | unique_id = int(feature.unique_id) 273 | # feature = unique_id_to_feature[unique_id] 274 | output_json = collections.OrderedDict() 275 | output_json["linex_index"] = unique_id 276 | all_out_features = [] 277 | for (i, token) in enumerate(feature.tokens): 278 | all_layers = [] 279 | for (j, layer_index) in enumerate(layer_indexes): 280 | layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy() 281 | layer_output = layer_output[b] 282 | layers = collections.OrderedDict() 283 | layers["index"] = layer_index 284 | layers["values"] = [ 285 | round(x.item(), 6) for x in layer_output[i] 286 | ] 287 | all_layers.append(layers) 288 | out_features = collections.OrderedDict() 289 | out_features["token"] = token 290 | out_features["layers"] = all_layers 291 | all_out_features.append(out_features) 292 | output_json["features"] = all_out_features 293 | writer.write(json.dumps(output_json) + "\n") 294 | 295 | 296 | if __name__ == "__main__": 297 | main() 298 | -------------------------------------------------------------------------------- /examples/lm_finetuning/README.md: -------------------------------------------------------------------------------- 1 | # BERT Model Finetuning using Masked Language Modeling objective 2 | 3 | ## Introduction 4 | 5 | The three example scripts in this folder can be used to **fine-tune** a pre-trained BERT model using the pretraining objective (combination of masked language modeling and next sentence prediction loss). In general, pretrained models like BERT are first trained with a pretraining objective (masked language modeling and next sentence prediction for BERT) on a large and general natural language corpus. A classifier head is then added on top of the pre-trained architecture and the model is quickly fine-tuned on a target task, while still (hopefully) retaining its general language understanding. This greatly reduces overfitting and yields state-of-the-art results, especially when training data for the target task are limited. 6 | 7 | The [ULMFiT paper](https://arxiv.org/abs/1801.06146) took a slightly different approach, however, and added an intermediate step in which the model is fine-tuned on text **from the same domain as the target task and using the pretraining objective** before the final stage in which the classifier head is added and the model is trained on the target task itself. This paper reported significantly improved results from this step, and found that they could get high-quality classifications even with only tiny numbers (<1000) of labelled training examples, as long as they had a lot of unlabelled data from the target domain. 8 | 9 | The BERT model has more capacity than the LSTM models used in the ULMFiT work, but the [BERT paper](https://arxiv.org/abs/1810.04805) did not test finetuning using the pretraining objective and at the present stage there aren't many examples of this approach being used for Transformer-based language models. As such, it's hard to predict what effect this step will have on final model performance, but it's reasonable to conjecture that this approach can improve the final classification performance, especially when a large unlabelled corpus from the target domain is available, labelled data is limited, or the target domain is very unusual and different from 'normal' English text. If you are aware of any literature on this subject, please feel free to add it in here, or open an issue and tag me (@Rocketknight1) and I'll include it. 10 | 11 | ## Input format 12 | 13 | The scripts in this folder expect a single file as input, consisting of untokenized text, with one **sentence** per line, and one blank line between documents. The reason for the sentence splitting is that part of BERT's training involves a _next sentence_ objective in which the model must predict whether two sequences of text are contiguous text from the same document or not, and to avoid making the task _too easy_, the split point between the sequences is always at the end of a sentence. The linebreaks in the file are therefore necessary to mark the points where the text can be split. 14 | 15 | ## Usage 16 | 17 | There are two ways to fine-tune a language model using these scripts. The first _quick_ approach is to use [`simple_lm_finetuning.py`](./simple_lm_finetuning.py). This script does everything in a single script, but generates training instances that consist of just two sentences. This is quite different from the BERT paper, where (confusingly) the NextSentence task concatenated sentences together from each document to form two long multi-sentences, which the paper just referred to as _sentences_. The difference between this simple approach and the original paper approach can have a significant effect for long sequences since two sentences will be much shorter than the max sequence length. In this case, most of each training example will just consist of blank padding characters, which wastes a lot of computation and results in a model that isn't really training on long sequences. 18 | 19 | As such, the preferred approach (assuming you have documents containing multiple contiguous sentences from your target domain) is to use [`pregenerate_training_data.py`](./pregenerate_training_data.py) to pre-process your data into training examples following the methodology used for LM training in the original BERT paper and repository. Since there is a significant random component to training data generation for BERT, this script includes an option to generate multiple _epochs_ of pre-processed data, to avoid training on the same random splits each epoch. Generating an epoch of data for each training epoch should result a better final model, and so we recommend doing so. 20 | 21 | You can then train on the pregenerated data using [`finetune_on_pregenerated.py`](./finetune_on_pregenerated.py), and pointing it to the folder created by [`pregenerate_training_data.py`](./pregenerate_training_data.py). Note that you should use the same `bert_model` and case options for both! Also note that `max_seq_len` does not need to be specified for the [`finetune_on_pregenerated.py`](./finetune_on_pregenerated.py) script, as it is inferred from the training examples. 22 | 23 | There are various options that can be tweaked, but they are mostly set to the values from the BERT paper/repository and default values should make sense. The most relevant ones are: 24 | 25 | - `--max_seq_len`: Controls the length of training examples (in wordpiece tokens) seen by the model. Defaults to 128 but can be set as high as 512. Higher values may yield stronger language models at the cost of slower and more memory-intensive training. 26 | - `--fp16`: Enables fast half-precision training on recent GPUs. 27 | 28 | In addition, if memory usage is an issue, especially when training on a single GPU, reducing `--train_batch_size` from the default 32 to a lower number (4-16) can be helpful, or leaving `--train_batch_size` at the default and increasing `--gradient_accumulation_steps` to 2-8. Changing `--gradient_accumulation_steps` may be preferable as alterations to the batch size may require corresponding changes in the learning rate to compensate. There is also a `--reduce_memory` option for both the `pregenerate_training_data.py` and `finetune_on_pregenerated.py` scripts that spills data to disc in shelf objects or numpy memmaps rather than retaining it in memory, which significantly reduces memory usage with little performance impact. 29 | 30 | ## Examples 31 | 32 | ### Simple fine-tuning 33 | 34 | ``` 35 | python3 simple_lm_finetuning.py 36 | --train_corpus my_corpus.txt 37 | --bert_model bert-base-uncased 38 | --do_lower_case 39 | --output_dir finetuned_lm/ 40 | --do_train 41 | ``` 42 | 43 | ### Pregenerating training data 44 | 45 | ``` 46 | python3 pregenerate_training_data.py 47 | --train_corpus my_corpus.txt 48 | --bert_model bert-base-uncased 49 | --do_lower_case 50 | --output_dir training/ 51 | --epochs_to_generate 3 52 | --max_seq_len 256 53 | ``` 54 | 55 | ### Training on pregenerated data 56 | 57 | ``` 58 | python3 finetune_on_pregenerated.py 59 | --pregenerated_data training/ 60 | --bert_model bert-base-uncased 61 | --do_lower_case 62 | --output_dir finetuned_lm/ 63 | --epochs 3 64 | ``` 65 | -------------------------------------------------------------------------------- /examples/lm_finetuning/pregenerate_training_data.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | from tqdm import tqdm, trange 4 | from tempfile import TemporaryDirectory 5 | import shelve 6 | 7 | from random import random, randrange, randint, shuffle, choice, sample 8 | from pytorch_pretrained_bert.tokenization import BertTokenizer 9 | import numpy as np 10 | import json 11 | 12 | 13 | class DocumentDatabase: 14 | def __init__(self, reduce_memory=False): 15 | if reduce_memory: 16 | self.temp_dir = TemporaryDirectory() 17 | self.working_dir = Path(self.temp_dir.name) 18 | self.document_shelf_filepath = self.working_dir / 'shelf.db' 19 | self.document_shelf = shelve.open(str(self.document_shelf_filepath), 20 | flag='n', protocol=-1) 21 | self.documents = None 22 | else: 23 | self.documents = [] 24 | self.document_shelf = None 25 | self.document_shelf_filepath = None 26 | self.temp_dir = None 27 | self.doc_lengths = [] 28 | self.doc_cumsum = None 29 | self.cumsum_max = None 30 | self.reduce_memory = reduce_memory 31 | 32 | def add_document(self, document): 33 | if not document: 34 | return 35 | if self.reduce_memory: 36 | current_idx = len(self.doc_lengths) 37 | self.document_shelf[str(current_idx)] = document 38 | else: 39 | self.documents.append(document) 40 | self.doc_lengths.append(len(document)) 41 | 42 | def _precalculate_doc_weights(self): 43 | self.doc_cumsum = np.cumsum(self.doc_lengths) 44 | self.cumsum_max = self.doc_cumsum[-1] 45 | 46 | def sample_doc(self, current_idx, sentence_weighted=True): 47 | # Uses the current iteration counter to ensure we don't sample the same doc twice 48 | if sentence_weighted: 49 | # With sentence weighting, we sample docs proportionally to their sentence length 50 | if self.doc_cumsum is None or len(self.doc_cumsum) != len(self.doc_lengths): 51 | self._precalculate_doc_weights() 52 | rand_start = self.doc_cumsum[current_idx] 53 | rand_end = rand_start + self.cumsum_max - self.doc_lengths[current_idx] 54 | sentence_index = randrange(rand_start, rand_end) % self.cumsum_max 55 | sampled_doc_index = np.searchsorted(self.doc_cumsum, sentence_index, side='right') 56 | else: 57 | # If we don't use sentence weighting, then every doc has an equal chance to be chosen 58 | sampled_doc_index = (current_idx + randrange(1, len(self.doc_lengths))) % len(self.doc_lengths) 59 | assert sampled_doc_index != current_idx 60 | if self.reduce_memory: 61 | return self.document_shelf[str(sampled_doc_index)] 62 | else: 63 | return self.documents[sampled_doc_index] 64 | 65 | def __len__(self): 66 | return len(self.doc_lengths) 67 | 68 | def __getitem__(self, item): 69 | if self.reduce_memory: 70 | return self.document_shelf[str(item)] 71 | else: 72 | return self.documents[item] 73 | 74 | def __enter__(self): 75 | return self 76 | 77 | def __exit__(self, exc_type, exc_val, traceback): 78 | if self.document_shelf is not None: 79 | self.document_shelf.close() 80 | if self.temp_dir is not None: 81 | self.temp_dir.cleanup() 82 | 83 | 84 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): 85 | """Truncates a pair of sequences to a maximum sequence length. Lifted from Google's BERT repo.""" 86 | while True: 87 | total_length = len(tokens_a) + len(tokens_b) 88 | if total_length <= max_num_tokens: 89 | break 90 | 91 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 92 | assert len(trunc_tokens) >= 1 93 | 94 | # We want to sometimes truncate from the front and sometimes from the 95 | # back to add more randomness and avoid biases. 96 | if random() < 0.5: 97 | del trunc_tokens[0] 98 | else: 99 | trunc_tokens.pop() 100 | 101 | 102 | def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list): 103 | """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but 104 | with several refactors to clean it up and remove a lot of unnecessary variables.""" 105 | cand_indices = [] 106 | for (i, token) in enumerate(tokens): 107 | if token == "[CLS]" or token == "[SEP]": 108 | continue 109 | cand_indices.append(i) 110 | 111 | num_to_mask = min(max_predictions_per_seq, 112 | max(1, int(round(len(tokens) * masked_lm_prob)))) 113 | shuffle(cand_indices) 114 | mask_indices = sorted(sample(cand_indices, num_to_mask)) 115 | masked_token_labels = [] 116 | for index in mask_indices: 117 | # 80% of the time, replace with [MASK] 118 | if random() < 0.8: 119 | masked_token = "[MASK]" 120 | else: 121 | # 10% of the time, keep original 122 | if random() < 0.5: 123 | masked_token = tokens[index] 124 | # 10% of the time, replace with random word 125 | else: 126 | masked_token = choice(vocab_list) 127 | masked_token_labels.append(tokens[index]) 128 | # Once we've saved the true label for that token, we can overwrite it with the masked version 129 | tokens[index] = masked_token 130 | 131 | return tokens, mask_indices, masked_token_labels 132 | 133 | 134 | def create_instances_from_document( 135 | doc_database, doc_idx, max_seq_length, short_seq_prob, 136 | masked_lm_prob, max_predictions_per_seq, vocab_list): 137 | """This code is mostly a duplicate of the equivalent function from Google BERT's repo. 138 | However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function. 139 | Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence 140 | (rather than each document) has an equal chance of being sampled as a false example for the NextSentence task.""" 141 | document = doc_database[doc_idx] 142 | # Account for [CLS], [SEP], [SEP] 143 | max_num_tokens = max_seq_length - 3 144 | 145 | # We *usually* want to fill up the entire sequence since we are padding 146 | # to `max_seq_length` anyways, so short sequences are generally wasted 147 | # computation. However, we *sometimes* 148 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 149 | # sequences to minimize the mismatch between pre-training and fine-tuning. 150 | # The `target_seq_length` is just a rough target however, whereas 151 | # `max_seq_length` is a hard limit. 152 | target_seq_length = max_num_tokens 153 | if random() < short_seq_prob: 154 | target_seq_length = randint(2, max_num_tokens) 155 | 156 | # We DON'T just concatenate all of the tokens from a document into a long 157 | # sequence and choose an arbitrary split point because this would make the 158 | # next sentence prediction task too easy. Instead, we split the input into 159 | # segments "A" and "B" based on the actual "sentences" provided by the user 160 | # input. 161 | instances = [] 162 | current_chunk = [] 163 | current_length = 0 164 | i = 0 165 | while i < len(document): 166 | segment = document[i] 167 | current_chunk.append(segment) 168 | current_length += len(segment) 169 | if i == len(document) - 1 or current_length >= target_seq_length: 170 | if current_chunk: 171 | # `a_end` is how many segments from `current_chunk` go into the `A` 172 | # (first) sentence. 173 | a_end = 1 174 | if len(current_chunk) >= 2: 175 | a_end = randrange(1, len(current_chunk)) 176 | 177 | tokens_a = [] 178 | for j in range(a_end): 179 | tokens_a.extend(current_chunk[j]) 180 | 181 | tokens_b = [] 182 | 183 | # Random next 184 | if len(current_chunk) == 1 or random() < 0.5: 185 | is_random_next = True 186 | target_b_length = target_seq_length - len(tokens_a) 187 | 188 | # Sample a random document, with longer docs being sampled more frequently 189 | random_document = doc_database.sample_doc(current_idx=doc_idx, sentence_weighted=True) 190 | 191 | random_start = randrange(0, len(random_document)) 192 | for j in range(random_start, len(random_document)): 193 | tokens_b.extend(random_document[j]) 194 | if len(tokens_b) >= target_b_length: 195 | break 196 | # We didn't actually use these segments so we "put them back" so 197 | # they don't go to waste. 198 | num_unused_segments = len(current_chunk) - a_end 199 | i -= num_unused_segments 200 | # Actual next 201 | else: 202 | is_random_next = False 203 | for j in range(a_end, len(current_chunk)): 204 | tokens_b.extend(current_chunk[j]) 205 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) 206 | 207 | assert len(tokens_a) >= 1 208 | assert len(tokens_b) >= 1 209 | 210 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] 211 | # The segment IDs are 0 for the [CLS] token, the A tokens and the first [SEP] 212 | # They are 1 for the B tokens and the final [SEP] 213 | segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] 214 | 215 | tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions( 216 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_list) 217 | 218 | instance = { 219 | "tokens": tokens, 220 | "segment_ids": segment_ids, 221 | "is_random_next": is_random_next, 222 | "masked_lm_positions": masked_lm_positions, 223 | "masked_lm_labels": masked_lm_labels} 224 | instances.append(instance) 225 | current_chunk = [] 226 | current_length = 0 227 | i += 1 228 | 229 | return instances 230 | 231 | 232 | def main(): 233 | parser = ArgumentParser() 234 | parser.add_argument('--train_corpus', type=Path, required=True) 235 | parser.add_argument("--output_dir", type=Path, required=True) 236 | parser.add_argument("--bert_model", type=str, required=True, 237 | choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", 238 | "bert-base-multilingual", "bert-base-chinese"]) 239 | parser.add_argument("--do_lower_case", action="store_true") 240 | 241 | parser.add_argument("--reduce_memory", action="store_true", 242 | help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") 243 | 244 | parser.add_argument("--epochs_to_generate", type=int, default=3, 245 | help="Number of epochs of data to pregenerate") 246 | parser.add_argument("--max_seq_len", type=int, default=128) 247 | parser.add_argument("--short_seq_prob", type=float, default=0.1, 248 | help="Probability of making a short sentence as a training example") 249 | parser.add_argument("--masked_lm_prob", type=float, default=0.15, 250 | help="Probability of masking each token for the LM task") 251 | parser.add_argument("--max_predictions_per_seq", type=int, default=20, 252 | help="Maximum number of tokens to mask in each sequence") 253 | 254 | args = parser.parse_args() 255 | 256 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 257 | vocab_list = list(tokenizer.vocab.keys()) 258 | with DocumentDatabase(reduce_memory=args.reduce_memory) as docs: 259 | with args.train_corpus.open() as f: 260 | doc = [] 261 | for line in tqdm(f, desc="Loading Dataset", unit=" lines"): 262 | line = line.strip() 263 | if line == "": 264 | docs.add_document(doc) 265 | doc = [] 266 | else: 267 | tokens = tokenizer.tokenize(line) 268 | doc.append(tokens) 269 | if doc: 270 | docs.add_document(doc) # If the last doc didn't end on a newline, make sure it still gets added 271 | if len(docs) <= 1: 272 | exit("ERROR: No document breaks were found in the input file! These are necessary to allow the script to " 273 | "ensure that random NextSentences are not sampled from the same document. Please add blank lines to " 274 | "indicate breaks between documents in your input file. If your dataset does not contain multiple " 275 | "documents, blank lines can be inserted at any natural boundary, such as the ends of chapters, " 276 | "sections or paragraphs.") 277 | 278 | args.output_dir.mkdir(exist_ok=True) 279 | for epoch in trange(args.epochs_to_generate, desc="Epoch"): 280 | epoch_filename = args.output_dir / f"epoch_{epoch}.json" 281 | num_instances = 0 282 | with epoch_filename.open('w') as epoch_file: 283 | for doc_idx in trange(len(docs), desc="Document"): 284 | doc_instances = create_instances_from_document( 285 | docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, 286 | masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, 287 | vocab_list=vocab_list) 288 | doc_instances = [json.dumps(instance) for instance in doc_instances] 289 | for instance in doc_instances: 290 | epoch_file.write(instance + '\n') 291 | num_instances += 1 292 | metrics_file = args.output_dir / f"epoch_{epoch}_metrics.json" 293 | with metrics_file.open('w') as metrics_file: 294 | metrics = { 295 | "num_training_examples": num_instances, 296 | "max_seq_len": args.max_seq_len 297 | } 298 | metrics_file.write(json.dumps(metrics)) 299 | 300 | 301 | if __name__ == '__main__': 302 | main() 303 | -------------------------------------------------------------------------------- /examples/run_gpt2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import logging 5 | from tqdm import trange 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer 12 | 13 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 14 | datefmt = '%m/%d/%Y %H:%M:%S', 15 | level = logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | def top_k_logits(logits, k): 19 | """ 20 | Masks everything but the k top entries as -infinity (1e10). 21 | Used to mask logits such that e^-infinity -> 0 won't contribute to the 22 | sum of the denominator. 23 | """ 24 | if k == 0: 25 | return logits 26 | else: 27 | values = torch.topk(logits, k)[0] 28 | batch_mins = values[:, -1].view(-1, 1).expand_as(logits) 29 | return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits) 30 | 31 | def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True): 32 | if start_token is None: 33 | assert context is not None, 'Specify exactly one of start_token and context!' 34 | context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) 35 | else: 36 | assert context is None, 'Specify exactly one of start_token and context!' 37 | context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long) 38 | prev = context 39 | output = context 40 | past = None 41 | with torch.no_grad(): 42 | for i in trange(length): 43 | logits, past = model(prev, past=past) 44 | logits = logits[:, -1, :] / temperature 45 | logits = top_k_logits(logits, k=top_k) 46 | log_probs = F.softmax(logits, dim=-1) 47 | if sample: 48 | prev = torch.multinomial(log_probs, num_samples=1) 49 | else: 50 | _, prev = torch.topk(log_probs, k=1, dim=-1) 51 | output = torch.cat((output, prev), dim=1) 52 | return output 53 | 54 | def run_model(): 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint') 57 | parser.add_argument("--seed", type=int, default=0) 58 | parser.add_argument("--nsamples", type=int, default=1) 59 | parser.add_argument("--batch_size", type=int, default=-1) 60 | parser.add_argument("--length", type=int, default=-1) 61 | parser.add_argument("--temperature", type=float, default=1.0) 62 | parser.add_argument("--top_k", type=int, default=0) 63 | parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') 64 | args = parser.parse_args() 65 | print(args) 66 | 67 | if args.batch_size == -1: 68 | args.batch_size = 1 69 | assert args.nsamples % args.batch_size == 0 70 | 71 | np.random.seed(args.seed) 72 | torch.random.manual_seed(args.seed) 73 | torch.cuda.manual_seed(args.seed) 74 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 75 | 76 | enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) 77 | model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) 78 | model.to(device) 79 | model.eval() 80 | 81 | if args.length == -1: 82 | args.length = model.config.n_ctx // 2 83 | elif args.length > model.config.n_ctx: 84 | raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) 85 | 86 | while True: 87 | context_tokens = [] 88 | if not args.unconditional: 89 | raw_text = input("Model prompt >>> ") 90 | while not raw_text: 91 | print('Prompt should not be empty!') 92 | raw_text = input("Model prompt >>> ") 93 | context_tokens = enc.encode(raw_text) 94 | generated = 0 95 | for _ in range(args.nsamples // args.batch_size): 96 | out = sample_sequence( 97 | model=model, length=args.length, 98 | context=context_tokens, 99 | start_token=None, 100 | batch_size=args.batch_size, 101 | temperature=args.temperature, top_k=args.top_k, device=device 102 | ) 103 | out = out[:, len(context_tokens):].tolist() 104 | for i in range(args.batch_size): 105 | generated += 1 106 | text = enc.decode(out[i]) 107 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 108 | print(text) 109 | print("=" * 80) 110 | else: 111 | generated = 0 112 | for _ in range(args.nsamples // args.batch_size): 113 | out = sample_sequence( 114 | model=model, length=args.length, 115 | context=None, 116 | start_token=enc.encoder['<|endoftext|>'], 117 | batch_size=args.batch_size, 118 | temperature=args.temperature, top_k=args.top_k, device=device 119 | ) 120 | out = out[:,1:].tolist() 121 | for i in range(args.batch_size): 122 | generated += 1 123 | text = enc.decode(out[i]) 124 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 125 | print(text) 126 | print("=" * 80) 127 | 128 | if __name__ == '__main__': 129 | run_model() 130 | 131 | 132 | -------------------------------------------------------------------------------- /examples/run_openai_gpt.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT model fine-tuning script. 17 | Adapted from https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/train.py 18 | It self adapted from https://github.com/openai/finetune-transformer-lm/blob/master/train.py 19 | 20 | This script with default values fine-tunes and evaluate a pretrained OpenAI GPT on the RocStories dataset: 21 | python run_openai_gpt.py \ 22 | --model_name openai-gpt \ 23 | --do_train \ 24 | --do_eval \ 25 | --train_dataset $ROC_STORIES_DIR/cloze_test_val__spring2016\ -\ cloze_test_ALL_val.csv \ 26 | --eval_dataset $ROC_STORIES_DIR/cloze_test_test__spring2016\ -\ cloze_test_ALL_test.csv \ 27 | --output_dir ../log \ 28 | --train_batch_size 16 \ 29 | """ 30 | import argparse 31 | import os 32 | import csv 33 | import random 34 | import logging 35 | from tqdm import tqdm, trange 36 | 37 | import numpy as np 38 | import torch 39 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 40 | TensorDataset) 41 | 42 | from pytorch_pretrained_bert import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, 43 | OpenAIAdam, cached_path, WEIGHTS_NAME, CONFIG_NAME) 44 | 45 | ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" 46 | 47 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 48 | datefmt = '%m/%d/%Y %H:%M:%S', 49 | level = logging.INFO) 50 | logger = logging.getLogger(__name__) 51 | 52 | def accuracy(out, labels): 53 | outputs = np.argmax(out, axis=1) 54 | return np.sum(outputs == labels) 55 | 56 | def load_rocstories_dataset(dataset_path): 57 | """ Output a list of tuples(story, 1st continuation, 2nd continuation, label) """ 58 | with open(dataset_path, encoding='utf_8') as f: 59 | f = csv.reader(f) 60 | output = [] 61 | next(f) # skip the first line 62 | for line in tqdm(f): 63 | output.append((' '.join(line[1:5]), line[5], line[6], int(line[-1])-1)) 64 | return output 65 | 66 | def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token): 67 | """ Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label) 68 | 69 | To Transformer inputs of shape (n_batch, n_alternative, length) comprising for each batch, continuation: 70 | input_ids[batch, alternative, :] = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] 71 | """ 72 | tensor_datasets = [] 73 | for dataset in encoded_datasets: 74 | n_batch = len(dataset) 75 | input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64) 76 | mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64) 77 | lm_labels = np.full((n_batch, 2, input_len), fill_value=-1, dtype=np.int64) 78 | mc_labels = np.zeros((n_batch,), dtype=np.int64) 79 | for i, (story, cont1, cont2, mc_label), in enumerate(dataset): 80 | with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] 81 | with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token] 82 | input_ids[i, 0, :len(with_cont1)] = with_cont1 83 | input_ids[i, 1, :len(with_cont2)] = with_cont2 84 | mc_token_ids[i, 0] = len(with_cont1) - 1 85 | mc_token_ids[i, 1] = len(with_cont2) - 1 86 | lm_labels[i, 0, :len(with_cont1)] = with_cont1 87 | lm_labels[i, 1, :len(with_cont2)] = with_cont2 88 | mc_labels[i] = mc_label 89 | all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels) 90 | tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs)) 91 | return tensor_datasets 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('--model_name', type=str, default='openai-gpt', 96 | help='pretrained model name') 97 | parser.add_argument("--do_train", action='store_true', help="Whether to run training.") 98 | parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") 99 | parser.add_argument("--output_dir", default=None, type=str, required=True, 100 | help="The output directory where the model predictions and checkpoints will be written.") 101 | parser.add_argument('--train_dataset', type=str, default='') 102 | parser.add_argument('--eval_dataset', type=str, default='') 103 | parser.add_argument('--seed', type=int, default=42) 104 | parser.add_argument('--num_train_epochs', type=int, default=3) 105 | parser.add_argument('--train_batch_size', type=int, default=8) 106 | parser.add_argument('--eval_batch_size', type=int, default=16) 107 | parser.add_argument('--max_grad_norm', type=int, default=1) 108 | parser.add_argument('--learning_rate', type=float, default=6.25e-5) 109 | parser.add_argument('--warmup_proportion', type=float, default=0.002) 110 | parser.add_argument('--lr_schedule', type=str, default='warmup_linear') 111 | parser.add_argument('--weight_decay', type=float, default=0.01) 112 | parser.add_argument('--lm_coef', type=float, default=0.9) 113 | parser.add_argument('--n_valid', type=int, default=374) 114 | 115 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 116 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 117 | args = parser.parse_args() 118 | print(args) 119 | 120 | if args.server_ip and args.server_port: 121 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 122 | import ptvsd 123 | print("Waiting for debugger attach") 124 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 125 | ptvsd.wait_for_attach() 126 | 127 | random.seed(args.seed) 128 | np.random.seed(args.seed) 129 | torch.manual_seed(args.seed) 130 | torch.cuda.manual_seed_all(args.seed) 131 | 132 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 133 | n_gpu = torch.cuda.device_count() 134 | logger.info("device: {}, n_gpu {}".format(device, n_gpu)) 135 | 136 | if not args.do_train and not args.do_eval: 137 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 138 | 139 | if not os.path.exists(args.output_dir): 140 | os.makedirs(args.output_dir) 141 | 142 | # Load tokenizer and model 143 | # This loading functions also add new tokens and embeddings called `special tokens` 144 | # These new embeddings will be fine-tuned on the RocStories dataset 145 | special_tokens = ['_start_', '_delimiter_', '_classify_'] 146 | tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name, special_tokens=special_tokens) 147 | special_tokens_ids = list(tokenizer.convert_tokens_to_ids(token) for token in special_tokens) 148 | model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name, num_special_tokens=len(special_tokens)) 149 | model.to(device) 150 | 151 | # Load and encode the datasets 152 | if not args.train_dataset and not args.eval_dataset: 153 | roc_stories = cached_path(ROCSTORIES_URL) 154 | def tokenize_and_encode(obj): 155 | """ Tokenize and encode a nested object """ 156 | if isinstance(obj, str): 157 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) 158 | elif isinstance(obj, int): 159 | return obj 160 | return list(tokenize_and_encode(o) for o in obj) 161 | logger.info("Encoding dataset...") 162 | train_dataset = load_rocstories_dataset(args.train_dataset) 163 | eval_dataset = load_rocstories_dataset(args.eval_dataset) 164 | datasets = (train_dataset, eval_dataset) 165 | encoded_datasets = tokenize_and_encode(datasets) 166 | 167 | # Compute the max input length for the Transformer 168 | max_length = model.config.n_positions // 2 - 2 169 | input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \ 170 | for dataset in encoded_datasets for story, cont1, cont2, _ in dataset) 171 | input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model 172 | 173 | # Prepare inputs tensors and dataloaders 174 | tensor_datasets = pre_process_datasets(encoded_datasets, input_length, max_length, *special_tokens_ids) 175 | train_tensor_dataset, eval_tensor_dataset = tensor_datasets[0], tensor_datasets[1] 176 | 177 | train_data = TensorDataset(*train_tensor_dataset) 178 | train_sampler = RandomSampler(train_data) 179 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 180 | 181 | eval_data = TensorDataset(*eval_tensor_dataset) 182 | eval_sampler = SequentialSampler(eval_data) 183 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) 184 | 185 | # Prepare optimizer 186 | if args.do_train: 187 | param_optimizer = list(model.named_parameters()) 188 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 189 | optimizer_grouped_parameters = [ 190 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 191 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 192 | ] 193 | num_train_optimization_steps = len(train_data) * args.num_train_epochs // args.train_batch_size 194 | optimizer = OpenAIAdam(optimizer_grouped_parameters, 195 | lr=args.learning_rate, 196 | warmup=args.warmup_proportion, 197 | max_grad_norm=args.max_grad_norm, 198 | weight_decay=args.weight_decay, 199 | t_total=num_train_optimization_steps) 200 | 201 | if args.do_train: 202 | nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None 203 | model.train() 204 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 205 | tr_loss = 0 206 | nb_tr_steps = 0 207 | tqdm_bar = tqdm(train_dataloader, desc="Training") 208 | for step, batch in enumerate(tqdm_bar): 209 | batch = tuple(t.to(device) for t in batch) 210 | input_ids, mc_token_ids, lm_labels, mc_labels = batch 211 | losses = model(input_ids, mc_token_ids, lm_labels, mc_labels) 212 | loss = args.lm_coef * losses[0] + losses[1] 213 | loss.backward() 214 | optimizer.step() 215 | optimizer.zero_grad() 216 | tr_loss += loss.item() 217 | exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item() 218 | nb_tr_steps += 1 219 | tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, optimizer.get_lr()[0]) 220 | 221 | # Save a trained model 222 | if args.do_train: 223 | # Save a trained model, configuration and tokenizer 224 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 225 | 226 | # If we save using the predefined names, we can load using `from_pretrained` 227 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) 228 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME) 229 | 230 | torch.save(model_to_save.state_dict(), output_model_file) 231 | model_to_save.config.to_json_file(output_config_file) 232 | tokenizer.save_vocabulary(args.output_dir) 233 | 234 | # Load a trained model and vocabulary that you have fine-tuned 235 | model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.output_dir) 236 | tokenizer = OpenAIGPTTokenizer.from_pretrained(args.output_dir) 237 | model.to(device) 238 | 239 | if args.do_eval: 240 | model.eval() 241 | eval_loss, eval_accuracy = 0, 0 242 | nb_eval_steps, nb_eval_examples = 0, 0 243 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 244 | batch = tuple(t.to(device) for t in batch) 245 | input_ids, mc_token_ids, lm_labels, mc_labels = batch 246 | with torch.no_grad(): 247 | _, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels) 248 | _, mc_logits = model(input_ids, mc_token_ids) 249 | 250 | mc_logits = mc_logits.detach().cpu().numpy() 251 | mc_labels = mc_labels.to('cpu').numpy() 252 | tmp_eval_accuracy = accuracy(mc_logits, mc_labels) 253 | 254 | eval_loss += mc_loss.mean().item() 255 | eval_accuracy += tmp_eval_accuracy 256 | 257 | nb_eval_examples += input_ids.size(0) 258 | nb_eval_steps += 1 259 | 260 | eval_loss = eval_loss / nb_eval_steps 261 | eval_accuracy = eval_accuracy / nb_eval_examples 262 | train_loss = tr_loss/nb_tr_steps if args.do_train else None 263 | result = {'eval_loss': eval_loss, 264 | 'eval_accuracy': eval_accuracy, 265 | 'train_loss': train_loss} 266 | 267 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 268 | with open(output_eval_file, "w") as writer: 269 | logger.info("***** Eval results *****") 270 | for key in sorted(result.keys()): 271 | logger.info(" %s = %s", key, str(result[key])) 272 | writer.write("%s = %s\n" % (key, str(result[key]))) 273 | 274 | if __name__ == '__main__': 275 | main() 276 | -------------------------------------------------------------------------------- /examples/run_transfo_xl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ PyTorch Transformer XL model evaluation script. 17 | Adapted from https://github.com/kimiyoung/transformer-xl. 18 | In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py 19 | 20 | This script with default values evaluates a pretrained Transformer-XL on WikiText 103 21 | """ 22 | from __future__ import absolute_import, division, print_function, unicode_literals 23 | 24 | import argparse 25 | import logging 26 | import time 27 | import math 28 | 29 | import torch 30 | 31 | from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer 32 | 33 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 34 | datefmt = '%m/%d/%Y %H:%M:%S', 35 | level = logging.INFO) 36 | logger = logging.getLogger(__name__) 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 40 | parser.add_argument('--model_name', type=str, default='transfo-xl-wt103', 41 | help='pretrained model name') 42 | parser.add_argument('--split', type=str, default='test', 43 | choices=['all', 'valid', 'test'], 44 | help='which split to evaluate') 45 | parser.add_argument('--batch_size', type=int, default=10, 46 | help='batch size') 47 | parser.add_argument('--tgt_len', type=int, default=128, 48 | help='number of tokens to predict') 49 | parser.add_argument('--ext_len', type=int, default=0, 50 | help='length of the extended context') 51 | parser.add_argument('--mem_len', type=int, default=1600, 52 | help='length of the retained previous heads') 53 | parser.add_argument('--clamp_len', type=int, default=1000, 54 | help='max positional embedding index') 55 | parser.add_argument('--no_cuda', action='store_true', 56 | help='Do not use CUDA even though CUA is available') 57 | parser.add_argument('--work_dir', type=str, required=True, 58 | help='path to the work_dir') 59 | parser.add_argument('--no_log', action='store_true', 60 | help='do not log the eval result') 61 | parser.add_argument('--same_length', action='store_true', 62 | help='set same length attention with masking') 63 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 64 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 65 | args = parser.parse_args() 66 | assert args.ext_len >= 0, 'extended context length must be non-negative' 67 | 68 | if args.server_ip and args.server_port: 69 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 70 | import ptvsd 71 | print("Waiting for debugger attach") 72 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 73 | ptvsd.wait_for_attach() 74 | 75 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 76 | logger.info("device: {}".format(device)) 77 | 78 | # Load a pre-processed dataset 79 | # You can also build the corpus yourself using TransfoXLCorpus methods 80 | # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax 81 | # and tokenizing the dataset 82 | # The pre-processed corpus is a convertion (using the conversion script ) 83 | tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name) 84 | corpus = TransfoXLCorpus.from_pretrained(args.model_name) 85 | ntokens = len(corpus.vocab) 86 | 87 | va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, 88 | device=device, ext_len=args.ext_len) 89 | te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, 90 | device=device, ext_len=args.ext_len) 91 | 92 | # Load a pre-trained model 93 | model = TransfoXLLMHeadModel.from_pretrained(args.model_name) 94 | model = model.to(device) 95 | 96 | logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( 97 | args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) 98 | 99 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 100 | if args.clamp_len > 0: 101 | model.clamp_len = args.clamp_len 102 | if args.same_length: 103 | model.same_length = True 104 | 105 | ############################################################################### 106 | # Evaluation code 107 | ############################################################################### 108 | def evaluate(eval_iter): 109 | # Turn on evaluation mode which disables dropout. 110 | model.eval() 111 | total_len, total_loss = 0, 0. 112 | start_time = time.time() 113 | with torch.no_grad(): 114 | mems = None 115 | for idx, (data, target, seq_len) in enumerate(eval_iter): 116 | ret = model(data, target, mems) 117 | loss, mems = ret 118 | loss = loss.mean() 119 | total_loss += seq_len * loss.item() 120 | total_len += seq_len 121 | total_time = time.time() - start_time 122 | logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format( 123 | total_time, 1000 * total_time / (idx+1))) 124 | return total_loss / total_len 125 | 126 | # Run on test data. 127 | if args.split == 'all': 128 | test_loss = evaluate(te_iter) 129 | valid_loss = evaluate(va_iter) 130 | elif args.split == 'valid': 131 | valid_loss = evaluate(va_iter) 132 | test_loss = None 133 | elif args.split == 'test': 134 | test_loss = evaluate(te_iter) 135 | valid_loss = None 136 | 137 | def format_log(loss, split): 138 | log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( 139 | split, loss, math.exp(loss)) 140 | return log_str 141 | 142 | log_str = '' 143 | if valid_loss is not None: 144 | log_str += format_log(valid_loss, 'valid') 145 | if test_loss is not None: 146 | log_str += format_log(test_loss, 'test') 147 | 148 | logger.info('=' * 100) 149 | logger.info(log_str) 150 | logger.info('=' * 100) 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_pretrained_bert.tokenization import BertTokenizer 2 | from pytorch_pretrained_bert.modeling import ( 3 | BertModel, 4 | BertForNextSentencePrediction, 5 | BertForMaskedLM, 6 | BertForMultipleChoice, 7 | BertForPreTraining, 8 | BertForQuestionAnswering, 9 | BertForSequenceClassification, 10 | BertForTokenClassification, 11 | ) 12 | 13 | dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex'] 14 | 15 | # A lot of models share the same param doc. Use a decorator 16 | # to save typing 17 | bert_docstring = """ 18 | Params: 19 | pretrained_model_name_or_path: either: 20 | - a str with the name of a pre-trained model to load 21 | . `bert-base-uncased` 22 | . `bert-large-uncased` 23 | . `bert-base-cased` 24 | . `bert-large-cased` 25 | . `bert-base-multilingual-uncased` 26 | . `bert-base-multilingual-cased` 27 | . `bert-base-chinese` 28 | - a path or url to a pretrained model archive containing: 29 | . `bert_config.json` a configuration file for the model 30 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining 31 | instance 32 | - a path or url to a pretrained model archive containing: 33 | . `bert_config.json` a configuration file for the model 34 | . `model.chkpt` a TensorFlow checkpoint 35 | from_tf: should we load the weights from a locally saved TensorFlow 36 | checkpoint 37 | cache_dir: an optional path to a folder in which the pre-trained models 38 | will be cached. 39 | state_dict: an optional state dictionnary 40 | (collections.OrderedDict object) to use instead of Google 41 | pre-trained models 42 | *inputs, **kwargs: additional input for the specific Bert class 43 | (ex: num_labels for BertForSequenceClassification) 44 | """ 45 | 46 | 47 | def _append_from_pretrained_docstring(docstr): 48 | def docstring_decorator(fn): 49 | fn.__doc__ = fn.__doc__ + docstr 50 | return fn 51 | return docstring_decorator 52 | 53 | 54 | def bertTokenizer(*args, **kwargs): 55 | """ 56 | Instantiate a BertTokenizer from a pre-trained/customized vocab file 57 | Args: 58 | pretrained_model_name_or_path: Path to pretrained model archive 59 | or one of pre-trained vocab configs below. 60 | * bert-base-uncased 61 | * bert-large-uncased 62 | * bert-base-cased 63 | * bert-large-cased 64 | * bert-base-multilingual-uncased 65 | * bert-base-multilingual-cased 66 | * bert-base-chinese 67 | Keyword args: 68 | cache_dir: an optional path to a specific directory to download and cache 69 | the pre-trained model weights. 70 | Default: None 71 | do_lower_case: Whether to lower case the input. 72 | Only has an effect when do_wordpiece_only=False 73 | Default: True 74 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 75 | Default: True 76 | max_len: An artificial maximum length to truncate tokenized sequences to; 77 | Effective maximum length is always the minimum of this 78 | value (if specified) and the underlying BERT model's 79 | sequence length. 80 | Default: None 81 | never_split: List of tokens which will never be split during tokenization. 82 | Only has an effect when do_wordpiece_only=False 83 | Default: ["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"] 84 | 85 | Example: 86 | >>> sentence = 'Hello, World!' 87 | >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) 88 | >>> toks = tokenizer.tokenize(sentence) 89 | ['Hello', '##,', 'World', '##!'] 90 | >>> ids = tokenizer.convert_tokens_to_ids(toks) 91 | [8667, 28136, 1291, 28125] 92 | """ 93 | tokenizer = BertTokenizer.from_pretrained(*args, **kwargs) 94 | return tokenizer 95 | 96 | 97 | @_append_from_pretrained_docstring(bert_docstring) 98 | def bertModel(*args, **kwargs): 99 | """ 100 | BertModel is the basic BERT Transformer model with a layer of summed token, 101 | position and sequence embeddings followed by a series of identical 102 | self-attention blocks (12 for BERT-base, 24 for BERT-large). 103 | """ 104 | model = BertModel.from_pretrained(*args, **kwargs) 105 | return model 106 | 107 | 108 | @_append_from_pretrained_docstring(bert_docstring) 109 | def bertForNextSentencePrediction(*args, **kwargs): 110 | """ 111 | BERT model with next sentence prediction head. 112 | This module comprises the BERT model followed by the next sentence 113 | classification head. 114 | """ 115 | model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs) 116 | return model 117 | 118 | 119 | @_append_from_pretrained_docstring(bert_docstring) 120 | def bertForPreTraining(*args, **kwargs): 121 | """ 122 | BERT model with pre-training heads. 123 | This module comprises the BERT model followed by the two pre-training heads 124 | - the masked language modeling head, and 125 | - the next sentence classification head. 126 | """ 127 | model = BertForPreTraining.from_pretrained(*args, **kwargs) 128 | return model 129 | 130 | 131 | @_append_from_pretrained_docstring(bert_docstring) 132 | def bertForMaskedLM(*args, **kwargs): 133 | """ 134 | BertForMaskedLM includes the BertModel Transformer followed by the 135 | (possibly) pre-trained masked language modeling head. 136 | """ 137 | model = BertForMaskedLM.from_pretrained(*args, **kwargs) 138 | return model 139 | 140 | 141 | @_append_from_pretrained_docstring(bert_docstring) 142 | def bertForSequenceClassification(*args, **kwargs): 143 | """ 144 | BertForSequenceClassification is a fine-tuning model that includes 145 | BertModel and a sequence-level (sequence or pair of sequences) classifier 146 | on top of the BertModel. 147 | 148 | The sequence-level classifier is a linear layer that takes as input the 149 | last hidden state of the first character in the input sequence 150 | (see Figures 3a and 3b in the BERT paper). 151 | """ 152 | model = BertForSequenceClassification.from_pretrained(*args, **kwargs) 153 | return model 154 | 155 | 156 | @_append_from_pretrained_docstring(bert_docstring) 157 | def bertForMultipleChoice(*args, **kwargs): 158 | """ 159 | BertForMultipleChoice is a fine-tuning model that includes BertModel and a 160 | linear layer on top of the BertModel. 161 | """ 162 | model = BertForMultipleChoice.from_pretrained(*args, **kwargs) 163 | return model 164 | 165 | 166 | @_append_from_pretrained_docstring(bert_docstring) 167 | def bertForQuestionAnswering(*args, **kwargs): 168 | """ 169 | BertForQuestionAnswering is a fine-tuning model that includes BertModel 170 | with a token-level classifiers on top of the full sequence of last hidden 171 | states. 172 | """ 173 | model = BertForQuestionAnswering.from_pretrained(*args, **kwargs) 174 | return model 175 | 176 | 177 | @_append_from_pretrained_docstring(bert_docstring) 178 | def bertForTokenClassification(*args, **kwargs): 179 | """ 180 | BertForTokenClassification is a fine-tuning model that includes BertModel 181 | and a token-level classifier on top of the BertModel. 182 | 183 | The token-level classifier is a linear layer that takes as input the last 184 | hidden state of the sequence. 185 | """ 186 | model = BertForTokenClassification.from_pretrained(*args, **kwargs) 187 | return model 188 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.2" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME 25 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from torch.hub import _get_torch_home 27 | torch_cache_home = _get_torch_home() 28 | except ImportError: 29 | torch_cache_home = os.path.expanduser( 30 | os.getenv('TORCH_HOME', os.path.join( 31 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 32 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert') 33 | 34 | try: 35 | from urllib.parse import urlparse 36 | except ImportError: 37 | from urlparse import urlparse 38 | 39 | try: 40 | from pathlib import Path 41 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 42 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) 43 | except (AttributeError, ImportError): 44 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 45 | default_cache_path) 46 | 47 | CONFIG_NAME = "config.json" 48 | WEIGHTS_NAME = "pytorch_model.bin" 49 | 50 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 51 | 52 | 53 | def url_to_filename(url, etag=None): 54 | """ 55 | Convert `url` into a hashed filename in a repeatable way. 56 | If `etag` is specified, append its hash to the url's, delimited 57 | by a period. 58 | """ 59 | url_bytes = url.encode('utf-8') 60 | url_hash = sha256(url_bytes) 61 | filename = url_hash.hexdigest() 62 | 63 | if etag: 64 | etag_bytes = etag.encode('utf-8') 65 | etag_hash = sha256(etag_bytes) 66 | filename += '.' + etag_hash.hexdigest() 67 | 68 | return filename 69 | 70 | 71 | def filename_to_url(filename, cache_dir=None): 72 | """ 73 | Return the url and etag (which may be ``None``) stored for `filename`. 74 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 75 | """ 76 | if cache_dir is None: 77 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 78 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 79 | cache_dir = str(cache_dir) 80 | 81 | cache_path = os.path.join(cache_dir, filename) 82 | if not os.path.exists(cache_path): 83 | raise EnvironmentError("file {} not found".format(cache_path)) 84 | 85 | meta_path = cache_path + '.json' 86 | if not os.path.exists(meta_path): 87 | raise EnvironmentError("file {} not found".format(meta_path)) 88 | 89 | with open(meta_path, encoding="utf-8") as meta_file: 90 | metadata = json.load(meta_file) 91 | url = metadata['url'] 92 | etag = metadata['etag'] 93 | 94 | return url, etag 95 | 96 | 97 | def cached_path(url_or_filename, cache_dir=None): 98 | """ 99 | Given something that might be a URL (or might be a local path), 100 | determine which. If it's a URL, download the file and cache it, and 101 | return the path to the cached file. If it's already a local path, 102 | make sure the file exists and then return the path. 103 | """ 104 | if cache_dir is None: 105 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 106 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 107 | url_or_filename = str(url_or_filename) 108 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 109 | cache_dir = str(cache_dir) 110 | 111 | parsed = urlparse(url_or_filename) 112 | 113 | if parsed.scheme in ('http', 'https', 's3'): 114 | # URL, so get it from the cache (downloading if necessary) 115 | return get_from_cache(url_or_filename, cache_dir) 116 | elif os.path.exists(url_or_filename): 117 | # File, and it exists. 118 | return url_or_filename 119 | elif parsed.scheme == '': 120 | # File, but it doesn't exist. 121 | raise EnvironmentError("file {} not found".format(url_or_filename)) 122 | else: 123 | # Something unknown 124 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 125 | 126 | 127 | def split_s3_path(url): 128 | """Split a full s3 path into the bucket name and path.""" 129 | parsed = urlparse(url) 130 | if not parsed.netloc or not parsed.path: 131 | raise ValueError("bad s3 path {}".format(url)) 132 | bucket_name = parsed.netloc 133 | s3_path = parsed.path 134 | # Remove '/' at beginning of path. 135 | if s3_path.startswith("/"): 136 | s3_path = s3_path[1:] 137 | return bucket_name, s3_path 138 | 139 | 140 | def s3_request(func): 141 | """ 142 | Wrapper function for s3 requests in order to create more helpful error 143 | messages. 144 | """ 145 | 146 | @wraps(func) 147 | def wrapper(url, *args, **kwargs): 148 | try: 149 | return func(url, *args, **kwargs) 150 | except ClientError as exc: 151 | if int(exc.response["Error"]["Code"]) == 404: 152 | raise EnvironmentError("file {} not found".format(url)) 153 | else: 154 | raise 155 | 156 | return wrapper 157 | 158 | 159 | @s3_request 160 | def s3_etag(url): 161 | """Check ETag on S3 object.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_object = s3_resource.Object(bucket_name, s3_path) 165 | return s3_object.e_tag 166 | 167 | 168 | @s3_request 169 | def s3_get(url, temp_file): 170 | """Pull a file directly from S3.""" 171 | s3_resource = boto3.resource("s3") 172 | bucket_name, s3_path = split_s3_path(url) 173 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 174 | 175 | 176 | def http_get(url, temp_file): 177 | req = requests.get(url, stream=True) 178 | content_length = req.headers.get('Content-Length') 179 | total = int(content_length) if content_length is not None else None 180 | progress = tqdm(unit="B", total=total) 181 | for chunk in req.iter_content(chunk_size=1024): 182 | if chunk: # filter out keep-alive new chunks 183 | progress.update(len(chunk)) 184 | temp_file.write(chunk) 185 | progress.close() 186 | 187 | 188 | def get_from_cache(url, cache_dir=None): 189 | """ 190 | Given a URL, look for the corresponding dataset in the local cache. 191 | If it's not there, download it. Then return the path to the cached file. 192 | """ 193 | if cache_dir is None: 194 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 195 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 196 | cache_dir = str(cache_dir) 197 | 198 | if not os.path.exists(cache_dir): 199 | os.makedirs(cache_dir) 200 | 201 | # Get eTag to add to filename, if it exists. 202 | if url.startswith("s3://"): 203 | etag = s3_etag(url) 204 | else: 205 | try: 206 | response = requests.head(url, allow_redirects=True) 207 | if response.status_code != 200: 208 | etag = None 209 | else: 210 | etag = response.headers.get("ETag") 211 | except EnvironmentError: 212 | etag = None 213 | 214 | if sys.version_info[0] == 2 and etag is not None: 215 | etag = etag.decode('utf-8') 216 | filename = url_to_filename(url, etag) 217 | 218 | # get cache path to put the file 219 | cache_path = os.path.join(cache_dir, filename) 220 | 221 | # If we don't have a connection (etag is None) and can't identify the file 222 | # try to get the last downloaded one 223 | if not os.path.exists(cache_path) and etag is None: 224 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 225 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 226 | if matching_files: 227 | cache_path = os.path.join(cache_dir, matching_files[-1]) 228 | 229 | if not os.path.exists(cache_path): 230 | # Download to temporary file, then copy to cache dir once finished. 231 | # Otherwise you get corrupt cache entries if the download gets interrupted. 232 | with tempfile.NamedTemporaryFile() as temp_file: 233 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 234 | 235 | # GET file object 236 | if url.startswith("s3://"): 237 | s3_get(url, temp_file) 238 | else: 239 | http_get(url, temp_file) 240 | 241 | # we are copying the file before closing it, so flush to avoid truncation 242 | temp_file.flush() 243 | # shutil.copyfileobj() starts at the current position, so go to the start 244 | temp_file.seek(0) 245 | 246 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 247 | with open(cache_path, 'wb') as cache_file: 248 | shutil.copyfileobj(temp_file, cache_file) 249 | 250 | logger.info("creating metadata file for %s", cache_path) 251 | meta = {'url': url, 'etag': etag} 252 | meta_path = cache_path + '.json' 253 | with open(meta_path, 'w') as meta_file: 254 | output_string = json.dumps(meta) 255 | if sys.version_info[0] == 2 and isinstance(output_string, str): 256 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 257 | meta_file.write(output_string) 258 | 259 | logger.info("removing temp file %s", temp_file.name) 260 | 261 | return cache_path 262 | 263 | 264 | def read_set_from_file(filename): 265 | ''' 266 | Extract a de-duped collection (set) of text from a file. 267 | Expected file format is one item per line. 268 | ''' 269 | collection = set() 270 | with open(filename, 'r', encoding='utf-8') as file_: 271 | for line in file_: 272 | collection.add(line.rstrip()) 273 | return collection 274 | 275 | 276 | def get_file_extension(path, dot=True, lower=True): 277 | ext = os.path.splitext(path)[1] 278 | ext = ext if dot else ext[1:] 279 | return ext.lower() if lower else ext 280 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | import abc 24 | import sys 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | if sys.version_info >= (3, 4): 30 | ABC = abc.ABC 31 | else: 32 | ABC = abc.ABCMeta('ABC', (), {}) 33 | 34 | 35 | class _LRSchedule(ABC): 36 | """ Parent of all LRSchedules here. """ 37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense 38 | def __init__(self, warmup=0.002, t_total=-1, **kw): 39 | """ 40 | :param warmup: what fraction of t_total steps will be used for linear warmup 41 | :param t_total: how many training steps (updates) are planned 42 | :param kw: 43 | """ 44 | super(_LRSchedule, self).__init__(**kw) 45 | if t_total < 0: 46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) 47 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 49 | warmup = max(warmup, 0.) 50 | self.warmup, self.t_total = float(warmup), float(t_total) 51 | self.warned_for_t_total_at_progress = -1 52 | 53 | def get_lr(self, step, nowarn=False): 54 | """ 55 | :param step: which of t_total steps we're on 56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps 57 | :return: learning rate multiplier for current update 58 | """ 59 | if self.t_total < 0: 60 | return 1. 61 | progress = float(step) / self.t_total 62 | ret = self.get_lr_(progress) 63 | # warning for exceeding t_total (only active with warmup_linear 64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: 65 | logger.warning( 66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." 67 | .format(ret, self.__class__.__name__)) 68 | self.warned_for_t_total_at_progress = progress 69 | # end warning 70 | return ret 71 | 72 | @abc.abstractmethod 73 | def get_lr_(self, progress): 74 | """ 75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress 76 | :return: learning rate multiplier for current update 77 | """ 78 | return 1. 79 | 80 | 81 | class ConstantLR(_LRSchedule): 82 | def get_lr_(self, progress): 83 | return 1. 84 | 85 | 86 | class WarmupCosineSchedule(_LRSchedule): 87 | """ 88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. 90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 91 | """ 92 | warn_t_total = True 93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): 94 | """ 95 | :param warmup: see LRSchedule 96 | :param t_total: see LRSchedule 97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. 98 | :param kw: 99 | """ 100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) 101 | self.cycles = cycles 102 | 103 | def get_lr_(self, progress): 104 | if progress < self.warmup: 105 | return progress / self.warmup 106 | else: 107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) 109 | 110 | 111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): 112 | """ 113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 115 | learning rate (with hard restarts). 116 | """ 117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 119 | assert(cycles >= 1.) 120 | 121 | def get_lr_(self, progress): 122 | if progress < self.warmup: 123 | return progress / self.warmup 124 | else: 125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) 127 | return ret 128 | 129 | 130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): 131 | """ 132 | All training progress is divided in `cycles` (default=1.) parts of equal length. 133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., 134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve. 135 | """ 136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): 137 | assert(warmup * cycles < 1.) 138 | warmup = warmup * cycles if warmup >= 0 else warmup 139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) 140 | 141 | def get_lr_(self, progress): 142 | progress = progress * self.cycles % 1. 143 | if progress < self.warmup: 144 | return progress / self.warmup 145 | else: 146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup 147 | ret = 0.5 * (1. + math.cos(math.pi * progress)) 148 | return ret 149 | 150 | 151 | class WarmupConstantSchedule(_LRSchedule): 152 | """ 153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 154 | Keeps learning rate equal to 1. after warmup. 155 | """ 156 | def get_lr_(self, progress): 157 | if progress < self.warmup: 158 | return progress / self.warmup 159 | return 1. 160 | 161 | 162 | class WarmupLinearSchedule(_LRSchedule): 163 | """ 164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. 165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. 166 | """ 167 | warn_t_total = True 168 | def get_lr_(self, progress): 169 | if progress < self.warmup: 170 | return progress / self.warmup 171 | return max((progress - 1.) / (self.warmup - 1.), 0.) 172 | 173 | 174 | SCHEDULES = { 175 | None: ConstantLR, 176 | "none": ConstantLR, 177 | "warmup_cosine": WarmupCosineSchedule, 178 | "warmup_constant": WarmupConstantSchedule, 179 | "warmup_linear": WarmupLinearSchedule 180 | } 181 | 182 | 183 | class BertAdam(Optimizer): 184 | """Implements BERT version of Adam algorithm with weight decay fix. 185 | Params: 186 | lr: learning rate 187 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 188 | t_total: total number of training steps for the learning 189 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 190 | schedule: schedule to use for the warmup (see above). 191 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). 192 | If `None` or `'none'`, learning rate is always kept constant. 193 | Default : `'warmup_linear'` 194 | b1: Adams b1. Default: 0.9 195 | b2: Adams b2. Default: 0.999 196 | e: Adams epsilon. Default: 1e-6 197 | weight_decay: Weight decay. Default: 0.01 198 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 199 | """ 200 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 201 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): 202 | if lr is not required and lr < 0.0: 203 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 204 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 205 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 206 | if not 0.0 <= b1 < 1.0: 207 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 208 | if not 0.0 <= b2 < 1.0: 209 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 210 | if not e >= 0.0: 211 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 212 | # initialize schedule object 213 | if not isinstance(schedule, _LRSchedule): 214 | schedule_type = SCHEDULES[schedule] 215 | schedule = schedule_type(warmup=warmup, t_total=t_total) 216 | else: 217 | if warmup != -1 or t_total != -1: 218 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 219 | "Please specify custom warmup and t_total in _LRSchedule object.") 220 | defaults = dict(lr=lr, schedule=schedule, 221 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 222 | max_grad_norm=max_grad_norm) 223 | super(BertAdam, self).__init__(params, defaults) 224 | 225 | def get_lr(self): 226 | lr = [] 227 | for group in self.param_groups: 228 | for p in group['params']: 229 | state = self.state[p] 230 | if len(state) == 0: 231 | return [0] 232 | lr_scheduled = group['lr'] 233 | lr_scheduled *= group['schedule'].get_lr(state['step']) 234 | lr.append(lr_scheduled) 235 | return lr 236 | 237 | def step(self, closure=None): 238 | """Performs a single optimization step. 239 | 240 | Arguments: 241 | closure (callable, optional): A closure that reevaluates the model 242 | and returns the loss. 243 | """ 244 | loss = None 245 | if closure is not None: 246 | loss = closure() 247 | 248 | for group in self.param_groups: 249 | for p in group['params']: 250 | if p.grad is None: 251 | continue 252 | grad = p.grad.data 253 | if grad.is_sparse: 254 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 255 | 256 | state = self.state[p] 257 | 258 | # State initialization 259 | if len(state) == 0: 260 | state['step'] = 0 261 | # Exponential moving average of gradient values 262 | state['next_m'] = torch.zeros_like(p.data) 263 | # Exponential moving average of squared gradient values 264 | state['next_v'] = torch.zeros_like(p.data) 265 | 266 | next_m, next_v = state['next_m'], state['next_v'] 267 | beta1, beta2 = group['b1'], group['b2'] 268 | 269 | # Add grad clipping 270 | if group['max_grad_norm'] > 0: 271 | clip_grad_norm_(p, group['max_grad_norm']) 272 | 273 | # Decay the first and second moment running average coefficient 274 | # In-place operations to update the averages at the same time 275 | next_m.mul_(beta1).add_(1 - beta1, grad) 276 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 277 | update = next_m / (next_v.sqrt() + group['e']) 278 | 279 | # Just adding the square of the weights to the loss function is *not* 280 | # the correct way of using L2 regularization/weight decay with Adam, 281 | # since that will interact with the m and v parameters in strange ways. 282 | # 283 | # Instead we want to decay the weights in a manner that doesn't interact 284 | # with the m/v parameters. This is equivalent to adding the square 285 | # of the weights to the loss with plain (non-momentum) SGD. 286 | if group['weight_decay'] > 0.0: 287 | update += group['weight_decay'] * p.data 288 | 289 | lr_scheduled = group['lr'] 290 | lr_scheduled *= group['schedule'].get_lr(state['step']) 291 | 292 | update_with_lr = lr_scheduled * update 293 | p.data.add_(-update_with_lr) 294 | 295 | state['step'] += 1 296 | 297 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 298 | # No bias correction 299 | # bias_correction1 = 1 - beta1 ** state['step'] 300 | # bias_correction2 = 1 - beta2 ** state['step'] 301 | 302 | return loss 303 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ 24 | WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class OpenAIAdam(Optimizer): 30 | """Implements Open AI version of Adam algorithm with weight decay fix. 31 | """ 32 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 33 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 34 | vector_l2=False, max_grad_norm=-1, **kwargs): 35 | if lr is not required and lr < 0.0: 36 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 37 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: 38 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 39 | if not 0.0 <= b1 < 1.0: 40 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 41 | if not 0.0 <= b2 < 1.0: 42 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 43 | if not e >= 0.0: 44 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 45 | # initialize schedule object 46 | if not isinstance(schedule, _LRSchedule): 47 | schedule_type = SCHEDULES[schedule] 48 | schedule = schedule_type(warmup=warmup, t_total=t_total) 49 | else: 50 | if warmup != -1 or t_total != -1: 51 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " 52 | "Please specify custom warmup and t_total in _LRSchedule object.") 53 | defaults = dict(lr=lr, schedule=schedule, 54 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 55 | max_grad_norm=max_grad_norm) 56 | super(OpenAIAdam, self).__init__(params, defaults) 57 | 58 | def get_lr(self): 59 | lr = [] 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | state = self.state[p] 63 | if len(state) == 0: 64 | return [0] 65 | lr_scheduled = group['lr'] 66 | lr_scheduled *= group['schedule'].get_lr(state['step']) 67 | lr.append(lr_scheduled) 68 | return lr 69 | 70 | def step(self, closure=None): 71 | """Performs a single optimization step. 72 | 73 | Arguments: 74 | closure (callable, optional): A closure that reevaluates the model 75 | and returns the loss. 76 | """ 77 | loss = None 78 | if closure is not None: 79 | loss = closure() 80 | 81 | for group in self.param_groups: 82 | for p in group['params']: 83 | if p.grad is None: 84 | continue 85 | grad = p.grad.data 86 | if grad.is_sparse: 87 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 88 | 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | # Exponential moving average of gradient values 95 | state['exp_avg'] = torch.zeros_like(p.data) 96 | # Exponential moving average of squared gradient values 97 | state['exp_avg_sq'] = torch.zeros_like(p.data) 98 | 99 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 100 | beta1, beta2 = group['b1'], group['b2'] 101 | 102 | state['step'] += 1 103 | 104 | # Add grad clipping 105 | if group['max_grad_norm'] > 0: 106 | clip_grad_norm_(p, group['max_grad_norm']) 107 | 108 | # Decay the first and second moment running average coefficient 109 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 110 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 111 | denom = exp_avg_sq.sqrt().add_(group['e']) 112 | 113 | bias_correction1 = 1 - beta1 ** state['step'] 114 | bias_correction2 = 1 - beta2 ** state['step'] 115 | 116 | lr_scheduled = group['lr'] 117 | lr_scheduled *= group['schedule'].get_lr(state['step']) 118 | 119 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 120 | 121 | p.data.addcdiv_(-step_size, exp_avg, denom) 122 | 123 | # Add weight decay at the end (fixed version) 124 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 125 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 126 | 127 | return loss 128 | -------------------------------------------------------------------------------- /pytorch_pretrained_bert/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .file_utils import cached_path 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 39 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 40 | } 41 | PRETRAINED_MERGES_ARCHIVE_MAP = { 42 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 43 | } 44 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 45 | 'gpt2': 1024, 46 | } 47 | VOCAB_NAME = 'vocab.json' 48 | MERGES_NAME = 'merges.txt' 49 | SPECIAL_TOKENS_NAME = 'special_tokens.txt' 50 | 51 | @lru_cache() 52 | def bytes_to_unicode(): 53 | """ 54 | Returns list of utf-8 byte and a corresponding list of unicode strings. 55 | The reversible bpe codes work on unicode strings. 56 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 57 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 58 | This is a signficant percentage of your normal, say, 32K bpe vocab. 59 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 60 | And avoids mapping to whitespace/control characters the bpe code barfs on. 61 | """ 62 | _chr = unichr if sys.version_info[0] == 2 else chr 63 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 64 | cs = bs[:] 65 | n = 0 66 | for b in range(2**8): 67 | if b not in bs: 68 | bs.append(b) 69 | cs.append(2**8+n) 70 | n += 1 71 | cs = [_chr(n) for n in cs] 72 | return dict(zip(bs, cs)) 73 | 74 | def get_pairs(word): 75 | """Return set of symbol pairs in a word. 76 | 77 | Word is represented as tuple of symbols (symbols being variable-length strings). 78 | """ 79 | pairs = set() 80 | prev_char = word[0] 81 | for char in word[1:]: 82 | pairs.add((prev_char, char)) 83 | prev_char = char 84 | return pairs 85 | 86 | class GPT2Tokenizer(object): 87 | """ 88 | GPT-2 BPE tokenizer. Peculiarities: 89 | - Byte-level BPE 90 | """ 91 | @classmethod 92 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 93 | """ 94 | Instantiate a PreTrainedBertModel from a pre-trained model file. 95 | Download and cache the pre-trained model file if needed. 96 | """ 97 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 98 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 99 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 100 | special_tokens_file = None 101 | else: 102 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 103 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 104 | special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) 105 | if not os.path.exists(special_tokens_file): 106 | special_tokens_file = None 107 | else: 108 | logger.info("loading special tokens file {}".format(special_tokens_file)) 109 | # redirect to the cache, if necessary 110 | try: 111 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 112 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 113 | except EnvironmentError: 114 | logger.error( 115 | "Model name '{}' was not found in model name list ({}). " 116 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 117 | "at this path or url.".format( 118 | pretrained_model_name_or_path, 119 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 120 | pretrained_model_name_or_path, 121 | vocab_file, merges_file)) 122 | return None 123 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 124 | logger.info("loading vocabulary file {}".format(vocab_file)) 125 | logger.info("loading merges file {}".format(merges_file)) 126 | else: 127 | logger.info("loading vocabulary file {} from cache at {}".format( 128 | vocab_file, resolved_vocab_file)) 129 | logger.info("loading merges file {} from cache at {}".format( 130 | merges_file, resolved_merges_file)) 131 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 132 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 133 | # than the number of positional embeddings 134 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 135 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 136 | # Instantiate tokenizer. 137 | if special_tokens_file and 'special_tokens' not in kwargs: 138 | special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] 139 | else: 140 | special_tokens = kwargs.pop('special_tokens', []) 141 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) 142 | return tokenizer 143 | 144 | def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): 145 | self.max_len = max_len if max_len is not None else int(1e12) 146 | self.encoder = json.load(open(vocab_file)) 147 | self.decoder = {v:k for k,v in self.encoder.items()} 148 | self.errors = errors # how to handle errors in decoding 149 | self.byte_encoder = bytes_to_unicode() 150 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 151 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 152 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 153 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 154 | self.cache = {} 155 | 156 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 157 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 158 | 159 | self.special_tokens = {} 160 | self.special_tokens_decoder = {} 161 | self.set_special_tokens(special_tokens) 162 | 163 | def __len__(self): 164 | return len(self.encoder) + len(self.special_tokens) 165 | 166 | def set_special_tokens(self, special_tokens): 167 | """ Add a list of additional tokens to the encoder. 168 | The additional tokens are indexed starting from the last index of the 169 | current vocabulary in the order of the `special_tokens` list. 170 | """ 171 | if not special_tokens: 172 | self.special_tokens = {} 173 | self.special_tokens_decoder = {} 174 | return 175 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 176 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 177 | logger.info("Special tokens {}".format(self.special_tokens)) 178 | 179 | def bpe(self, token): 180 | if token in self.cache: 181 | return self.cache[token] 182 | word = tuple(token) 183 | pairs = get_pairs(word) 184 | 185 | if not pairs: 186 | return token 187 | 188 | while True: 189 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 190 | if bigram not in self.bpe_ranks: 191 | break 192 | first, second = bigram 193 | new_word = [] 194 | i = 0 195 | while i < len(word): 196 | try: 197 | j = word.index(first, i) 198 | new_word.extend(word[i:j]) 199 | i = j 200 | except: 201 | new_word.extend(word[i:]) 202 | break 203 | 204 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 205 | new_word.append(first+second) 206 | i += 2 207 | else: 208 | new_word.append(word[i]) 209 | i += 1 210 | new_word = tuple(new_word) 211 | word = new_word 212 | if len(word) == 1: 213 | break 214 | else: 215 | pairs = get_pairs(word) 216 | word = ' '.join(word) 217 | self.cache[token] = word 218 | return word 219 | 220 | def tokenize(self, text): 221 | """ Tokenize a string. """ 222 | bpe_tokens = [] 223 | for token in re.findall(self.pat, text): 224 | if sys.version_info[0] == 2: 225 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 226 | else: 227 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 228 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 229 | return bpe_tokens 230 | 231 | def convert_tokens_to_ids(self, tokens): 232 | """ Converts a sequence of tokens into ids using the vocab. """ 233 | ids = [] 234 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 235 | if tokens in self.special_tokens: 236 | return self.special_tokens[tokens] 237 | else: 238 | return self.encoder.get(tokens, 0) 239 | for token in tokens: 240 | if token in self.special_tokens: 241 | ids.append(self.special_tokens[token]) 242 | else: 243 | ids.append(self.encoder.get(token, 0)) 244 | if len(ids) > self.max_len: 245 | logger.warning( 246 | "Token indices sequence length is longer than the specified maximum " 247 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 248 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 249 | ) 250 | return ids 251 | 252 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 253 | """Converts a sequence of ids in BPE tokens using the vocab.""" 254 | tokens = [] 255 | for i in ids: 256 | if i in self.special_tokens_decoder: 257 | if not skip_special_tokens: 258 | tokens.append(self.special_tokens_decoder[i]) 259 | else: 260 | tokens.append(self.decoder[i]) 261 | return tokens 262 | 263 | def encode(self, text): 264 | return self.convert_tokens_to_ids(self.tokenize(text)) 265 | 266 | def decode(self, tokens): 267 | text = ''.join([self.decoder[token] for token in tokens]) 268 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 269 | return text 270 | 271 | def save_vocabulary(self, vocab_path): 272 | """Save the tokenizer vocabulary and merge files to a directory.""" 273 | if not os.path.isdir(vocab_path): 274 | logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) 275 | return 276 | vocab_file = os.path.join(vocab_path, VOCAB_NAME) 277 | merge_file = os.path.join(vocab_path, MERGES_NAME) 278 | special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) 279 | 280 | with open(vocab_file, 'w', encoding='utf-8') as f: 281 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 282 | 283 | index = 0 284 | with open(merge_file, "w", encoding="utf-8") as writer: 285 | writer.write(u'#version: 0.2\n') 286 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 287 | if index != token_index: 288 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 289 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 290 | index = token_index 291 | writer.write(' '.join(bpe_tokens) + u'\n') 292 | index += 1 293 | 294 | index = len(self.encoder) 295 | with open(special_tokens_file, 'w', encoding='utf-8') as writer: 296 | for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): 297 | if index != token_index: 298 | logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." 299 | " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) 300 | index = token_index 301 | writer.write(token + u'\n') 302 | index += 1 303 | 304 | return vocab_file, merge_file, special_tokens_file 305 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | torch>=0.4.1 3 | # progress bars in model download and training scripts 4 | tqdm 5 | # Accessing files from S3 directly. 6 | boto3 7 | # Used for downloading models over HTTP 8 | requests 9 | # For OpenAI GPT 10 | regex -------------------------------------------------------------------------------- /samples/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /samples/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | 4 | To create the package for pypi. 5 | 6 | 1. Change the version in __init__.py and setup.py. 7 | 8 | 2. Commit these changes with the message: "Release: VERSION" 9 | 10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 11 | Push the tag to git: git push --tags origin master 12 | 13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 14 | creating the wheel and the source distribution (obviously). 15 | 16 | For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. 17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 18 | 19 | For the sources, run: "python setup.py sdist" 20 | You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. 21 | 22 | 5. Check that everything looks correct by uploading the package to the pypi test server: 23 | 24 | twine upload dist/* -r pypitest 25 | (pypi suggest using twine as other methods upload files via plaintext.) 26 | 27 | Check that you can install it in a virtualenv by running: 28 | pip install -i https://testpypi.python.org/pypi allennlp 29 | 30 | 6. Upload the final version to actual pypi: 31 | twine upload dist/* -r pypi 32 | 33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 34 | 35 | """ 36 | from io import open 37 | from setuptools import find_packages, setup 38 | 39 | setup( 40 | name="pytorch_pretrained_bert", 41 | version="0.6.2", 42 | author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors, Open AI team Authors", 43 | author_email="thomas@huggingface.co", 44 | description="PyTorch version of Google AI BERT model with script to load Google pre-trained models", 45 | long_description=open("README.md", "r", encoding='utf-8').read(), 46 | long_description_content_type="text/markdown", 47 | keywords='BERT NLP deep learning google', 48 | license='Apache', 49 | url="https://github.com/huggingface/pytorch-pretrained-BERT", 50 | packages=find_packages(exclude=["*.tests", "*.tests.*", 51 | "tests.*", "tests"]), 52 | install_requires=['torch>=0.4.1', 53 | 'numpy', 54 | 'boto3', 55 | 'requests', 56 | 'tqdm', 57 | 'regex'], 58 | entry_points={ 59 | 'console_scripts': [ 60 | "pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main", 61 | ] 62 | }, 63 | # python_requires='>=3.5.0', 64 | tests_require=['pytest'], 65 | classifiers=[ 66 | 'Intended Audience :: Science/Research', 67 | 'License :: OSI Approved :: Apache Software License', 68 | 'Programming Language :: Python :: 3', 69 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 70 | ], 71 | ) 72 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | if config.getoption("--runslow"): 14 | # --runslow given in cli: do not skip slow tests 15 | return 16 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 17 | for item in items: 18 | if "slow" in item.keywords: 19 | item.add_marker(skip_slow) 20 | -------------------------------------------------------------------------------- /tests/modeling_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import unittest 21 | import json 22 | import random 23 | import shutil 24 | import pytest 25 | 26 | import torch 27 | 28 | from pytorch_pretrained_bert import (GPT2Config, GPT2Model, 29 | GPT2LMHeadModel, GPT2DoubleHeadsModel) 30 | from pytorch_pretrained_bert.modeling_gpt2 import PRETRAINED_MODEL_ARCHIVE_MAP 31 | 32 | class GPT2ModelTest(unittest.TestCase): 33 | class GPT2ModelTester(object): 34 | 35 | def __init__(self, 36 | parent, 37 | batch_size=13, 38 | seq_length=7, 39 | is_training=True, 40 | use_position_ids=True, 41 | use_token_type_ids=True, 42 | use_labels=True, 43 | vocab_size=99, 44 | n_positions=33, 45 | n_embd=32, 46 | n_layer=5, 47 | n_head=4, 48 | n_choices=3, 49 | type_sequence_label_size=2, 50 | initializer_range=0.02, 51 | num_labels=3, 52 | scope=None): 53 | self.parent = parent 54 | self.batch_size = batch_size 55 | self.seq_length = seq_length 56 | self.is_training = is_training 57 | self.use_position_ids = use_position_ids 58 | self.use_token_type_ids = use_token_type_ids 59 | self.use_labels = use_labels 60 | self.vocab_size = vocab_size 61 | self.n_positions = n_positions 62 | self.n_embd = n_embd 63 | self.n_layer = n_layer 64 | self.n_head = n_head 65 | self.n_choices = n_choices 66 | self.type_sequence_label_size = type_sequence_label_size 67 | self.initializer_range = initializer_range 68 | self.num_labels = num_labels 69 | self.scope = scope 70 | 71 | def prepare_config_and_inputs(self): 72 | input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size) 73 | 74 | position_ids = None 75 | if self.use_position_ids: 76 | position_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions) 77 | 78 | token_type_ids = None 79 | if self.use_token_type_ids: 80 | total_voc = self.vocab_size 81 | token_type_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc) 82 | 83 | mc_labels = None 84 | lm_labels = None 85 | mc_token_ids = None 86 | if self.use_labels: 87 | mc_labels = GPT2ModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) 88 | lm_labels = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels) 89 | mc_token_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length) 90 | 91 | config = GPT2Config( 92 | vocab_size_or_config_json_file=self.vocab_size, 93 | n_positions=self.n_positions, 94 | n_embd=self.n_embd, 95 | n_layer=self.n_layer, 96 | n_head=self.n_head, 97 | initializer_range=self.initializer_range) 98 | 99 | return (config, input_ids, token_type_ids, position_ids, 100 | mc_labels, lm_labels, mc_token_ids) 101 | 102 | def create_gpt2_model(self, config, input_ids, token_type_ids, position_ids, 103 | mc_labels, lm_labels, mc_token_ids): 104 | model = GPT2Model(config) 105 | model.eval() 106 | hidden_states, presents = model(input_ids, position_ids, token_type_ids) 107 | outputs = { 108 | "hidden_states": hidden_states, 109 | "presents": presents, 110 | } 111 | return outputs 112 | 113 | def check_gpt2_model_output(self, result): 114 | self.parent.assertListEqual( 115 | list(result["hidden_states"].size()), 116 | [self.batch_size, self.n_choices, self.seq_length, self.n_embd]) 117 | 118 | 119 | def create_gpt2_lm_head(self, config, input_ids, token_type_ids, position_ids, 120 | mc_labels, lm_labels, mc_token_ids): 121 | model = GPT2LMHeadModel(config) 122 | model.eval() 123 | loss = model(input_ids, position_ids, token_type_ids, lm_labels) 124 | lm_logits, presents = model(input_ids, position_ids, token_type_ids) 125 | outputs = { 126 | "loss": loss, 127 | "lm_logits": lm_logits, 128 | "presents": presents, 129 | } 130 | return outputs 131 | 132 | def check_gpt2_lm_head_output(self, result): 133 | total_voc = self.vocab_size 134 | self.parent.assertListEqual( 135 | list(result["lm_logits"].size()), 136 | [self.batch_size, self.n_choices, self.seq_length, total_voc]) 137 | 138 | def check_gpt2_lm_head_loss_output(self, result): 139 | self.parent.assertListEqual( 140 | list(result["loss"].size()), 141 | []) 142 | 143 | def create_gpt2_double_heads(self, config, input_ids, token_type_ids, position_ids, 144 | mc_labels, lm_labels, mc_token_ids): 145 | model = GPT2DoubleHeadsModel(config) 146 | model.eval() 147 | loss = model(input_ids, mc_token_ids, 148 | lm_labels=lm_labels, mc_labels=mc_labels, 149 | token_type_ids=token_type_ids, position_ids=position_ids) 150 | lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids) 151 | outputs = { 152 | "loss": loss, 153 | "lm_logits": lm_logits, 154 | "mc_logits": mc_logits, 155 | "presents": presents, 156 | } 157 | return outputs 158 | 159 | def check_gpt2_double_heads_output(self, result): 160 | total_voc = self.vocab_size 161 | self.parent.assertListEqual( 162 | list(result["lm_logits"].size()), 163 | [self.batch_size, self.n_choices, self.seq_length, total_voc]) 164 | self.parent.assertListEqual( 165 | list(result["mc_logits"].size()), 166 | [self.batch_size, self.n_choices]) 167 | 168 | def check_gpt2_double_heads_loss_output(self, result): 169 | self.parent.assertListEqual( 170 | [list(l.size()) for l in result["loss"]], 171 | [[], []]) 172 | 173 | def test_default(self): 174 | self.run_tester(GPT2ModelTest.GPT2ModelTester(self)) 175 | 176 | def test_config_to_json_string(self): 177 | config = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37) 178 | obj = json.loads(config.to_json_string()) 179 | self.assertEqual(obj["vocab_size"], 99) 180 | self.assertEqual(obj["n_embd"], 37) 181 | 182 | def test_config_to_json_file(self): 183 | config_first = GPT2Config(vocab_size_or_config_json_file=99, n_embd=37) 184 | json_file_path = "/tmp/config.json" 185 | config_first.to_json_file(json_file_path) 186 | config_second = GPT2Config.from_json_file(json_file_path) 187 | os.remove(json_file_path) 188 | self.assertEqual(config_second.to_dict(), config_first.to_dict()) 189 | 190 | @pytest.mark.slow 191 | def test_model_from_pretrained(self): 192 | cache_dir = "/tmp/pytorch_pretrained_bert_test/" 193 | for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 194 | model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir) 195 | shutil.rmtree(cache_dir) 196 | self.assertIsNotNone(model) 197 | 198 | def run_tester(self, tester): 199 | config_and_inputs = tester.prepare_config_and_inputs() 200 | output_result = tester.create_gpt2_model(*config_and_inputs) 201 | tester.check_gpt2_model_output(output_result) 202 | 203 | output_result = tester.create_gpt2_lm_head(*config_and_inputs) 204 | tester.check_gpt2_lm_head_output(output_result) 205 | tester.check_gpt2_lm_head_loss_output(output_result) 206 | 207 | output_result = tester.create_gpt2_double_heads(*config_and_inputs) 208 | tester.check_gpt2_double_heads_output(output_result) 209 | tester.check_gpt2_double_heads_loss_output(output_result) 210 | 211 | @classmethod 212 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 213 | """Creates a random int32 tensor of the shape within the vocab size.""" 214 | if rng is None: 215 | rng = random.Random() 216 | 217 | total_dims = 1 218 | for dim in shape: 219 | total_dims *= dim 220 | 221 | values = [] 222 | for _ in range(total_dims): 223 | values.append(rng.randint(0, vocab_size - 1)) 224 | 225 | return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() 226 | 227 | 228 | if __name__ == "__main__": 229 | unittest.main() 230 | -------------------------------------------------------------------------------- /tests/modeling_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import unittest 21 | import json 22 | import random 23 | import shutil 24 | import pytest 25 | 26 | import torch 27 | 28 | from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel, 29 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) 30 | from pytorch_pretrained_bert.modeling_openai import PRETRAINED_MODEL_ARCHIVE_MAP 31 | 32 | class OpenAIGPTModelTest(unittest.TestCase): 33 | class OpenAIGPTModelTester(object): 34 | 35 | def __init__(self, 36 | parent, 37 | batch_size=13, 38 | seq_length=7, 39 | is_training=True, 40 | use_position_ids=True, 41 | use_token_type_ids=True, 42 | use_labels=True, 43 | vocab_size=99, 44 | n_special=1, 45 | n_positions=33, 46 | n_embd=32, 47 | n_layer=5, 48 | n_head=4, 49 | n_choices=3, 50 | afn="gelu", 51 | resid_pdrop=0.1, 52 | attn_pdrop=0.1, 53 | embd_pdrop=0.1, 54 | type_sequence_label_size=2, 55 | initializer_range=0.02, 56 | num_labels=3, 57 | scope=None): 58 | self.parent = parent 59 | self.batch_size = batch_size 60 | self.seq_length = seq_length 61 | self.is_training = is_training 62 | self.use_position_ids = use_position_ids 63 | self.use_token_type_ids = use_token_type_ids 64 | self.use_labels = use_labels 65 | self.vocab_size = vocab_size 66 | self.n_special = n_special 67 | self.n_positions = n_positions 68 | self.n_embd = n_embd 69 | self.n_layer = n_layer 70 | self.n_head = n_head 71 | self.afn = afn 72 | self.n_choices = n_choices 73 | self.resid_pdrop = resid_pdrop 74 | self.attn_pdrop = attn_pdrop 75 | self.embd_pdrop = embd_pdrop 76 | self.type_sequence_label_size = type_sequence_label_size 77 | self.initializer_range = initializer_range 78 | self.num_labels = num_labels 79 | self.scope = scope 80 | 81 | def prepare_config_and_inputs(self): 82 | input_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size) 83 | 84 | position_ids = None 85 | if self.use_position_ids: 86 | position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions) 87 | 88 | token_type_ids = None 89 | if self.use_token_type_ids: 90 | total_voc = self.vocab_size + self.n_special 91 | token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc) 92 | 93 | mc_labels = None 94 | lm_labels = None 95 | mc_token_ids = None 96 | if self.use_labels: 97 | mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) 98 | lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels) 99 | mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length) 100 | 101 | config = OpenAIGPTConfig( 102 | vocab_size_or_config_json_file=self.vocab_size, 103 | n_positions=self.n_positions, 104 | n_special=self.n_special, 105 | n_embd=self.n_embd, 106 | n_layer=self.n_layer, 107 | n_head=self.n_head, 108 | afn=self.afn, 109 | resid_pdrop=self.resid_pdrop, 110 | attn_pdrop=self.attn_pdrop, 111 | embd_pdrop=self.embd_pdrop, 112 | initializer_range=self.initializer_range) 113 | 114 | return (config, input_ids, token_type_ids, position_ids, 115 | mc_labels, lm_labels, mc_token_ids) 116 | 117 | def create_openai_model(self, config, input_ids, token_type_ids, position_ids, 118 | mc_labels, lm_labels, mc_token_ids): 119 | model = OpenAIGPTModel(config) 120 | model.eval() 121 | hidden_states = model(input_ids, position_ids, token_type_ids) 122 | outputs = { 123 | "hidden_states": hidden_states, 124 | } 125 | return outputs 126 | 127 | def check_openai_model_output(self, result): 128 | self.parent.assertListEqual( 129 | list(result["hidden_states"].size()), 130 | [self.batch_size, self.n_choices, self.seq_length, self.n_embd]) 131 | 132 | 133 | def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids, 134 | mc_labels, lm_labels, mc_token_ids): 135 | model = OpenAIGPTLMHeadModel(config) 136 | model.eval() 137 | loss = model(input_ids, position_ids, token_type_ids, lm_labels) 138 | lm_logits = model(input_ids, position_ids, token_type_ids) 139 | outputs = { 140 | "loss": loss, 141 | "lm_logits": lm_logits, 142 | } 143 | return outputs 144 | 145 | def check_openai_lm_head_output(self, result): 146 | total_voc = self.n_special + self.vocab_size 147 | self.parent.assertListEqual( 148 | list(result["lm_logits"].size()), 149 | [self.batch_size, self.n_choices, self.seq_length, total_voc]) 150 | 151 | def check_openai_lm_head_loss_output(self, result): 152 | self.parent.assertListEqual( 153 | list(result["loss"].size()), 154 | []) 155 | 156 | def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids, 157 | mc_labels, lm_labels, mc_token_ids): 158 | model = OpenAIGPTDoubleHeadsModel(config) 159 | model.eval() 160 | loss = model(input_ids, mc_token_ids, 161 | lm_labels=lm_labels, mc_labels=mc_labels, 162 | token_type_ids=token_type_ids, position_ids=position_ids) 163 | lm_logits, mc_logits = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids) 164 | outputs = { 165 | "loss": loss, 166 | "lm_logits": lm_logits, 167 | "mc_logits": mc_logits, 168 | } 169 | return outputs 170 | 171 | def check_openai_double_heads_output(self, result): 172 | total_voc = self.n_special + self.vocab_size 173 | self.parent.assertListEqual( 174 | list(result["lm_logits"].size()), 175 | [self.batch_size, self.n_choices, self.seq_length, total_voc]) 176 | self.parent.assertListEqual( 177 | list(result["mc_logits"].size()), 178 | [self.batch_size, self.n_choices]) 179 | 180 | def check_openai_double_heads_loss_output(self, result): 181 | self.parent.assertListEqual( 182 | [list(l.size()) for l in result["loss"]], 183 | [[], []]) 184 | 185 | def test_default(self): 186 | self.run_tester(OpenAIGPTModelTest.OpenAIGPTModelTester(self)) 187 | 188 | def test_config_to_json_string(self): 189 | config = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37) 190 | obj = json.loads(config.to_json_string()) 191 | self.assertEqual(obj["vocab_size"], 99) 192 | self.assertEqual(obj["n_embd"], 37) 193 | 194 | def test_config_to_json_file(self): 195 | config_first = OpenAIGPTConfig(vocab_size_or_config_json_file=99, n_embd=37) 196 | json_file_path = "/tmp/config.json" 197 | config_first.to_json_file(json_file_path) 198 | config_second = OpenAIGPTConfig.from_json_file(json_file_path) 199 | os.remove(json_file_path) 200 | self.assertEqual(config_second.to_dict(), config_first.to_dict()) 201 | 202 | @pytest.mark.slow 203 | def test_model_from_pretrained(self): 204 | cache_dir = "/tmp/pytorch_pretrained_bert_test/" 205 | for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 206 | model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir) 207 | shutil.rmtree(cache_dir) 208 | self.assertIsNotNone(model) 209 | 210 | def run_tester(self, tester): 211 | config_and_inputs = tester.prepare_config_and_inputs() 212 | output_result = tester.create_openai_model(*config_and_inputs) 213 | tester.check_openai_model_output(output_result) 214 | 215 | output_result = tester.create_openai_lm_head(*config_and_inputs) 216 | tester.check_openai_lm_head_output(output_result) 217 | tester.check_openai_lm_head_loss_output(output_result) 218 | 219 | output_result = tester.create_openai_double_heads(*config_and_inputs) 220 | tester.check_openai_double_heads_output(output_result) 221 | tester.check_openai_double_heads_loss_output(output_result) 222 | 223 | @classmethod 224 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 225 | """Creates a random int32 tensor of the shape within the vocab size.""" 226 | if rng is None: 227 | rng = random.Random() 228 | 229 | total_dims = 1 230 | for dim in shape: 231 | total_dims *= dim 232 | 233 | values = [] 234 | for _ in range(total_dims): 235 | values.append(rng.randint(0, vocab_size - 1)) 236 | 237 | return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() 238 | 239 | 240 | if __name__ == "__main__": 241 | unittest.main() 242 | -------------------------------------------------------------------------------- /tests/modeling_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import unittest 21 | import json 22 | import random 23 | import shutil 24 | import pytest 25 | 26 | import torch 27 | 28 | from pytorch_pretrained_bert import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) 29 | from pytorch_pretrained_bert.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP 30 | 31 | class TransfoXLModelTest(unittest.TestCase): 32 | class TransfoXLModelTester(object): 33 | 34 | def __init__(self, 35 | parent, 36 | batch_size=13, 37 | seq_length=7, 38 | mem_len=30, 39 | clamp_len=15, 40 | is_training=True, 41 | use_labels=True, 42 | vocab_size=99, 43 | cutoffs=[10, 50, 80], 44 | d_model=32, 45 | d_embed=32, 46 | n_head=4, 47 | d_head=8, 48 | d_inner=128, 49 | div_val=2, 50 | n_layer=5, 51 | scope=None, 52 | seed=1): 53 | self.parent = parent 54 | self.batch_size = batch_size 55 | self.seq_length = seq_length 56 | self.mem_len = mem_len 57 | self.clamp_len = clamp_len 58 | self.is_training = is_training 59 | self.use_labels = use_labels 60 | self.vocab_size = vocab_size 61 | self.cutoffs = cutoffs 62 | self.d_model = d_model 63 | self.d_embed = d_embed 64 | self.n_head = n_head 65 | self.d_head = d_head 66 | self.d_inner = d_inner 67 | self.div_val = div_val 68 | self.n_layer = n_layer 69 | self.scope = scope 70 | self.seed = seed 71 | 72 | def prepare_config_and_inputs(self): 73 | input_ids_1 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 74 | input_ids_2 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 75 | 76 | lm_labels = None 77 | if self.use_labels: 78 | lm_labels = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 79 | 80 | config = TransfoXLConfig( 81 | vocab_size_or_config_json_file=self.vocab_size, 82 | mem_len=self.mem_len, 83 | clamp_len=self.clamp_len, 84 | cutoffs=self.cutoffs, 85 | d_model=self.d_model, 86 | d_embed=self.d_embed, 87 | n_head=self.n_head, 88 | d_head=self.d_head, 89 | d_inner=self.d_inner, 90 | div_val=self.div_val, 91 | n_layer=self.n_layer) 92 | 93 | return (config, input_ids_1, input_ids_2, lm_labels) 94 | 95 | def set_seed(self): 96 | random.seed(self.seed) 97 | torch.manual_seed(self.seed) 98 | 99 | def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels): 100 | model = TransfoXLModel(config) 101 | model.eval() 102 | 103 | hidden_states_1, mems_1 = model(input_ids_1) 104 | hidden_states_2, mems_2 = model(input_ids_2, mems_1) 105 | outputs = { 106 | "hidden_states_1": hidden_states_1, 107 | "mems_1": mems_1, 108 | "hidden_states_2": hidden_states_2, 109 | "mems_2": mems_2, 110 | } 111 | return outputs 112 | 113 | def check_transfo_xl_model_output(self, result): 114 | self.parent.assertListEqual( 115 | list(result["hidden_states_1"].size()), 116 | [self.batch_size, self.seq_length, self.d_model]) 117 | self.parent.assertListEqual( 118 | list(result["hidden_states_2"].size()), 119 | [self.batch_size, self.seq_length, self.d_model]) 120 | self.parent.assertListEqual( 121 | list(list(mem.size()) for mem in result["mems_1"]), 122 | [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) 123 | self.parent.assertListEqual( 124 | list(list(mem.size()) for mem in result["mems_2"]), 125 | [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) 126 | 127 | 128 | def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels): 129 | model = TransfoXLLMHeadModel(config) 130 | model.eval() 131 | 132 | loss_1, mems_1a = model(input_ids_1, target=lm_labels) 133 | lm_logits_1, mems_1b = model(input_ids_1) 134 | 135 | loss_2, mems_2a = model(input_ids_2, target=lm_labels, mems=mems_1a) 136 | lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b) 137 | 138 | outputs = { 139 | "loss_1": loss_1, 140 | "mems_1a": mems_1a, 141 | "lm_logits_1": lm_logits_1, 142 | "mems_1b": mems_1b, 143 | "loss_2": loss_2, 144 | "mems_2a": mems_2a, 145 | "lm_logits_2": lm_logits_2, 146 | "mems_2b": mems_2b, 147 | } 148 | return outputs 149 | 150 | def check_transfo_xl_lm_head_output(self, result): 151 | self.parent.assertListEqual( 152 | list(result["loss_1"].size()), 153 | [self.batch_size, self.seq_length]) 154 | self.parent.assertListEqual( 155 | list(result["lm_logits_1"].size()), 156 | [self.batch_size, self.seq_length, self.vocab_size]) 157 | self.parent.assertListEqual( 158 | list(list(mem.size()) for mem in result["mems_1a"]), 159 | [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) 160 | self.parent.assertListEqual( 161 | list(list(mem.size()) for mem in result["mems_1b"]), 162 | [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) 163 | self.parent.assertListEqual( 164 | list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]), 165 | list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"])) 166 | 167 | self.parent.assertListEqual( 168 | list(result["loss_2"].size()), 169 | [self.batch_size, self.seq_length]) 170 | self.parent.assertListEqual( 171 | list(result["lm_logits_2"].size()), 172 | [self.batch_size, self.seq_length, self.vocab_size]) 173 | self.parent.assertListEqual( 174 | list(list(mem.size()) for mem in result["mems_2a"]), 175 | [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) 176 | self.parent.assertListEqual( 177 | list(list(mem.size()) for mem in result["mems_2b"]), 178 | [[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) 179 | self.parent.assertListEqual( 180 | list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2a"]), 181 | list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2b"])) 182 | 183 | def test_default(self): 184 | self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self)) 185 | 186 | def test_config_to_json_string(self): 187 | config = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37) 188 | obj = json.loads(config.to_json_string()) 189 | self.assertEqual(obj["n_token"], 96) 190 | self.assertEqual(obj["d_embed"], 37) 191 | 192 | def test_config_to_json_file(self): 193 | config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37) 194 | json_file_path = "/tmp/config.json" 195 | config_first.to_json_file(json_file_path) 196 | config_second = TransfoXLConfig.from_json_file(json_file_path) 197 | os.remove(json_file_path) 198 | self.assertEqual(config_second.to_dict(), config_first.to_dict()) 199 | 200 | @pytest.mark.slow 201 | def test_model_from_pretrained(self): 202 | cache_dir = "/tmp/pytorch_pretrained_bert_test/" 203 | for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 204 | model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) 205 | shutil.rmtree(cache_dir) 206 | self.assertIsNotNone(model) 207 | 208 | def run_tester(self, tester): 209 | config_and_inputs = tester.prepare_config_and_inputs() 210 | 211 | tester.set_seed() 212 | output_result = tester.create_transfo_xl_model(*config_and_inputs) 213 | tester.check_transfo_xl_model_output(output_result) 214 | 215 | tester.set_seed() 216 | output_result = tester.create_transfo_xl_lm_head(*config_and_inputs) 217 | tester.check_transfo_xl_lm_head_output(output_result) 218 | 219 | @classmethod 220 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 221 | """Creates a random int32 tensor of the shape within the vocab size.""" 222 | if rng is None: 223 | rng = random.Random() 224 | 225 | total_dims = 1 226 | for dim in shape: 227 | total_dims *= dim 228 | 229 | values = [] 230 | for _ in range(total_dims): 231 | values.append(rng.randint(0, vocab_size - 1)) 232 | 233 | return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() 234 | 235 | 236 | if __name__ == "__main__": 237 | unittest.main() 238 | -------------------------------------------------------------------------------- /tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | 21 | import torch 22 | 23 | from pytorch_pretrained_bert import BertAdam 24 | from pytorch_pretrained_bert import OpenAIAdam 25 | from pytorch_pretrained_bert.optimization import ConstantLR, WarmupLinearSchedule, WarmupConstantSchedule, \ 26 | WarmupCosineWithWarmupRestartsSchedule, WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule 27 | import numpy as np 28 | 29 | 30 | class OptimizationTest(unittest.TestCase): 31 | 32 | def assertListAlmostEqual(self, list1, list2, tol): 33 | self.assertEqual(len(list1), len(list2)) 34 | for a, b in zip(list1, list2): 35 | self.assertAlmostEqual(a, b, delta=tol) 36 | 37 | def test_adam(self): 38 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 39 | target = torch.tensor([0.4, 0.2, -0.5]) 40 | criterion = torch.nn.MSELoss() 41 | # No warmup, constant schedule, no gradient clipping 42 | optimizer = BertAdam(params=[w], lr=2e-1, 43 | weight_decay=0.0, 44 | max_grad_norm=-1) 45 | for _ in range(100): 46 | loss = criterion(w, target) 47 | loss.backward() 48 | optimizer.step() 49 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 50 | w.grad.zero_() 51 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 52 | 53 | 54 | class ScheduleInitTest(unittest.TestCase): 55 | def test_bert_sched_init(self): 56 | m = torch.nn.Linear(50, 50) 57 | optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None) 58 | self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) 59 | optim = BertAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none") 60 | self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) 61 | optim = BertAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000) 62 | self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule)) 63 | # shouldn't fail 64 | 65 | def test_openai_sched_init(self): 66 | m = torch.nn.Linear(50, 50) 67 | optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule=None) 68 | self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) 69 | optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.1, t_total=1000, schedule="none") 70 | self.assertTrue(isinstance(optim.param_groups[0]["schedule"], ConstantLR)) 71 | optim = OpenAIAdam(m.parameters(), lr=0.001, warmup=.01, t_total=1000) 72 | self.assertTrue(isinstance(optim.param_groups[0]["schedule"], WarmupLinearSchedule)) 73 | # shouldn't fail 74 | 75 | 76 | class WarmupCosineWithRestartsTest(unittest.TestCase): 77 | def test_it(self): 78 | m = WarmupCosineWithWarmupRestartsSchedule(warmup=0.05, t_total=1000., cycles=5) 79 | x = np.arange(0, 1000) 80 | y = [m.get_lr(xe) for xe in x] 81 | y = np.asarray(y) 82 | expected_zeros = y[[0, 200, 400, 600, 800]] 83 | print(expected_zeros) 84 | expected_ones = y[[50, 250, 450, 650, 850]] 85 | print(expected_ones) 86 | self.assertTrue(np.allclose(expected_ones, 1)) 87 | self.assertTrue(np.allclose(expected_zeros, 0)) 88 | 89 | 90 | if __name__ == "__main__": 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | import shutil 21 | import pytest 22 | 23 | from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP 24 | 25 | 26 | class GPT2TokenizationTest(unittest.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 30 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 31 | "lo", "low", "er", 32 | "low", "lowest", "newer", "wider"] 33 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 34 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 35 | with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: 36 | fp.write(json.dumps(vocab_tokens)) 37 | vocab_file = fp.name 38 | with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: 39 | fp.write("\n".join(merges)) 40 | merges_file = fp.name 41 | 42 | tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["", ""]) 43 | os.remove(vocab_file) 44 | os.remove(merges_file) 45 | 46 | text = "lower" 47 | bpe_tokens = ["low", "er"] 48 | tokens = tokenizer.tokenize(text) 49 | self.assertListEqual(tokens, bpe_tokens) 50 | 51 | input_tokens = tokens + [""] 52 | input_bpe_tokens = [13, 12, 16] 53 | self.assertListEqual( 54 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 55 | 56 | vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/") 57 | tokenizer_2 = GPT2Tokenizer.from_pretrained("/tmp/") 58 | os.remove(vocab_file) 59 | os.remove(merges_file) 60 | os.remove(special_tokens_file) 61 | 62 | self.assertListEqual( 63 | [tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks, 64 | tokenizer.special_tokens, tokenizer.special_tokens_decoder], 65 | [tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, 66 | tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder]) 67 | 68 | # @pytest.mark.slow 69 | def test_tokenizer_from_pretrained(self): 70 | cache_dir = "/tmp/pytorch_pretrained_bert_test/" 71 | for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: 72 | tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir) 73 | shutil.rmtree(cache_dir) 74 | self.assertIsNotNone(tokenizer) 75 | 76 | if __name__ == '__main__': 77 | unittest.main() 78 | -------------------------------------------------------------------------------- /tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | import shutil 21 | import pytest 22 | 23 | from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP 24 | 25 | 26 | class OpenAIGPTTokenizationTest(unittest.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 30 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 31 | "w", "r", "t", 32 | "lo", "low", "er", 33 | "low", "lowest", "newer", "wider"] 34 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 35 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 36 | with open("/tmp/openai_tokenizer_vocab_test.json", "w") as fp: 37 | fp.write(json.dumps(vocab_tokens)) 38 | vocab_file = fp.name 39 | with open("/tmp/openai_tokenizer_merges_test.txt", "w") as fp: 40 | fp.write("\n".join(merges)) 41 | merges_file = fp.name 42 | 43 | tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["", ""]) 44 | os.remove(vocab_file) 45 | os.remove(merges_file) 46 | 47 | text = "lower" 48 | bpe_tokens = ["low", "er"] 49 | tokens = tokenizer.tokenize(text) 50 | self.assertListEqual(tokens, bpe_tokens) 51 | 52 | input_tokens = tokens + [""] 53 | input_bpe_tokens = [14, 15, 20] 54 | self.assertListEqual( 55 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 56 | 57 | vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/") 58 | tokenizer_2 = OpenAIGPTTokenizer.from_pretrained("/tmp/") 59 | os.remove(vocab_file) 60 | os.remove(merges_file) 61 | os.remove(special_tokens_file) 62 | 63 | self.assertListEqual( 64 | [tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks, 65 | tokenizer.special_tokens, tokenizer.special_tokens_decoder], 66 | [tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks, 67 | tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder]) 68 | 69 | @pytest.mark.slow 70 | def test_tokenizer_from_pretrained(self): 71 | cache_dir = "/tmp/pytorch_pretrained_bert_test/" 72 | for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: 73 | tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name, cache_dir=cache_dir) 74 | shutil.rmtree(cache_dir) 75 | self.assertIsNotNone(tokenizer) 76 | 77 | 78 | if __name__ == '__main__': 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /tests/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | import shutil 21 | import pytest 22 | 23 | from pytorch_pretrained_bert.tokenization import (BasicTokenizer, 24 | BertTokenizer, 25 | WordpieceTokenizer, 26 | _is_control, _is_punctuation, 27 | _is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP) 28 | 29 | 30 | class TokenizationTest(unittest.TestCase): 31 | 32 | def test_full_tokenizer(self): 33 | vocab_tokens = [ 34 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 35 | "##ing", "," 36 | ] 37 | with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: 38 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = BertTokenizer(vocab_file) 43 | os.remove(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertListEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/") 52 | tokenizer.from_pretrained(vocab_file) 53 | os.remove(vocab_file) 54 | 55 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 56 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 57 | 58 | self.assertListEqual( 59 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 60 | 61 | @pytest.mark.slow 62 | def test_tokenizer_from_pretrained(self): 63 | cache_dir = "/tmp/pytorch_pretrained_bert_test/" 64 | for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: 65 | tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir) 66 | shutil.rmtree(cache_dir) 67 | self.assertIsNotNone(tokenizer) 68 | 69 | def test_chinese(self): 70 | tokenizer = BasicTokenizer() 71 | 72 | self.assertListEqual( 73 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 74 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 75 | 76 | def test_basic_tokenizer_lower(self): 77 | tokenizer = BasicTokenizer(do_lower_case=True) 78 | 79 | self.assertListEqual( 80 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 81 | ["hello", "!", "how", "are", "you", "?"]) 82 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 83 | 84 | def test_basic_tokenizer_no_lower(self): 85 | tokenizer = BasicTokenizer(do_lower_case=False) 86 | 87 | self.assertListEqual( 88 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 89 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 90 | 91 | def test_wordpiece_tokenizer(self): 92 | vocab_tokens = [ 93 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 94 | "##ing" 95 | ] 96 | 97 | vocab = {} 98 | for (i, token) in enumerate(vocab_tokens): 99 | vocab[token] = i 100 | tokenizer = WordpieceTokenizer(vocab=vocab) 101 | 102 | self.assertListEqual(tokenizer.tokenize(""), []) 103 | 104 | self.assertListEqual( 105 | tokenizer.tokenize("unwanted running"), 106 | ["un", "##want", "##ed", "runn", "##ing"]) 107 | 108 | self.assertListEqual( 109 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 110 | 111 | def test_is_whitespace(self): 112 | self.assertTrue(_is_whitespace(u" ")) 113 | self.assertTrue(_is_whitespace(u"\t")) 114 | self.assertTrue(_is_whitespace(u"\r")) 115 | self.assertTrue(_is_whitespace(u"\n")) 116 | self.assertTrue(_is_whitespace(u"\u00A0")) 117 | 118 | self.assertFalse(_is_whitespace(u"A")) 119 | self.assertFalse(_is_whitespace(u"-")) 120 | 121 | def test_is_control(self): 122 | self.assertTrue(_is_control(u"\u0005")) 123 | 124 | self.assertFalse(_is_control(u"A")) 125 | self.assertFalse(_is_control(u" ")) 126 | self.assertFalse(_is_control(u"\t")) 127 | self.assertFalse(_is_control(u"\r")) 128 | 129 | def test_is_punctuation(self): 130 | self.assertTrue(_is_punctuation(u"-")) 131 | self.assertTrue(_is_punctuation(u"$")) 132 | self.assertTrue(_is_punctuation(u"`")) 133 | self.assertTrue(_is_punctuation(u".")) 134 | 135 | self.assertFalse(_is_punctuation(u"A")) 136 | self.assertFalse(_is_punctuation(u" ")) 137 | 138 | 139 | if __name__ == '__main__': 140 | unittest.main() 141 | -------------------------------------------------------------------------------- /tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | import shutil 21 | import pytest 22 | 23 | from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP 24 | 25 | 26 | class TransfoXLTokenizationTest(unittest.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", "," 31 | ] 32 | with open("/tmp/transfo_xl_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer: 33 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 34 | vocab_file = vocab_writer.name 35 | 36 | tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) 37 | tokenizer.build_vocab() 38 | os.remove(vocab_file) 39 | 40 | tokens = tokenizer.tokenize(u" UNwanted , running") 41 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 42 | 43 | self.assertListEqual( 44 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 45 | 46 | vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/") 47 | tokenizer.from_pretrained(vocab_file) 48 | os.remove(vocab_file) 49 | 50 | tokens = tokenizer.tokenize(u" UNwanted , running") 51 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 52 | 53 | self.assertListEqual( 54 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 55 | 56 | 57 | def test_full_tokenizer_lower(self): 58 | tokenizer = TransfoXLTokenizer(lower_case=True) 59 | 60 | self.assertListEqual( 61 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 62 | ["hello", "!", "how", "are", "you", "?"]) 63 | 64 | def test_full_tokenizer_no_lower(self): 65 | tokenizer = TransfoXLTokenizer(lower_case=False) 66 | 67 | self.assertListEqual( 68 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 69 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 70 | 71 | @pytest.mark.slow 72 | def test_tokenizer_from_pretrained(self): 73 | cache_dir = "/tmp/pytorch_pretrained_bert_test/" 74 | for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: 75 | tokenizer = TransfoXLTokenizer.from_pretrained(model_name, cache_dir=cache_dir) 76 | shutil.rmtree(cache_dir) 77 | self.assertIsNotNone(tokenizer) 78 | 79 | if __name__ == '__main__': 80 | unittest.main() 81 | --------------------------------------------------------------------------------