├── .coveragerc ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── eter.py ├── examples ├── ZJL │ ├── chinese_L-12_H-768_A-12 │ │ └── aaaa │ └── data │ │ ├── dev │ │ └── Result(example).csv │ │ ├── test │ │ ├── Result(example).csv │ │ ├── Test_reviews.csv │ │ └── test.csv │ │ └── train │ │ ├── Train_labels.csv │ │ ├── Train_reviews.csv │ │ └── train.csv ├── requirements.txt ├── run.sh ├── run_zhijiang.py └── utils_zhijiang.py ├── hubconf.py ├── hubconfs ├── bert_hubconf.py ├── gpt2_hubconf.py ├── gpt_hubconf.py ├── transformer_xl_hubconf.py ├── xlm_hubconf.py └── xlnet_hubconf.1.py ├── pytorch_transformers ├── __init__.py ├── __main__.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_pytorch_checkpoint_to_tf.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── convert_xlm_checkpoint_to_pytorch.py ├── convert_xlnet_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling_bert.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── modeling_utils.py ├── modeling_xlm.py ├── modeling_xlnet.py ├── optimization.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── fixtures │ │ ├── input.txt │ │ ├── sample_text.txt │ │ └── test_sentencepiece.model │ ├── modeling_bert_test.py │ ├── modeling_common_test.py │ ├── modeling_gpt2_test.py │ ├── modeling_openai_test.py │ ├── modeling_transfo_xl_test.py │ ├── modeling_xlm_test.py │ ├── modeling_xlnet_test.py │ ├── optimization_test.py │ ├── tokenization_bert_test.py │ ├── tokenization_gpt2_test.py │ ├── tokenization_openai_test.py │ ├── tokenization_tests_commons.py │ ├── tokenization_transfo_xl_test.py │ ├── tokenization_utils_test.py │ ├── tokenization_xlm_test.py │ └── tokenization_xlnet_test.py ├── tokenization_bert.py ├── tokenization_gpt2.py ├── tokenization_openai.py ├── tokenization_transfo_xl.py ├── tokenization_utils.py ├── tokenization_xlm.py └── tokenization_xlnet.py ├── requirements.txt └── setup.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source=pytorch_transformers 3 | omit = 4 | # skip convertion scripts from testing for now 5 | */convert_* 6 | */__main__.py 7 | [report] 8 | exclude_lines = 9 | pragma: no cover 10 | raise 11 | except 12 | register_parameter -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | # vscode 119 | .vscode 120 | 121 | # TF code 122 | tensorflow_code 123 | 124 | # Models 125 | models 126 | proc_data 127 | 128 | # examples 129 | runs 130 | examples/runs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于PyTorch-Transformers版本 2 | -------------------------------------------------------------------------------- /eter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | lines = [] 5 | for line in sys.stdin: 6 | lines.append(line.split()) 7 | N, M, S, D = lines[0] 8 | cakes = lines[1] 9 | for i in range(2, M): 10 | print(lines[M]) -------------------------------------------------------------------------------- /examples/ZJL/chinese_L-12_H-768_A-12/aaaa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heyangHEY/BERT-CRF/b869ca749a44647ddd6bac09b59f42fd55d0884c/examples/ZJL/chinese_L-12_H-768_A-12/aaaa -------------------------------------------------------------------------------- /examples/ZJL/data/dev/Result(example).csv: -------------------------------------------------------------------------------- 1 | 1,味道,香香的,气味,正面 2 | 1,物流,神速,物流,正面 3 | 2,_,好贵,价格,负面 4 | 3,_,_,_,_ -------------------------------------------------------------------------------- /examples/ZJL/data/test/Result(example).csv: -------------------------------------------------------------------------------- 1 | 1,味道,香香的,气味,正面 2 | 1,物流,神速,物流,正面 3 | 2,_,好贵,价格,负面 4 | 3,_,_,_,_ -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | scikit-learn -------------------------------------------------------------------------------- /examples/run.sh: -------------------------------------------------------------------------------- 1 | python utils_zhijiang.py \ 2 | --data_dir /ZJL/examples/ZJL/data/ \ 3 | --train_review train/Train_reviews.csv \ 4 | --train_result train/Train_labels.csv \ 5 | --train_file train/train.csv \ 6 | --test_review test/Test_reviews.csv \ 7 | --test_file test/test.csv \ 8 | --dev_file dev/dev.csv \ 9 | --split_ratio 0.1 10 | 11 | cuda=0 12 | CUDA_VISIBLE_DEVICES=$cuda python run_zhijiang.py \ 13 | --data_dir /ZJL/examples/ZJL/data/ \ 14 | --model_type bert \ 15 | --model_name_or_path bert-base-chinese \ 16 | --task_name zhijiang \ 17 | --output_dir /ZJL/examples/ZJL/result \ 18 | --do_train \ 19 | --evaluate_during_training \ 20 | --per_gpu_train_batch_size 40 \ 21 | --per_gpu_eval_batch_size 40 \ 22 | --num_train_epochs 10.0 \ 23 | --save_steps 10 \ 24 | --logging_steps 10 \ 25 | --learning_rate 5e-5 \ 26 | --warmup_steps 73 27 | 28 | # --evaluate_during_training \ 29 | # --do_eval \ 30 | # --eval_all_checkpoints \ 31 | #BERT_CHINESE_DIR=/ZJL/chinese_L-12_H-768_A-12/ 32 | 33 | 34 | #cuda=0 35 | #CUDA_VISIBLE_DEVICES=$cuda python run_zhijiang.py \ 36 | # --data_dir /ZJL/examples/ZJL/data/ \ 37 | # --model_type bert \ 38 | # --model_name_or_path bert-base-chinese \ 39 | # --task_name zhijiang \ 40 | # --output_dir /ZJL/examples/ZJL/result/checkpoint-640 \ 41 | # --do_test \ 42 | # --per_gpu_test_batch_size 8 -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex'] 2 | 3 | from hubconfs.bert_hubconf import ( 4 | bertTokenizer, 5 | bertModel, 6 | bertForNextSentencePrediction, 7 | bertForPreTraining, 8 | bertForMaskedLM, 9 | bertForSequenceClassification, 10 | bertForMultipleChoice, 11 | bertForQuestionAnswering, 12 | bertForTokenClassification 13 | ) 14 | from hubconfs.gpt_hubconf import ( 15 | openAIGPTTokenizer, 16 | openAIGPTModel, 17 | openAIGPTLMHeadModel, 18 | openAIGPTDoubleHeadsModel 19 | ) 20 | from hubconfs.gpt2_hubconf import ( 21 | gpt2Tokenizer, 22 | gpt2Model, 23 | gpt2LMHeadModel, 24 | gpt2DoubleHeadsModel 25 | ) 26 | from hubconfs.transformer_xl_hubconf import ( 27 | transformerXLTokenizer, 28 | transformerXLModel, 29 | transformerXLLMHeadModel 30 | ) 31 | -------------------------------------------------------------------------------- /hubconfs/gpt2_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer 2 | from pytorch_transformers.modeling_gpt2 import ( 3 | GPT2Model, 4 | GPT2LMHeadModel, 5 | GPT2DoubleHeadsModel 6 | ) 7 | 8 | # A lot of models share the same param doc. Use a decorator 9 | # to save typing 10 | gpt2_docstring = """ 11 | Params: 12 | pretrained_model_name_or_path: either: 13 | - a str with the name of a pre-trained model to load selected in the list of: 14 | . `gpt2`, `gpt2-medium` 15 | - a path or url to a pretrained model archive containing: 16 | . `gpt2_config.json` a configuration file for the model 17 | . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance 18 | - a path or url to a pretrained model archive containing: 19 | . `gpt2_config.json` a configuration file for the model 20 | . a TensorFlow checkpoint with trained weights 21 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 22 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 23 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 24 | *inputs, **kwargs: additional input for the specific GPT-2 class 25 | """ 26 | 27 | 28 | def _append_from_pretrained_docstring(docstr): 29 | def docstring_decorator(fn): 30 | fn.__doc__ = fn.__doc__ + docstr 31 | return fn 32 | return docstring_decorator 33 | 34 | 35 | def gpt2Tokenizer(*args, **kwargs): 36 | """ 37 | Instantiate a GPT-2 BPE tokenizer for OpenAI GPT-2 from a pre-trained/customized vocab file. 38 | Peculiarities: 39 | - Byte-level BPE 40 | 41 | Args: 42 | pretrained_model_name_or_path: Path to pretrained model archive 43 | or one of pre-trained vocab configs below. 44 | * gpt2 45 | Keyword args: 46 | special_tokens: Special tokens in vocabulary that are not pretrained ([SEP], [CLS]...) 47 | Default: None 48 | max_len: An artificial maximum length to truncate tokenized sequences to; 49 | Effective maximum length is always the minimum of this 50 | value (if specified) and the underlying BERT model's 51 | sequence length. 52 | Default: None 53 | 54 | Example: 55 | >>> import torch 56 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 57 | 58 | >>> text = "Who was Jim Henson ?" 59 | >>> indexed_tokens = tokenizer.encode(tokenized_text) 60 | """ 61 | tokenizer = GPT2Tokenizer.from_pretrained(*args, **kwargs) 62 | return tokenizer 63 | 64 | 65 | @_append_from_pretrained_docstring(gpt2_docstring) 66 | def gpt2Model(*args, **kwargs): 67 | """ 68 | gpt2Model is the basic OpenAI GPT-2 Transformer model based on 69 | identical stacked masked self-attention blocks and pre-trained 70 | on large scale dataset using language modeling signal. 71 | 72 | Example: 73 | # Load the tokenizer 74 | >>> import torch 75 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 76 | 77 | # Prepare tokenized input 78 | >>> text_1 = "Who was Jim Henson ?" 79 | >>> text_2 = "Jim Henson was a puppeteer" 80 | >>> indexed_tokens_1 = tokenizer.encode(text_1) 81 | >>> indexed_tokens_2 = tokenizer.encode(text_2) 82 | >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 83 | >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 84 | 85 | # Load gpt2Model 86 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Model', 'gpt2') 87 | >>> model.eval() 88 | 89 | # Predict hidden states features for each layer 90 | # past can be used to reuse precomputed hidden state in a subsequent predictions 91 | >>> with torch.no_grad(): 92 | hidden_states_1, past = model(tokens_tensor_1) 93 | hidden_states_2, past = model(tokens_tensor_2, past=past) 94 | """ 95 | model = GPT2Model.from_pretrained(*args, **kwargs) 96 | return model 97 | 98 | 99 | @_append_from_pretrained_docstring(gpt2_docstring) 100 | def gpt2LMHeadModel(*args, **kwargs): 101 | """ 102 | gpt2LMHeadModel is the OpenAI GPT-2 Transformer model with the 103 | tied (pre-trained) language modeling head on top. 104 | 105 | Example: 106 | # Load the tokenizer 107 | >>> import torch 108 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 109 | 110 | # Prepare tokenized input 111 | >>> text_1 = "Who was Jim Henson ?" 112 | >>> text_2 = "Jim Henson was a puppeteer" 113 | >>> indexed_tokens_1 = tokenizer.encode(text_1) 114 | >>> indexed_tokens_2 = tokenizer.encode(text_2) 115 | >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 116 | >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 117 | 118 | # Load gpt2LMHeadModel 119 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'gpt2LMHeadModel', 'gpt2') 120 | >>> model.eval() 121 | 122 | # Predict hidden states features for each layer 123 | # past can be used to reuse precomputed hidden state in a subsequent predictions 124 | >>> with torch.no_grad(): 125 | predictions_1, past = model(tokens_tensor_1) 126 | predictions_2, past = model(tokens_tensor_2, past=past) 127 | 128 | # Get the predicted last token 129 | >>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 130 | >>> predicted_token = tokenizer.decode([predicted_index]) 131 | >>> assert predicted_token == ' who' 132 | """ 133 | model = GPT2LMHeadModel.from_pretrained(*args, **kwargs) 134 | return model 135 | 136 | 137 | @_append_from_pretrained_docstring(gpt2_docstring) 138 | def gpt2DoubleHeadsModel(*args, **kwargs): 139 | """ 140 | gpt2DoubleHeadsModel is the OpenAI GPT-2 Transformer model with the 141 | tied (pre-trained) language modeling head and a multiple choice 142 | classification head (only initialized, not pre-trained). 143 | 144 | Example: 145 | # Load the tokenizer 146 | >>> import torch 147 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 148 | 149 | # Prepare tokenized input 150 | >>> text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 151 | >>> text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 152 | >>> tokenized_text1 = tokenizer.tokenize(text1) 153 | >>> tokenized_text2 = tokenizer.tokenize(text2) 154 | >>> indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 155 | >>> indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 156 | >>> tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 157 | >>> mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 158 | 159 | # Load gpt2DoubleHeadsModel 160 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'gpt2DoubleHeadsModel', 'gpt2') 161 | >>> model.eval() 162 | 163 | # Predict hidden states features for each layer 164 | >>> with torch.no_grad(): 165 | lm_logits, multiple_choice_logits, presents = model(tokens_tensor, mc_token_ids) 166 | """ 167 | model = GPT2DoubleHeadsModel.from_pretrained(*args, **kwargs) 168 | return model 169 | -------------------------------------------------------------------------------- /hubconfs/gpt_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer 2 | from pytorch_transformers.modeling_openai import ( 3 | OpenAIGPTModel, 4 | OpenAIGPTLMHeadModel, 5 | OpenAIGPTDoubleHeadsModel 6 | ) 7 | 8 | # Dependecies that are not specified in global hubconf.py 9 | specific_dependencies = ['spacy', 'ftfy'] 10 | 11 | # A lot of models share the same param doc. Use a decorator 12 | # to save typing 13 | gpt_docstring = """ 14 | OpenAI GPT use a single embedding matrix to store the word and special embeddings. 15 | Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]... 16 | Special tokens need to be trained during the fine-tuning if you use them. 17 | The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function. 18 | 19 | The embeddings are ordered as follow in the token embeddings matrice: 20 | [0, ---------------------- 21 | ... -> word embeddings 22 | config.vocab_size - 1, ______________________ 23 | config.vocab_size, 24 | ... -> special embeddings 25 | config.vocab_size + config.n_special - 1] ______________________ 26 | 27 | where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is: 28 | total_tokens_embeddings = config.vocab_size + config.n_special 29 | You should use the associate indices to index the embeddings. 30 | 31 | Params: 32 | pretrained_model_name_or_path: either: 33 | - a str with the name of a pre-trained model to load selected in the list of: 34 | . `openai-gpt` 35 | - a path or url to a pretrained model archive containing: 36 | . `openai_gpt_config.json` a configuration file for the model 37 | . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance 38 | - a path or url to a pretrained model archive containing: 39 | . `openai-gpt-config.json` a configuration file for the model 40 | . a series of NumPy files containing OpenAI TensorFlow trained weights 41 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 42 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 43 | state_dict: an optional state dictionnary (collections.OrderedDict object) 44 | to use instead of pre-trained models 45 | *inputs, **kwargs: additional input for the specific OpenAI-GPT class 46 | """ 47 | 48 | 49 | def _append_from_pretrained_docstring(docstr): 50 | def docstring_decorator(fn): 51 | fn.__doc__ = fn.__doc__ + docstr 52 | return fn 53 | return docstring_decorator 54 | 55 | 56 | def openAIGPTTokenizer(*args, **kwargs): 57 | """ 58 | Instantiate a BPE tokenizer for OpenAI GPT from a pre-trained/customized vocab file. 59 | Peculiarities: 60 | - lower case all inputs 61 | - uses SpaCy tokenizer ('en' model) and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 62 | - argument special_tokens and function set_special_tokens: 63 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 64 | 65 | Args: 66 | pretrained_model_name_or_path: Path to pretrained model archive 67 | or one of pre-trained vocab configs below. 68 | * openai-gpt 69 | Keyword args: 70 | special_tokens: Special tokens in vocabulary that are not pretrained ([SEP], [CLS]...) 71 | Default: None 72 | max_len: An artificial maximum length to truncate tokenized sequences to; 73 | Effective maximum length is always the minimum of this 74 | value (if specified) and the underlying BERT model's 75 | sequence length. 76 | Default: None 77 | 78 | Example: 79 | >>> import torch 80 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTTokenizer', 'openai-gpt') 81 | 82 | >>> text = "Who was Jim Henson ? Jim Henson was a puppeteer" 83 | >>> tokenized_text = tokenizer.tokenize(text) 84 | >>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 85 | [763, 509, 4265, 2298, 945, 257, 4265, 2298, 945, 509, 246, 10148, 39041, 483] 86 | """ 87 | tokenizer = OpenAIGPTTokenizer.from_pretrained(*args, **kwargs) 88 | return tokenizer 89 | 90 | 91 | @_append_from_pretrained_docstring(gpt_docstring) 92 | def openAIGPTModel(*args, **kwargs): 93 | """ 94 | OpenAIGPTModel is the basic OpenAI GPT Transformer model based on 95 | identical stacked masked self-attention blocks and pre-trained 96 | on large scale dataset using language modeling signal. 97 | 98 | Example: 99 | # Load the tokenizer 100 | >>> import torch 101 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTTokenizer', 'openai-gpt') 102 | 103 | # Prepare tokenized input 104 | >>> text = "Who was Jim Henson ? Jim Henson was a puppeteer" 105 | >>> tokenized_text = tokenizer.tokenize(text) 106 | >>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 107 | >>> tokens_tensor = torch.tensor([indexed_tokens]) 108 | 109 | # Load openAIGPTModel 110 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTModel', 'openai-gpt') 111 | >>> model.eval() 112 | 113 | # Predict hidden states features for each layer 114 | >>> with torch.no_grad(): 115 | hidden_states = model(tokens_tensor) 116 | """ 117 | model = OpenAIGPTModel.from_pretrained(*args, **kwargs) 118 | return model 119 | 120 | 121 | @_append_from_pretrained_docstring(gpt_docstring) 122 | def openAIGPTLMHeadModel(*args, **kwargs): 123 | """ 124 | OpenAIGPTLMHeadModel is the OpenAI GPT Transformer model with the 125 | tied (pre-trained) language modeling head on top. 126 | 127 | Example: 128 | # Load the tokenizer 129 | >>> import torch 130 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTTokenizer', 'openai-gpt') 131 | 132 | # Prepare tokenized input 133 | >>> text = "Who was Jim Henson ? Jim Henson was a puppeteer" 134 | >>> tokenized_text = tokenizer.tokenize(text) 135 | >>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 136 | >>> tokens_tensor = torch.tensor([indexed_tokens]) 137 | 138 | # Load openAIGPTLMHeadModel 139 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTLMHeadModel', 'openai-gpt') 140 | >>> model.eval() 141 | 142 | # Predict hidden states features for each layer 143 | >>> with torch.no_grad(): 144 | predictions = model(tokens_tensor) 145 | 146 | # Get the predicted last token 147 | >>> predicted_index = torch.argmax(predictions[0, -1, :]).item() 148 | >>> predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] 149 | '.' 150 | """ 151 | model = OpenAIGPTLMHeadModel.from_pretrained(*args, **kwargs) 152 | return model 153 | 154 | 155 | @_append_from_pretrained_docstring(gpt_docstring) 156 | def openAIGPTDoubleHeadsModel(*args, **kwargs): 157 | """ 158 | OpenAIGPTDoubleHeadsModel is the OpenAI GPT Transformer model with the 159 | tied (pre-trained) language modeling head and a multiple choice 160 | classification head (only initialized, not pre-trained). 161 | 162 | Example: 163 | # Load the tokenizer 164 | >>> import torch 165 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTTokenizer', 'openai-gpt') 166 | 167 | # Prepare tokenized input 168 | >>> text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 169 | >>> text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 170 | >>> tokenized_text1 = tokenizer.tokenize(text1) 171 | >>> tokenized_text2 = tokenizer.tokenize(text2) 172 | >>> indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 173 | >>> indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 174 | >>> tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 175 | >>> mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 176 | 177 | # Load openAIGPTDoubleHeadsModel 178 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'openAIGPTDoubleHeadsModel', 'openai-gpt') 179 | >>> model.eval() 180 | 181 | # Predict hidden states features for each layer 182 | >>> with torch.no_grad(): 183 | lm_logits, multiple_choice_logits = model(tokens_tensor, mc_token_ids) 184 | """ 185 | model = OpenAIGPTDoubleHeadsModel.from_pretrained(*args, **kwargs) 186 | return model 187 | -------------------------------------------------------------------------------- /hubconfs/transformer_xl_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer 2 | from pytorch_transformers.modeling_transfo_xl import ( 3 | TransfoXLModel, 4 | TransfoXLLMHeadModel 5 | ) 6 | 7 | # A lot of models share the same param doc. Use a decorator 8 | # to save typing 9 | transformer_xl_docstring = """ 10 | Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that: 11 | - you don't need to specify positioning embeddings indices 12 | - the tokens in the vocabulary have to be sorted to decreasing frequency. 13 | 14 | Params: 15 | pretrained_model_name_or_path: either: 16 | - a str with the name of a pre-trained model to load selected in the list of: 17 | . `transfo-xl-wt103` 18 | - a path or url to a pretrained model archive containing: 19 | . `transfo_xl_config.json` a configuration file for the model 20 | . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance 21 | - a path or url to a pretrained model archive containing: 22 | . `transfo_xl_config.json` a configuration file for the model 23 | . `model.chkpt` a TensorFlow checkpoint 24 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 25 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 26 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models 27 | *inputs, **kwargs: additional input for the specific TransformerXL class 28 | """ 29 | 30 | 31 | def _append_from_pretrained_docstring(docstr): 32 | def docstring_decorator(fn): 33 | fn.__doc__ = fn.__doc__ + docstr 34 | return fn 35 | return docstring_decorator 36 | 37 | 38 | def transformerXLTokenizer(*args, **kwargs): 39 | """ 40 | Instantiate a Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl 41 | 42 | Args: 43 | pretrained_model_name_or_path: Path to pretrained model archive 44 | or one of pre-trained vocab configs below. 45 | * transfo-xl-wt103 46 | 47 | Example: 48 | >>> import torch 49 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103') 50 | 51 | >>> text = "Who was Jim Henson ?" 52 | >>> tokenized_text = tokenizer.tokenize(tokenized_text) 53 | >>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 54 | """ 55 | tokenizer = TransfoXLTokenizer.from_pretrained(*args, **kwargs) 56 | return tokenizer 57 | 58 | 59 | @_append_from_pretrained_docstring(transformer_xl_docstring) 60 | def transformerXLModel(*args, **kwargs): 61 | """ 62 | transformerXLModel is the basic Transformer XL model. 63 | 64 | Example: 65 | # Load the tokenizer 66 | >>> import torch 67 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103') 68 | 69 | # Prepare tokenized input 70 | >>> text_1 = "Who was Jim Henson ?" 71 | >>> text_2 = "Jim Henson was a puppeteer" 72 | >>> tokenized_text_1 = tokenizer.tokenize(text_1) 73 | >>> tokenized_text_2 = tokenizer.tokenize(text_2) 74 | >>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1) 75 | >>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2) 76 | >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 77 | >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 78 | 79 | # Load transformerXLModel 80 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLModel', 'transfo-xl-wt103') 81 | >>> model.eval() 82 | 83 | # Predict hidden states features for each layer 84 | # We can re-use the memory cells in a subsequent call to attend a longer context 85 | >>> with torch.no_grad(): 86 | hidden_states_1, mems_1 = model(tokens_tensor_1) 87 | hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1) 88 | """ 89 | model = TransfoXLModel.from_pretrained(*args, **kwargs) 90 | return model 91 | 92 | 93 | @_append_from_pretrained_docstring(transformer_xl_docstring) 94 | def transformerXLLMHeadModel(*args, **kwargs): 95 | """ 96 | transformerXLModel is the basic Transformer XL model with the 97 | tied (pre-trained) language modeling head on top. 98 | 99 | Example: 100 | # Load the tokenizer 101 | >>> import torch 102 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103') 103 | 104 | # Prepare tokenized input 105 | >>> text_1 = "Who was Jim Henson ?" 106 | >>> text_2 = "Jim Henson was a puppeteer" 107 | >>> tokenized_text_1 = tokenizer.tokenize(text_1) 108 | >>> tokenized_text_2 = tokenizer.tokenize(text_2) 109 | >>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1) 110 | >>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2) 111 | >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 112 | >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 113 | 114 | # Load transformerXLLMHeadModel 115 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLLMHeadModel', 'transfo-xl-wt103') 116 | >>> model.eval() 117 | 118 | # Predict hidden states features for each layer 119 | # We can re-use the memory cells in a subsequent call to attend a longer context 120 | >>> with torch.no_grad(): 121 | predictions_1, mems_1 = model(tokens_tensor_1) 122 | predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1) 123 | 124 | # Get the predicted last token 125 | >>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 126 | >>> predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] 127 | >>> assert predicted_token == 'who' 128 | """ 129 | model = TransfoXLLMHeadModel.from_pretrained(*args, **kwargs) 130 | return model 131 | -------------------------------------------------------------------------------- /hubconfs/xlm_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_xlm import XLMTokenizer 2 | from pytorch_transformers.modeling_xlm import ( 3 | XLMConfig, 4 | XLMModel, 5 | XLMWithLMHeadModel, 6 | XLMForSequenceClassification, 7 | XLMForQuestionAnswering 8 | ) 9 | 10 | # A lot of models share the same param doc. Use a decorator 11 | # to save typing 12 | xlm_start_docstring = """ 13 | Model class adapted from the XLM Transformer model of 14 | "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau 15 | Paper: https://arxiv.org/abs/1901.07291 16 | Original code: https://github.com/facebookresearch/XLM 17 | 18 | Example: 19 | # Load the tokenizer 20 | >>> import torch 21 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlmTokenizer', 'xlm-mlm-en-2048') 22 | 23 | # Prepare tokenized input 24 | >>> text_1 = "Who was Jim Henson ?" 25 | >>> text_2 = "Jim Henson was a puppeteer" 26 | >>> indexed_tokens_1 = tokenizer.encode(text_1) 27 | >>> indexed_tokens_2 = tokenizer.encode(text_2) 28 | >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 29 | >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 30 | """ 31 | 32 | # A lot of models share the same param doc. Use a decorator 33 | # to save typing 34 | xlm_end_docstring = """ 35 | Params: 36 | pretrained_model_name_or_path: either: 37 | - a str with the name of a pre-trained model to load selected in the list of: 38 | . `xlm-mlm-en-2048` 39 | - a path or url to a pretrained model archive containing: 40 | . `config.json` a configuration file for the model 41 | . `pytorch_model.bin` a PyTorch dump created using the `convert_xlm_checkpoint_to_pytorch` conversion script 42 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 43 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 44 | *inputs, **kwargs: additional input for the specific XLM class 45 | """ 46 | 47 | 48 | def _begin_with_docstring(docstr): 49 | def docstring_decorator(fn): 50 | fn.__doc__ = fn.__doc__ + docstr 51 | return fn 52 | return docstring_decorator 53 | 54 | def _end_with_docstring(docstr): 55 | def docstring_decorator(fn): 56 | fn.__doc__ = fn.__doc__ + docstr 57 | return fn 58 | return docstring_decorator 59 | 60 | 61 | def xlmTokenizer(*args, **kwargs): 62 | """ 63 | Instantiate a XLM BPE tokenizer for XLM from a pre-trained vocab file. 64 | 65 | Args: 66 | pretrained_model_name_or_path: Path to pretrained model archive 67 | or one of pre-trained vocab configs below. 68 | * xlm-mlm-en-2048 69 | Keyword args: 70 | special_tokens: Special tokens in vocabulary that are not pretrained 71 | Default: None 72 | max_len: An artificial maximum length to truncate tokenized sequences to; 73 | Effective maximum length is always the minimum of this 74 | value (if specified) and the underlying model's 75 | sequence length. 76 | Default: None 77 | 78 | Example: 79 | >>> import torch 80 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlmTokenizer', 'xlm-mlm-en-2048') 81 | 82 | >>> text = "Who was Jim Henson ?" 83 | >>> indexed_tokens = tokenizer.encode(tokenized_text) 84 | """ 85 | tokenizer = XLMTokenizer.from_pretrained(*args, **kwargs) 86 | return tokenizer 87 | 88 | 89 | @_begin_with_docstring(xlm_start_docstring) 90 | @_end_with_docstring(xlm_end_docstring) 91 | def xlmModel(*args, **kwargs): 92 | """ 93 | # Load xlmModel 94 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'xlmModel', 'xlm-mlm-en-2048') 95 | >>> model.eval() 96 | 97 | # Predict hidden states features for each layer 98 | >>> with torch.no_grad(): 99 | hidden_states_1, mems = model(tokens_tensor_1) 100 | hidden_states_2, mems = model(tokens_tensor_2, past=mems) 101 | """ 102 | model = XLMModel.from_pretrained(*args, **kwargs) 103 | return model 104 | 105 | 106 | @_begin_with_docstring(xlm_start_docstring) 107 | @_end_with_docstring(xlm_end_docstring) 108 | def xlmLMHeadModel(*args, **kwargs): 109 | """ 110 | # Prepare tokenized input 111 | >>> text_1 = "Who was Jim Henson ?" 112 | >>> text_2 = "Jim Henson was a puppeteer" 113 | >>> indexed_tokens_1 = tokenizer.encode(text_1) 114 | >>> indexed_tokens_2 = tokenizer.encode(text_2) 115 | >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 116 | >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 117 | 118 | # Load xlnetLMHeadModel 119 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetLMHeadModel', 'xlm-mlm-en-2048') 120 | >>> model.eval() 121 | 122 | # Predict hidden states features for each layer 123 | >>> with torch.no_grad(): 124 | predictions_1, mems = model(tokens_tensor_1) 125 | predictions_2, mems = model(tokens_tensor_2, mems=mems) 126 | 127 | # Get the predicted last token 128 | >>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 129 | >>> predicted_token = tokenizer.decode([predicted_index]) 130 | >>> assert predicted_token == ' who' 131 | """ 132 | model = XLMWithLMHeadModel.from_pretrained(*args, **kwargs) 133 | return model 134 | 135 | 136 | # @_end_with_docstring(xlnet_docstring) 137 | # def xlnetForSequenceClassification(*args, **kwargs): 138 | # """ 139 | # xlnetModel is the basic XLNet Transformer model from 140 | # "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 141 | # by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 142 | 143 | # Example: 144 | # # Load the tokenizer 145 | # >>> import torch 146 | # >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlm-mlm-en-2048') 147 | 148 | # # Prepare tokenized input 149 | # >>> text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 150 | # >>> text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 151 | # >>> tokenized_text1 = tokenizer.tokenize(text1) 152 | # >>> tokenized_text2 = tokenizer.tokenize(text2) 153 | # >>> indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 154 | # >>> indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 155 | # >>> tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 156 | # >>> mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 157 | 158 | # # Load xlnetForSequenceClassification 159 | # >>> model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetForSequenceClassification', 'xlm-mlm-en-2048') 160 | # >>> model.eval() 161 | 162 | # # Predict sequence classes logits 163 | # >>> with torch.no_grad(): 164 | # lm_logits, mems = model(tokens_tensor) 165 | # """ 166 | # model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs) 167 | # return model 168 | -------------------------------------------------------------------------------- /hubconfs/xlnet_hubconf.1.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_xlnet import XLNetTokenizer 2 | from pytorch_transformers.modeling_xlnet import ( 3 | XLNetConfig, 4 | XLNetModel, 5 | XLNetLMHeadModel, 6 | # XLNetForSequenceClassification 7 | ) 8 | 9 | # A lot of models share the same param doc. Use a decorator 10 | # to save typing 11 | xlnet_docstring = """ 12 | Params: 13 | pretrained_model_name_or_path: either: 14 | - a str with the name of a pre-trained model to load selected in the list of: 15 | . `xlnet-large-cased` 16 | - a path or url to a pretrained model archive containing: 17 | . `config.json` a configuration file for the model 18 | . `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance 19 | - a path or url to a pretrained model archive containing: 20 | . `xlnet_config.json` a configuration file for the model 21 | . `model.chkpt` a TensorFlow checkpoint 22 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 23 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 24 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 25 | *inputs, **kwargs: additional input for the specific XLNet class 26 | """ 27 | 28 | 29 | def _append_from_pretrained_docstring(docstr): 30 | def docstring_decorator(fn): 31 | fn.__doc__ = fn.__doc__ + docstr 32 | return fn 33 | return docstring_decorator 34 | 35 | 36 | def xlnetTokenizer(*args, **kwargs): 37 | """ 38 | Instantiate a XLNet sentencepiece tokenizer for XLNet from a pre-trained vocab file. 39 | Peculiarities: 40 | - require Google sentencepiece (https://github.com/google/sentencepiece) 41 | 42 | Args: 43 | pretrained_model_name_or_path: Path to pretrained model archive 44 | or one of pre-trained vocab configs below. 45 | * xlnet-large-cased 46 | Keyword args: 47 | special_tokens: Special tokens in vocabulary that are not pretrained 48 | Default: None 49 | max_len: An artificial maximum length to truncate tokenized sequences to; 50 | Effective maximum length is always the minimum of this 51 | value (if specified) and the underlying model's 52 | sequence length. 53 | Default: None 54 | 55 | Example: 56 | >>> import torch 57 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 58 | 59 | >>> text = "Who was Jim Henson ?" 60 | >>> indexed_tokens = tokenizer.encode(tokenized_text) 61 | """ 62 | tokenizer = XLNetTokenizer.from_pretrained(*args, **kwargs) 63 | return tokenizer 64 | 65 | 66 | @_append_from_pretrained_docstring(xlnet_docstring) 67 | def xlnetModel(*args, **kwargs): 68 | """ 69 | xlnetModel is the basic XLNet Transformer model from 70 | "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 71 | by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 72 | 73 | Example: 74 | # Load the tokenizer 75 | >>> import torch 76 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 77 | 78 | # Prepare tokenized input 79 | >>> text_1 = "Who was Jim Henson ?" 80 | >>> text_2 = "Jim Henson was a puppeteer" 81 | >>> indexed_tokens_1 = tokenizer.encode(text_1) 82 | >>> indexed_tokens_2 = tokenizer.encode(text_2) 83 | >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 84 | >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 85 | 86 | # Load xlnetModel 87 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetModel', 'xlnet-large-cased') 88 | >>> model.eval() 89 | 90 | # Predict hidden states features for each layer 91 | >>> with torch.no_grad(): 92 | hidden_states_1, mems = model(tokens_tensor_1) 93 | hidden_states_2, mems = model(tokens_tensor_2, past=mems) 94 | """ 95 | model = XLNetModel.from_pretrained(*args, **kwargs) 96 | return model 97 | 98 | 99 | @_append_from_pretrained_docstring(xlnet_docstring) 100 | def xlnetLMHeadModel(*args, **kwargs): 101 | """ 102 | xlnetModel is the basic XLNet Transformer model from 103 | "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 104 | by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 105 | with a tied (pre-trained) language modeling head on top. 106 | 107 | Example: 108 | # Load the tokenizer 109 | >>> import torch 110 | >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 111 | 112 | # Prepare tokenized input 113 | >>> text_1 = "Who was Jim Henson ?" 114 | >>> text_2 = "Jim Henson was a puppeteer" 115 | >>> indexed_tokens_1 = tokenizer.encode(text_1) 116 | >>> indexed_tokens_2 = tokenizer.encode(text_2) 117 | >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 118 | >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 119 | 120 | # Load xlnetLMHeadModel 121 | >>> model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetLMHeadModel', 'xlnet-large-cased') 122 | >>> model.eval() 123 | 124 | # Predict hidden states features for each layer 125 | >>> with torch.no_grad(): 126 | predictions_1, mems = model(tokens_tensor_1) 127 | predictions_2, mems = model(tokens_tensor_2, mems=mems) 128 | 129 | # Get the predicted last token 130 | >>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 131 | >>> predicted_token = tokenizer.decode([predicted_index]) 132 | >>> assert predicted_token == ' who' 133 | """ 134 | model = XLNetLMHeadModel.from_pretrained(*args, **kwargs) 135 | return model 136 | 137 | 138 | # @_append_from_pretrained_docstring(xlnet_docstring) 139 | # def xlnetForSequenceClassification(*args, **kwargs): 140 | # """ 141 | # xlnetModel is the basic XLNet Transformer model from 142 | # "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 143 | # by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 144 | 145 | # Example: 146 | # # Load the tokenizer 147 | # >>> import torch 148 | # >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 149 | 150 | # # Prepare tokenized input 151 | # >>> text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 152 | # >>> text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 153 | # >>> tokenized_text1 = tokenizer.tokenize(text1) 154 | # >>> tokenized_text2 = tokenizer.tokenize(text2) 155 | # >>> indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 156 | # >>> indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 157 | # >>> tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 158 | # >>> mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 159 | 160 | # # Load xlnetForSequenceClassification 161 | # >>> model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetForSequenceClassification', 'xlnet-large-cased') 162 | # >>> model.eval() 163 | 164 | # # Predict sequence classes logits 165 | # >>> with torch.no_grad(): 166 | # lm_logits, mems = model(tokens_tensor) 167 | # """ 168 | # model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs) 169 | # return model 170 | -------------------------------------------------------------------------------- /pytorch_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE 7 | from .tokenization_xlm import XLMTokenizer 8 | from .tokenization_utils import (PreTrainedTokenizer, clean_up_tokenization) 9 | 10 | from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, 11 | BertForMaskedLM, BertForNextSentencePrediction, 12 | BertForSequenceClassification, BertForMultipleChoice, 13 | BertForTokenClassification, BertForQuestionAnswering, 14 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 15 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) 16 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 17 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 18 | load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, 19 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) 20 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 21 | load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, 22 | TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) 23 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 24 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 25 | load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, 26 | GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) 27 | from .modeling_xlnet import (XLNetConfig, 28 | XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, 29 | XLNetForSequenceClassification, XLNetForQuestionAnswering, 30 | load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, 31 | XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) 32 | from .modeling_xlm import (XLMConfig, XLMModel, 33 | XLMWithLMHeadModel, XLMForSequenceClassification, 34 | XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, 35 | XLM_PRETRAINED_MODEL_ARCHIVE_MAP) 36 | from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, 37 | PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) 38 | 39 | from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, 40 | WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 41 | 42 | from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) 43 | -------------------------------------------------------------------------------- /pytorch_transformers/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: 5 | print( 6 | "Should be used as one of: \n" 7 | ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" 8 | ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" 9 | ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" 10 | ">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" 11 | ">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n" 12 | ">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT") 13 | else: 14 | if sys.argv[1] == "bert": 15 | try: 16 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 17 | except ImportError: 18 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 19 | "In that case, it requires TensorFlow to be installed. Please see " 20 | "https://www.tensorflow.org/install/ for installation instructions.") 21 | raise 22 | 23 | if len(sys.argv) != 5: 24 | # pylint: disable=line-too-long 25 | print("Should be used as `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 26 | else: 27 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 28 | TF_CONFIG = sys.argv.pop() 29 | TF_CHECKPOINT = sys.argv.pop() 30 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 31 | elif sys.argv[1] == "gpt": 32 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 33 | if len(sys.argv) < 4 or len(sys.argv) > 5: 34 | # pylint: disable=line-too-long 35 | print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") 36 | else: 37 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 38 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 39 | if len(sys.argv) == 5: 40 | OPENAI_GPT_CONFIG = sys.argv[4] 41 | else: 42 | OPENAI_GPT_CONFIG = "" 43 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 44 | OPENAI_GPT_CONFIG, 45 | PYTORCH_DUMP_OUTPUT) 46 | elif sys.argv[1] == "transfo_xl": 47 | try: 48 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 49 | except ImportError: 50 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 51 | "In that case, it requires TensorFlow to be installed. Please see " 52 | "https://www.tensorflow.org/install/ for installation instructions.") 53 | raise 54 | if len(sys.argv) < 4 or len(sys.argv) > 5: 55 | # pylint: disable=line-too-long 56 | print("Should be used as `pytorch_transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 57 | else: 58 | if 'ckpt' in sys.argv[2].lower(): 59 | TF_CHECKPOINT = sys.argv[2] 60 | TF_DATASET_FILE = "" 61 | else: 62 | TF_DATASET_FILE = sys.argv[2] 63 | TF_CHECKPOINT = "" 64 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 65 | if len(sys.argv) == 5: 66 | TF_CONFIG = sys.argv[4] 67 | else: 68 | TF_CONFIG = "" 69 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 70 | elif sys.argv[1] == "gpt2": 71 | try: 72 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 73 | except ImportError: 74 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 75 | "In that case, it requires TensorFlow to be installed. Please see " 76 | "https://www.tensorflow.org/install/ for installation instructions.") 77 | raise 78 | 79 | if len(sys.argv) < 4 or len(sys.argv) > 5: 80 | # pylint: disable=line-too-long 81 | print("Should be used as `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 82 | else: 83 | TF_CHECKPOINT = sys.argv[2] 84 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 85 | if len(sys.argv) == 5: 86 | TF_CONFIG = sys.argv[4] 87 | else: 88 | TF_CONFIG = "" 89 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 90 | elif sys.argv[1] == "xlnet": 91 | try: 92 | from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch 93 | except ImportError: 94 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 95 | "In that case, it requires TensorFlow to be installed. Please see " 96 | "https://www.tensorflow.org/install/ for installation instructions.") 97 | raise 98 | 99 | if len(sys.argv) < 5 or len(sys.argv) > 6: 100 | # pylint: disable=line-too-long 101 | print("Should be used as `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") 102 | else: 103 | TF_CHECKPOINT = sys.argv[2] 104 | TF_CONFIG = sys.argv[3] 105 | PYTORCH_DUMP_OUTPUT = sys.argv[4] 106 | if len(sys.argv) == 6: 107 | FINETUNING_TASK = sys.argv[5] 108 | else: 109 | FINETUNING_TASK = None 110 | 111 | convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT, 112 | TF_CONFIG, 113 | PYTORCH_DUMP_OUTPUT, 114 | FINETUNING_TASK) 115 | elif sys.argv[1] == "xlm": 116 | from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch 117 | 118 | if len(sys.argv) != 4: 119 | # pylint: disable=line-too-long 120 | print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`") 121 | else: 122 | XLM_CHECKPOINT_PATH = sys.argv[2] 123 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 124 | 125 | convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT) 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if gpt2_config_file == "": 36 | config = GPT2Config() 37 | else: 38 | config = GPT2Config(gpt2_config_file) 39 | model = GPT2Model(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--gpt2_checkpoint_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--gpt2_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 74 | args.gpt2_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if openai_config_file == "": 36 | config = OpenAIGPTConfig() 37 | else: 38 | config = OpenAIGPTConfig(openai_config_file) 39 | model = OpenAIGPTModel(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--openai_checkpoint_folder_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--openai_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 74 | args.openai_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_pytorch_checkpoint_to_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from pytorch_pretrained_bert.modeling import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transopse = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | session = tf.Session() 66 | state_dict = model.state_dict() 67 | tf_vars = [] 68 | 69 | def to_tf_var_name(name:str): 70 | for patt, repl in iter(var_map): 71 | name = name.replace(patt, repl) 72 | return 'bert/{}'.format(name) 73 | 74 | def assign_tf_var(tensor:np.ndarray, name:str): 75 | tmp_var = tf.Variable(initial_value=tensor) 76 | tf_var = tf.get_variable(dtype=tmp_var.dtype, shape=tmp_var.shape, name=name) 77 | op = tf.assign(ref=tf_var, value=tmp_var) 78 | session.run(tf.variables_initializer([tmp_var, tf_var])) 79 | session.run(fetches=[op, tf_var]) 80 | return tf_var 81 | 82 | for var_name in state_dict: 83 | tf_name = to_tf_var_name(var_name) 84 | torch_tensor = state_dict[var_name].numpy() 85 | if any([x in var_name for x in tensors_to_transopse]): 86 | torch_tensor = torch_tensor.T 87 | tf_tensor = assign_tf_var(tensor=torch_tensor, name=tf_name) 88 | tf_vars.append(tf_tensor) 89 | print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name)))) 90 | 91 | saver = tf.train.Saver(tf_vars) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_transformers.tokenization_transfo_xl as data_utils 27 | from pytorch_transformers.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) 33 | 34 | if sys.version_info[0] == 2: 35 | import cPickle as pickle 36 | else: 37 | import pickle 38 | 39 | import logging 40 | logging.basicConfig(level=logging.INFO) 41 | 42 | # We do this to be able to load python 2 datasets pickles 43 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 44 | data_utils.Vocab = data_utils.TransfoXLTokenizer 45 | data_utils.Corpus = data_utils.TransfoXLCorpus 46 | sys.modules['data_utils'] = data_utils 47 | sys.modules['vocabulary'] = data_utils 48 | 49 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 50 | transfo_xl_config_file, 51 | pytorch_dump_folder_path, 52 | transfo_xl_dataset_file): 53 | if transfo_xl_dataset_file: 54 | # Convert a pre-processed corpus (see original TensorFlow repo) 55 | with open(transfo_xl_dataset_file, "rb") as fp: 56 | corpus = pickle.load(fp, encoding="latin1") 57 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 58 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] 59 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 60 | corpus_vocab_dict = corpus.vocab.__dict__ 61 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 62 | 63 | corpus_dict_no_vocab = corpus.__dict__ 64 | corpus_dict_no_vocab.pop('vocab', None) 65 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 66 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 67 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 68 | 69 | if tf_checkpoint_path: 70 | # Convert a pre-trained TensorFlow model 71 | config_path = os.path.abspath(transfo_xl_config_file) 72 | tf_path = os.path.abspath(tf_checkpoint_path) 73 | 74 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 75 | # Initialise PyTorch model 76 | if transfo_xl_config_file == "": 77 | config = TransfoXLConfig() 78 | else: 79 | config = TransfoXLConfig(transfo_xl_config_file) 80 | print("Building PyTorch model from configuration: {}".format(str(config))) 81 | model = TransfoXLLMHeadModel(config) 82 | 83 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 84 | # Save pytorch-model 85 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 86 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 87 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 88 | torch.save(model.state_dict(), pytorch_weights_dump_path) 89 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 90 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 91 | f.write(config.to_json_string()) 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument("--pytorch_dump_folder_path", 97 | default = None, 98 | type = str, 99 | required = True, 100 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 101 | parser.add_argument("--tf_checkpoint_path", 102 | default = "", 103 | type = str, 104 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 105 | parser.add_argument("--transfo_xl_config_file", 106 | default = "", 107 | type = str, 108 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 109 | "This specifies the model architecture.") 110 | parser.add_argument("--transfo_xl_dataset_file", 111 | default = "", 112 | type = str, 113 | help = "An optional dataset file to be converted in a vocabulary.") 114 | args = parser.parse_args() 115 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 116 | args.transfo_xl_config_file, 117 | args.pytorch_dump_folder_path, 118 | args.transfo_xl_dataset_file) 119 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import json 21 | from io import open 22 | 23 | import torch 24 | import numpy 25 | 26 | from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME 27 | from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') 35 | 36 | model = chkpt['model'] 37 | 38 | config = chkpt['params'] 39 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.Tensor, numpy.ndarray))) 40 | 41 | vocab = chkpt['dico_word2id'] 42 | vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] 48 | 49 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 50 | torch.save(model, pytorch_weights_dump_path) 51 | 52 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 53 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 54 | f.write(json.dumps(config, indent=2) + "\n") 55 | 56 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 57 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 58 | f.write(json.dumps(vocab, indent=2) + "\n") 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | ## Required parameters 64 | parser.add_argument("--xlm_checkpoint_path", 65 | default = None, 66 | type = str, 67 | required = True, 68 | help = "Path the official PyTorch dump.") 69 | parser.add_argument("--pytorch_dump_folder_path", 70 | default = None, 71 | type = str, 72 | required = True, 73 | help = "Path to the output PyTorch model.") 74 | args = parser.parse_args() 75 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import torch 24 | 25 | from pytorch_transformers.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME, 26 | XLNetConfig, 27 | XLNetLMHeadModel, XLNetForQuestionAnswering, 28 | XLNetForSequenceClassification, 29 | load_tf_weights_in_xlnet) 30 | 31 | GLUE_TASKS_NUM_LABELS = { 32 | "cola": 2, 33 | "mnli": 3, 34 | "mrpc": 2, 35 | "sst-2": 2, 36 | "sts-b": 1, 37 | "qqp": 2, 38 | "qnli": 2, 39 | "rte": 2, 40 | "wnli": 2, 41 | } 42 | 43 | import logging 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): 47 | # Initialise PyTorch model 48 | config = XLNetConfig.from_json_file(bert_config_file) 49 | 50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 51 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 53 | config.finetuning_task = finetuning_task 54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 55 | model = XLNetForSequenceClassification(config) 56 | elif 'squad' in finetuning_task: 57 | config.finetuning_task = finetuning_task 58 | model = XLNetForQuestionAnswering(config) 59 | else: 60 | model = XLNetLMHeadModel(config) 61 | 62 | # Load weights from tf checkpoint 63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 64 | 65 | # Save pytorch-model 66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 69 | torch.save(model.state_dict(), pytorch_weights_dump_path) 70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 72 | f.write(config.to_json_string()) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | ## Required parameters 78 | parser.add_argument("--tf_checkpoint_path", 79 | default = None, 80 | type = str, 81 | required = True, 82 | help = "Path the TensorFlow checkpoint path.") 83 | parser.add_argument("--xlnet_config_file", 84 | default = None, 85 | type = str, 86 | required = True, 87 | help = "The config json file corresponding to the pre-trained XLNet model. \n" 88 | "This specifies the model architecture.") 89 | parser.add_argument("--pytorch_dump_folder_path", 90 | default = None, 91 | type = str, 92 | required = True, 93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 94 | parser.add_argument("--finetuning_task", 95 | default = None, 96 | type = str, 97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") 98 | args = parser.parse_args() 99 | print(args) 100 | 101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, 102 | args.xlnet_config_file, 103 | args.pytorch_dump_folder_path, 104 | args.finetuning_task) 105 | -------------------------------------------------------------------------------- /pytorch_transformers/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | from io import open 18 | 19 | import boto3 20 | import requests 21 | from botocore.exceptions import ClientError 22 | from tqdm import tqdm 23 | 24 | try: 25 | from torch.hub import _get_torch_home 26 | torch_cache_home = _get_torch_home() 27 | except ImportError: 28 | torch_cache_home = os.path.expanduser( 29 | os.getenv('TORCH_HOME', os.path.join( 30 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 31 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers') 32 | 33 | try: 34 | from urllib.parse import urlparse 35 | except ImportError: 36 | from urlparse import urlparse 37 | 38 | try: 39 | from pathlib import Path 40 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 41 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) 42 | except (AttributeError, ImportError): 43 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 44 | default_cache_path) 45 | 46 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 47 | 48 | 49 | def url_to_filename(url, etag=None): 50 | """ 51 | Convert `url` into a hashed filename in a repeatable way. 52 | If `etag` is specified, append its hash to the url's, delimited 53 | by a period. 54 | """ 55 | url_bytes = url.encode('utf-8') 56 | url_hash = sha256(url_bytes) 57 | filename = url_hash.hexdigest() 58 | 59 | if etag: 60 | etag_bytes = etag.encode('utf-8') 61 | etag_hash = sha256(etag_bytes) 62 | filename += '.' + etag_hash.hexdigest() 63 | 64 | return filename 65 | 66 | 67 | def filename_to_url(filename, cache_dir=None): 68 | """ 69 | Return the url and etag (which may be ``None``) stored for `filename`. 70 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 71 | """ 72 | if cache_dir is None: 73 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 74 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 75 | cache_dir = str(cache_dir) 76 | 77 | cache_path = os.path.join(cache_dir, filename) 78 | if not os.path.exists(cache_path): 79 | raise EnvironmentError("file {} not found".format(cache_path)) 80 | 81 | meta_path = cache_path + '.json' 82 | if not os.path.exists(meta_path): 83 | raise EnvironmentError("file {} not found".format(meta_path)) 84 | 85 | with open(meta_path, encoding="utf-8") as meta_file: 86 | metadata = json.load(meta_file) 87 | url = metadata['url'] 88 | etag = metadata['etag'] 89 | 90 | return url, etag 91 | 92 | 93 | def cached_path(url_or_filename, cache_dir=None): 94 | """ 95 | Given something that might be a URL (or might be a local path), 96 | determine which. If it's a URL, download the file and cache it, and 97 | return the path to the cached file. If it's already a local path, 98 | make sure the file exists and then return the path. 99 | """ 100 | if cache_dir is None: 101 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 102 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 103 | url_or_filename = str(url_or_filename) 104 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 105 | cache_dir = str(cache_dir) 106 | 107 | parsed = urlparse(url_or_filename) 108 | 109 | if parsed.scheme in ('http', 'https', 's3'): 110 | # URL, so get it from the cache (downloading if necessary) 111 | return get_from_cache(url_or_filename, cache_dir) 112 | elif os.path.exists(url_or_filename): 113 | # File, and it exists. 114 | return url_or_filename 115 | elif parsed.scheme == '': 116 | # File, but it doesn't exist. 117 | raise EnvironmentError("file {} not found".format(url_or_filename)) 118 | else: 119 | # Something unknown 120 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 121 | 122 | 123 | def split_s3_path(url): 124 | """Split a full s3 path into the bucket name and path.""" 125 | parsed = urlparse(url) 126 | if not parsed.netloc or not parsed.path: 127 | raise ValueError("bad s3 path {}".format(url)) 128 | bucket_name = parsed.netloc 129 | s3_path = parsed.path 130 | # Remove '/' at beginning of path. 131 | if s3_path.startswith("/"): 132 | s3_path = s3_path[1:] 133 | return bucket_name, s3_path 134 | 135 | 136 | def s3_request(func): 137 | """ 138 | Wrapper function for s3 requests in order to create more helpful error 139 | messages. 140 | """ 141 | 142 | @wraps(func) 143 | def wrapper(url, *args, **kwargs): 144 | try: 145 | return func(url, *args, **kwargs) 146 | except ClientError as exc: 147 | if int(exc.response["Error"]["Code"]) == 404: 148 | raise EnvironmentError("file {} not found".format(url)) 149 | else: 150 | raise 151 | 152 | return wrapper 153 | 154 | 155 | @s3_request 156 | def s3_etag(url): 157 | """Check ETag on S3 object.""" 158 | s3_resource = boto3.resource("s3") 159 | bucket_name, s3_path = split_s3_path(url) 160 | s3_object = s3_resource.Object(bucket_name, s3_path) 161 | return s3_object.e_tag 162 | 163 | 164 | @s3_request 165 | def s3_get(url, temp_file): 166 | """Pull a file directly from S3.""" 167 | s3_resource = boto3.resource("s3") 168 | bucket_name, s3_path = split_s3_path(url) 169 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 170 | 171 | 172 | def http_get(url, temp_file): 173 | req = requests.get(url, stream=True) 174 | content_length = req.headers.get('Content-Length') 175 | total = int(content_length) if content_length is not None else None 176 | progress = tqdm(unit="B", total=total) 177 | for chunk in req.iter_content(chunk_size=1024): 178 | if chunk: # filter out keep-alive new chunks 179 | progress.update(len(chunk)) 180 | temp_file.write(chunk) 181 | progress.close() 182 | 183 | 184 | def get_from_cache(url, cache_dir=None): 185 | """ 186 | Given a URL, look for the corresponding dataset in the local cache. 187 | If it's not there, download it. Then return the path to the cached file. 188 | """ 189 | if cache_dir is None: 190 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 191 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 192 | cache_dir = str(cache_dir) 193 | if sys.version_info[0] == 2 and not isinstance(cache_dir, str): 194 | cache_dir = str(cache_dir) 195 | 196 | if not os.path.exists(cache_dir): 197 | os.makedirs(cache_dir) 198 | 199 | # Get eTag to add to filename, if it exists. 200 | if url.startswith("s3://"): 201 | etag = s3_etag(url) 202 | else: 203 | try: 204 | response = requests.head(url, allow_redirects=True) 205 | if response.status_code != 200: 206 | etag = None 207 | else: 208 | etag = response.headers.get("ETag") 209 | except EnvironmentError: 210 | etag = None 211 | 212 | if sys.version_info[0] == 2 and etag is not None: 213 | etag = etag.decode('utf-8') 214 | filename = url_to_filename(url, etag) 215 | 216 | # get cache path to put the file 217 | cache_path = os.path.join(cache_dir, filename) 218 | 219 | # If we don't have a connection (etag is None) and can't identify the file 220 | # try to get the last downloaded one 221 | if not os.path.exists(cache_path) and etag is None: 222 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 223 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 224 | if matching_files: 225 | cache_path = os.path.join(cache_dir, matching_files[-1]) 226 | 227 | if not os.path.exists(cache_path): 228 | # Download to temporary file, then copy to cache dir once finished. 229 | # Otherwise you get corrupt cache entries if the download gets interrupted. 230 | with tempfile.NamedTemporaryFile() as temp_file: 231 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 232 | 233 | # GET file object 234 | if url.startswith("s3://"): 235 | s3_get(url, temp_file) 236 | else: 237 | http_get(url, temp_file) 238 | 239 | # we are copying the file before closing it, so flush to avoid truncation 240 | temp_file.flush() 241 | # shutil.copyfileobj() starts at the current position, so go to the start 242 | temp_file.seek(0) 243 | 244 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 245 | with open(cache_path, 'wb') as cache_file: 246 | shutil.copyfileobj(temp_file, cache_file) 247 | 248 | logger.info("creating metadata file for %s", cache_path) 249 | meta = {'url': url, 'etag': etag} 250 | meta_path = cache_path + '.json' 251 | with open(meta_path, 'w') as meta_file: 252 | output_string = json.dumps(meta) 253 | if sys.version_info[0] == 2 and isinstance(output_string, str): 254 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 255 | meta_file.write(output_string) 256 | 257 | logger.info("removing temp file %s", temp_file.name) 258 | 259 | return cache_path 260 | -------------------------------------------------------------------------------- /pytorch_transformers/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class ConstantLRSchedule(LambdaLR): 27 | """ Constant learning rate schedule. 28 | """ 29 | def __init__(self, optimizer, last_epoch=-1): 30 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 31 | 32 | 33 | class WarmupConstantSchedule(LambdaLR): 34 | """ Linear warmup and then constant. 35 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 36 | Keeps learning rate schedule equal to 1. after warmup_steps. 37 | """ 38 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 39 | 40 | def lr_lambda(step): 41 | if step < warmup_steps: 42 | return float(step) / float(max(1.0, warmup_steps)) 43 | return 1. 44 | 45 | super(WarmupConstantSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch) 46 | 47 | 48 | class WarmupLinearSchedule(LambdaLR): 49 | """ Linear warmup and then linear decay. 50 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 51 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 52 | """ 53 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 54 | 55 | def lr_lambda(step): 56 | if step < warmup_steps: 57 | return float(step) / float(max(1, warmup_steps)) 58 | return max(0.0, float(t_total - step) / float(max(1.0, t_total - warmup_steps))) 59 | 60 | super(WarmupLinearSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch) 61 | 62 | 63 | class WarmupCosineSchedule(LambdaLR): 64 | """ Linear warmup and then cosine decay. 65 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 66 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 67 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 68 | """ 69 | warn_t_total = True 70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 71 | 72 | def lr_lambda(step): 73 | if step < warmup_steps: 74 | return float(step) / float(max(1.0, warmup_steps)) 75 | else: 76 | progress = float(step - warmup_steps) / float(max(1, t_total - warmup_steps)) # progress after warmup 77 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(cycles) * 2.0 * progress))) 78 | 79 | super(WarmupCosineSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch) 80 | 81 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 82 | """ Linear warmup and then cosine cycles with hard restarts. 83 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 84 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 85 | learning rate (with hard restarts). 86 | """ 87 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 88 | 89 | def lr_lambda(step): 90 | if step < warmup_steps: 91 | return float(step) / float(max(1, warmup_steps)) 92 | else: 93 | progress = float(step - warmup_steps) / float(max(1, t_total - warmup_steps)) # progress after warmup 94 | if progress >= 1.0: 95 | return 0.0 96 | return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(cycles) * progress) % 1.0)))) 97 | 98 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, lr_lambda, last_epoch=last_epoch) 99 | 100 | 101 | class AdamW(Optimizer): 102 | """ Implements Adam algorithm with weight decay fix. 103 | 104 | Parameters: 105 | lr (float): learning rate. Default 1e-3. 106 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 107 | eps (float): Adams epsilon. Default: 1e-6 108 | weight_decay (float): Weight decay. Default: 0.0 109 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 110 | """ 111 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 112 | if lr < 0.0: 113 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 114 | if not 0.0 <= betas[0] < 1.0: 115 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 116 | if not 0.0 <= betas[1] < 1.0: 117 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 118 | if not 0.0 <= eps: 119 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 120 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 121 | correct_bias=correct_bias) 122 | super(AdamW, self).__init__(params, defaults) 123 | 124 | def step(self, closure=None): 125 | """Performs a single optimization step. 126 | 127 | Arguments: 128 | closure (callable, optional): A closure that reevaluates the model 129 | and returns the loss. 130 | """ 131 | loss = None 132 | if closure is not None: 133 | loss = closure() 134 | 135 | for group in self.param_groups: 136 | for p in group['params']: 137 | if p.grad is None: 138 | continue 139 | grad = p.grad.data 140 | if grad.is_sparse: 141 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 142 | 143 | state = self.state[p] 144 | 145 | # State initialization 146 | if len(state) == 0: 147 | state['step'] = 0 148 | # Exponential moving average of gradient values 149 | state['exp_avg'] = torch.zeros_like(p.data) 150 | # Exponential moving average of squared gradient values 151 | state['exp_avg_sq'] = torch.zeros_like(p.data) 152 | 153 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 154 | beta1, beta2 = group['betas'] 155 | 156 | state['step'] += 1 157 | 158 | # Decay the first and second moment running average coefficient 159 | # In-place operations to update the averages at the same time 160 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 161 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 162 | denom = exp_avg_sq.sqrt().add_(group['eps']) 163 | 164 | step_size = group['lr'] 165 | if group['correct_bias']: # No bias correction for Bert 166 | bias_correction1 = 1.0 - beta1 ** state['step'] 167 | bias_correction2 = 1.0 - beta2 ** state['step'] 168 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 169 | 170 | p.data.addcdiv_(-step_size, exp_avg, denom) 171 | 172 | # Just adding the square of the weights to the loss function is *not* 173 | # the correct way of using L2 regularization/weight decay with Adam, 174 | # since that will interact with the m and v parameters in strange ways. 175 | # 176 | # Instead we want to decay the weights in a manner that doesn't interact 177 | # with the m/v parameters. This is equivalent to adding the square 178 | # of the weights to the loss with plain (non-momentum) SGD. 179 | # Add weight decay at the end (fixed version) 180 | if group['weight_decay'] > 0.0: 181 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 182 | 183 | return loss 184 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heyangHEY/BERT-CRF/b869ca749a44647ddd6bac09b59f42fd55d0884c/pytorch_transformers/tests/__init__.py -------------------------------------------------------------------------------- /pytorch_transformers/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | if config.getoption("--runslow"): 14 | # --runslow given in cli: do not skip slow tests 15 | return 16 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 17 | for item in items: 18 | if "slow" in item.keywords: 19 | item.add_marker(skip_slow) 20 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heyangHEY/BERT-CRF/b869ca749a44647ddd6bac09b59f42fd55d0884c/pytorch_transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import pytest 21 | 22 | 23 | from pytorch_transformers import (GPT2Config, GPT2Model, 24 | GPT2LMHeadModel, GPT2DoubleHeadsModel) 25 | 26 | from .modeling_common_test import CommonTestCases, ConfigTester 27 | 28 | class GPT2ModelTest(unittest.TestCase): 29 | 30 | def test_config(self): 31 | config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37) 32 | config_tester.run_common_tests() 33 | 34 | def test_model(self): 35 | model_tester = CommonTestCases.GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model, 36 | lm_head_model_class=GPT2LMHeadModel, 37 | double_head_model_class=GPT2DoubleHeadsModel) 38 | model_tester.run_common_tests(test_presents=True) 39 | 40 | @pytest.mark.slow 41 | def test_pretrained(self): 42 | model_tester = CommonTestCases.GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model, 43 | lm_head_model_class=GPT2LMHeadModel, 44 | double_head_model_class=GPT2DoubleHeadsModel) 45 | model_tester.run_slow_tests() 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import pytest 21 | 22 | 23 | from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, 24 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) 25 | 26 | from .modeling_common_test import CommonTestCases, ConfigTester 27 | 28 | class OpenAIModelTest(unittest.TestCase): 29 | 30 | def test_config(self): 31 | config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37) 32 | config_tester.run_common_tests() 33 | 34 | def test_model(self): 35 | model_tester = CommonTestCases.GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel, 36 | lm_head_model_class=OpenAIGPTLMHeadModel, 37 | double_head_model_class=OpenAIGPTDoubleHeadsModel) 38 | model_tester.run_common_tests(test_presents=False) 39 | 40 | @pytest.mark.slow 41 | def test_pretrained(self): 42 | model_tester = CommonTestCases.GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel, 43 | lm_head_model_class=OpenAIGPTLMHeadModel, 44 | double_head_model_class=OpenAIGPTDoubleHeadsModel) 45 | model_tester.run_slow_tests() 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import unittest 21 | import json 22 | import random 23 | import shutil 24 | import pytest 25 | 26 | import torch 27 | 28 | from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) 29 | from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP 30 | 31 | from .modeling_common_test import ConfigTester, CommonTestCases, ids_tensor 32 | 33 | class TransfoXLModelTest(CommonTestCases.CommonModelTester): 34 | 35 | all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) 36 | test_pruning = False 37 | test_torchscript = False 38 | test_resize_embeddings = False 39 | 40 | class TransfoXLModelTester(object): 41 | 42 | def __init__(self, 43 | parent, 44 | batch_size=13, 45 | seq_length=7, 46 | mem_len=30, 47 | clamp_len=15, 48 | is_training=True, 49 | use_labels=True, 50 | vocab_size=99, 51 | cutoffs=[10, 50, 80], 52 | hidden_size=32, 53 | d_embed=32, 54 | num_attention_heads=4, 55 | d_head=8, 56 | d_inner=128, 57 | div_val=2, 58 | num_hidden_layers=5, 59 | scope=None, 60 | seed=1, 61 | ): 62 | self.parent = parent 63 | self.batch_size = batch_size 64 | self.seq_length = seq_length 65 | self.mem_len = mem_len 66 | self.key_len = seq_length + mem_len 67 | self.clamp_len = clamp_len 68 | self.is_training = is_training 69 | self.use_labels = use_labels 70 | self.vocab_size = vocab_size 71 | self.cutoffs = cutoffs 72 | self.hidden_size = hidden_size 73 | self.d_embed = d_embed 74 | self.num_attention_heads = num_attention_heads 75 | self.d_head = d_head 76 | self.d_inner = d_inner 77 | self.div_val = div_val 78 | self.num_hidden_layers = num_hidden_layers 79 | self.scope = scope 80 | self.seed = seed 81 | 82 | def prepare_config_and_inputs(self): 83 | input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 84 | input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 85 | 86 | lm_labels = None 87 | if self.use_labels: 88 | lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 89 | 90 | config = TransfoXLConfig( 91 | vocab_size_or_config_json_file=self.vocab_size, 92 | mem_len=self.mem_len, 93 | clamp_len=self.clamp_len, 94 | cutoffs=self.cutoffs, 95 | d_model=self.hidden_size, 96 | d_embed=self.d_embed, 97 | n_head=self.num_attention_heads, 98 | d_head=self.d_head, 99 | d_inner=self.d_inner, 100 | div_val=self.div_val, 101 | n_layer=self.num_hidden_layers) 102 | 103 | return (config, input_ids_1, input_ids_2, lm_labels) 104 | 105 | def set_seed(self): 106 | random.seed(self.seed) 107 | torch.manual_seed(self.seed) 108 | 109 | def create_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels): 110 | model = TransfoXLModel(config) 111 | model.eval() 112 | 113 | hidden_states_1, mems_1 = model(input_ids_1) 114 | hidden_states_2, mems_2 = model(input_ids_2, mems_1) 115 | outputs = { 116 | "hidden_states_1": hidden_states_1, 117 | "mems_1": mems_1, 118 | "hidden_states_2": hidden_states_2, 119 | "mems_2": mems_2, 120 | } 121 | return outputs 122 | 123 | def check_transfo_xl_model_output(self, result): 124 | self.parent.assertListEqual( 125 | list(result["hidden_states_1"].size()), 126 | [self.batch_size, self.seq_length, self.hidden_size]) 127 | self.parent.assertListEqual( 128 | list(result["hidden_states_2"].size()), 129 | [self.batch_size, self.seq_length, self.hidden_size]) 130 | self.parent.assertListEqual( 131 | list(list(mem.size()) for mem in result["mems_1"]), 132 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 133 | self.parent.assertListEqual( 134 | list(list(mem.size()) for mem in result["mems_2"]), 135 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 136 | 137 | 138 | def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels): 139 | model = TransfoXLLMHeadModel(config) 140 | model.eval() 141 | 142 | lm_logits_1, mems_1 = model(input_ids_1) 143 | loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels) 144 | lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1) 145 | loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1) 146 | 147 | outputs = { 148 | "loss_1": loss_1, 149 | "mems_1": mems_1, 150 | "lm_logits_1": lm_logits_1, 151 | "loss_2": loss_2, 152 | "mems_2": mems_2, 153 | "lm_logits_2": lm_logits_2, 154 | } 155 | return outputs 156 | 157 | def check_transfo_xl_lm_head_output(self, result): 158 | self.parent.assertListEqual( 159 | list(result["loss_1"].size()), 160 | [self.batch_size, self.seq_length]) 161 | self.parent.assertListEqual( 162 | list(result["lm_logits_1"].size()), 163 | [self.batch_size, self.seq_length, self.vocab_size]) 164 | self.parent.assertListEqual( 165 | list(list(mem.size()) for mem in result["mems_1"]), 166 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 167 | 168 | self.parent.assertListEqual( 169 | list(result["loss_2"].size()), 170 | [self.batch_size, self.seq_length]) 171 | self.parent.assertListEqual( 172 | list(result["lm_logits_2"].size()), 173 | [self.batch_size, self.seq_length, self.vocab_size]) 174 | self.parent.assertListEqual( 175 | list(list(mem.size()) for mem in result["mems_2"]), 176 | [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers) 177 | 178 | def prepare_config_and_inputs_for_common(self): 179 | config_and_inputs = self.prepare_config_and_inputs() 180 | (config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs 181 | inputs_dict = {'input_ids': input_ids_1} 182 | return config, inputs_dict 183 | 184 | 185 | def setUp(self): 186 | self.model_tester = TransfoXLModelTest.TransfoXLModelTester(self) 187 | self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37) 188 | 189 | def test_config(self): 190 | self.config_tester.run_common_tests() 191 | 192 | def test_transfo_xl_model(self): 193 | self.model_tester.set_seed() 194 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 195 | output_result = self.model_tester.create_transfo_xl_model(*config_and_inputs) 196 | self.model_tester.check_transfo_xl_model_output(output_result) 197 | 198 | def test_transfo_xl_lm_head(self): 199 | self.model_tester.set_seed() 200 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 201 | output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) 202 | self.model_tester.check_transfo_xl_lm_head_output(output_result) 203 | 204 | @pytest.mark.slow 205 | def test_model_from_pretrained(self): 206 | cache_dir = "/tmp/pytorch_transformers_test/" 207 | for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 208 | model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) 209 | shutil.rmtree(cache_dir) 210 | self.assertIsNotNone(model) 211 | 212 | 213 | if __name__ == "__main__": 214 | unittest.main() 215 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | 23 | from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification) 24 | from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP 25 | 26 | from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) 27 | 28 | 29 | class XLMModelTest(CommonTestCases.CommonModelTester): 30 | 31 | all_model_classes = (XLMModel, XLMWithLMHeadModel, 32 | XLMForQuestionAnswering, XLMForSequenceClassification) 33 | # , XLMForSequenceClassification, XLMForTokenClassification), 34 | 35 | class XLMModelTester(object): 36 | 37 | def __init__(self, 38 | parent, 39 | batch_size=13, 40 | seq_length=7, 41 | is_training=True, 42 | use_input_lengths=True, 43 | use_token_type_ids=True, 44 | use_labels=True, 45 | gelu_activation=True, 46 | sinusoidal_embeddings=False, 47 | causal=False, 48 | asm=False, 49 | n_langs=2, 50 | vocab_size=99, 51 | n_special=0, 52 | hidden_size=32, 53 | num_hidden_layers=5, 54 | num_attention_heads=4, 55 | hidden_dropout_prob=0.1, 56 | attention_probs_dropout_prob=0.1, 57 | max_position_embeddings=512, 58 | type_vocab_size=16, 59 | type_sequence_label_size=2, 60 | initializer_range=0.02, 61 | num_labels=3, 62 | num_choices=4, 63 | summary_type="last", 64 | use_proj=True, 65 | scope=None, 66 | ): 67 | self.parent = parent 68 | self.batch_size = batch_size 69 | self.seq_length = seq_length 70 | self.is_training = is_training 71 | self.use_input_lengths = use_input_lengths 72 | self.use_token_type_ids = use_token_type_ids 73 | self.use_labels = use_labels 74 | self.gelu_activation = gelu_activation 75 | self.sinusoidal_embeddings = sinusoidal_embeddings 76 | self.asm = asm 77 | self.n_langs = n_langs 78 | self.vocab_size = vocab_size 79 | self.n_special = n_special 80 | self.summary_type = summary_type 81 | self.causal = causal 82 | self.use_proj = use_proj 83 | self.hidden_size = hidden_size 84 | self.num_hidden_layers = num_hidden_layers 85 | self.num_attention_heads = num_attention_heads 86 | self.hidden_dropout_prob = hidden_dropout_prob 87 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 88 | self.max_position_embeddings = max_position_embeddings 89 | self.n_langs = n_langs 90 | self.type_sequence_label_size = type_sequence_label_size 91 | self.initializer_range = initializer_range 92 | self.summary_type = summary_type 93 | self.num_labels = num_labels 94 | self.num_choices = num_choices 95 | self.scope = scope 96 | 97 | def prepare_config_and_inputs(self): 98 | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) 99 | input_mask = ids_tensor([self.batch_size, self.seq_length], 2).float() 100 | 101 | input_lengths = None 102 | if self.use_input_lengths: 103 | input_lengths = ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2 # small variation of seq_length 104 | 105 | token_type_ids = None 106 | if self.use_token_type_ids: 107 | token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.n_langs) 108 | 109 | sequence_labels = None 110 | token_labels = None 111 | is_impossible_labels = None 112 | if self.use_labels: 113 | sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) 114 | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) 115 | is_impossible_labels = ids_tensor([self.batch_size], 2).float() 116 | 117 | config = XLMConfig( 118 | vocab_size_or_config_json_file=self.vocab_size, 119 | n_special=self.n_special, 120 | emb_dim=self.hidden_size, 121 | n_layers=self.num_hidden_layers, 122 | n_heads=self.num_attention_heads, 123 | dropout=self.hidden_dropout_prob, 124 | attention_dropout=self.attention_probs_dropout_prob, 125 | gelu_activation=self.gelu_activation, 126 | sinusoidal_embeddings=self.sinusoidal_embeddings, 127 | asm=self.asm, 128 | causal=self.causal, 129 | n_langs=self.n_langs, 130 | max_position_embeddings=self.max_position_embeddings, 131 | initializer_range=self.initializer_range, 132 | summary_type=self.summary_type, 133 | use_proj=self.use_proj) 134 | 135 | return config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask 136 | 137 | def check_loss_output(self, result): 138 | self.parent.assertListEqual( 139 | list(result["loss"].size()), 140 | []) 141 | 142 | def create_and_check_xlm_model(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): 143 | model = XLMModel(config=config) 144 | model.eval() 145 | outputs = model(input_ids, lengths=input_lengths, langs=token_type_ids) 146 | outputs = model(input_ids, langs=token_type_ids) 147 | outputs = model(input_ids) 148 | sequence_output = outputs[0] 149 | result = { 150 | "sequence_output": sequence_output, 151 | } 152 | self.parent.assertListEqual( 153 | list(result["sequence_output"].size()), 154 | [self.batch_size, self.seq_length, self.hidden_size]) 155 | 156 | 157 | def create_and_check_xlm_lm_head(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): 158 | model = XLMWithLMHeadModel(config) 159 | model.eval() 160 | 161 | loss, logits = model(input_ids, token_type_ids=token_type_ids, labels=token_labels) 162 | 163 | result = { 164 | "loss": loss, 165 | "logits": logits, 166 | } 167 | 168 | self.parent.assertListEqual( 169 | list(result["loss"].size()), 170 | []) 171 | self.parent.assertListEqual( 172 | list(result["logits"].size()), 173 | [self.batch_size, self.seq_length, self.vocab_size]) 174 | 175 | 176 | def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): 177 | model = XLMForQuestionAnswering(config) 178 | model.eval() 179 | 180 | outputs = model(input_ids) 181 | start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs 182 | 183 | outputs = model(input_ids, start_positions=sequence_labels, 184 | end_positions=sequence_labels, 185 | cls_index=sequence_labels, 186 | is_impossible=is_impossible_labels, 187 | p_mask=input_mask) 188 | 189 | outputs = model(input_ids, start_positions=sequence_labels, 190 | end_positions=sequence_labels, 191 | cls_index=sequence_labels, 192 | is_impossible=is_impossible_labels) 193 | 194 | (total_loss,) = outputs 195 | 196 | outputs = model(input_ids, start_positions=sequence_labels, 197 | end_positions=sequence_labels) 198 | 199 | (total_loss,) = outputs 200 | 201 | result = { 202 | "loss": total_loss, 203 | "start_top_log_probs": start_top_log_probs, 204 | "start_top_index": start_top_index, 205 | "end_top_log_probs": end_top_log_probs, 206 | "end_top_index": end_top_index, 207 | "cls_logits": cls_logits, 208 | } 209 | 210 | self.parent.assertListEqual( 211 | list(result["loss"].size()), 212 | []) 213 | self.parent.assertListEqual( 214 | list(result["start_top_log_probs"].size()), 215 | [self.batch_size, model.config.start_n_top]) 216 | self.parent.assertListEqual( 217 | list(result["start_top_index"].size()), 218 | [self.batch_size, model.config.start_n_top]) 219 | self.parent.assertListEqual( 220 | list(result["end_top_log_probs"].size()), 221 | [self.batch_size, model.config.start_n_top * model.config.end_n_top]) 222 | self.parent.assertListEqual( 223 | list(result["end_top_index"].size()), 224 | [self.batch_size, model.config.start_n_top * model.config.end_n_top]) 225 | self.parent.assertListEqual( 226 | list(result["cls_logits"].size()), 227 | [self.batch_size]) 228 | 229 | 230 | def create_and_check_xlm_sequence_classif(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): 231 | model = XLMForSequenceClassification(config) 232 | model.eval() 233 | 234 | (logits,) = model(input_ids) 235 | loss, logits = model(input_ids, labels=sequence_labels) 236 | 237 | result = { 238 | "loss": loss, 239 | "logits": logits, 240 | } 241 | 242 | self.parent.assertListEqual( 243 | list(result["loss"].size()), 244 | []) 245 | self.parent.assertListEqual( 246 | list(result["logits"].size()), 247 | [self.batch_size, self.type_sequence_label_size]) 248 | 249 | 250 | def prepare_config_and_inputs_for_common(self): 251 | config_and_inputs = self.prepare_config_and_inputs() 252 | (config, input_ids, token_type_ids, input_lengths, 253 | sequence_labels, token_labels, is_impossible_labels, input_mask) = config_and_inputs 254 | inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths} 255 | return config, inputs_dict 256 | 257 | def setUp(self): 258 | self.model_tester = XLMModelTest.XLMModelTester(self) 259 | self.config_tester = ConfigTester(self, config_class=XLMConfig, emb_dim=37) 260 | 261 | def test_config(self): 262 | self.config_tester.run_common_tests() 263 | 264 | def test_xlm_model(self): 265 | config_and_inputs = self.model_tester.prepare_config_and_inputs() 266 | self.model_tester.create_and_check_xlm_model(*config_and_inputs) 267 | 268 | # config_and_inputs = tester.prepare_config_and_inputs() 269 | # tester.create_and_check_xlm_for_masked_lm(*config_and_inputs) 270 | 271 | # config_and_inputs = tester.prepare_config_and_inputs() 272 | # tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs) 273 | 274 | # config_and_inputs = tester.prepare_config_and_inputs() 275 | # tester.create_and_check_xlm_for_question_answering(*config_and_inputs) 276 | 277 | # config_and_inputs = tester.prepare_config_and_inputs() 278 | # tester.create_and_check_xlm_for_sequence_classification(*config_and_inputs) 279 | 280 | # config_and_inputs = tester.prepare_config_and_inputs() 281 | # tester.create_and_check_xlm_for_token_classification(*config_and_inputs) 282 | 283 | @pytest.mark.slow 284 | def test_model_from_pretrained(self): 285 | cache_dir = "/tmp/pytorch_transformers_test/" 286 | for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 287 | model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) 288 | shutil.rmtree(cache_dir) 289 | self.assertIsNotNone(model) 290 | 291 | 292 | if __name__ == "__main__": 293 | unittest.main() 294 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | 21 | import torch 22 | 23 | from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, 24 | WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 25 | 26 | import numpy as np 27 | 28 | 29 | def unwrap_schedule(scheduler, num_steps=10): 30 | lrs = [] 31 | for _ in range(num_steps): 32 | scheduler.step() 33 | lrs.append(scheduler.get_lr()) 34 | return lrs 35 | 36 | class OptimizationTest(unittest.TestCase): 37 | 38 | def assertListAlmostEqual(self, list1, list2, tol): 39 | self.assertEqual(len(list1), len(list2)) 40 | for a, b in zip(list1, list2): 41 | self.assertAlmostEqual(a, b, delta=tol) 42 | 43 | def test_adam_w(self): 44 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 45 | target = torch.tensor([0.4, 0.2, -0.5]) 46 | criterion = torch.nn.MSELoss() 47 | # No warmup, constant schedule, no gradient clipping 48 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0) 49 | for _ in range(100): 50 | loss = criterion(w, target) 51 | loss.backward() 52 | optimizer.step() 53 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 54 | w.grad.zero_() 55 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 56 | 57 | 58 | class ScheduleInitTest(unittest.TestCase): 59 | m = torch.nn.Linear(50, 50) 60 | optimizer = AdamW(m.parameters(), lr=10.) 61 | num_steps = 10 62 | 63 | def assertListAlmostEqual(self, list1, list2, tol): 64 | self.assertEqual(len(list1), len(list2)) 65 | for a, b in zip(list1, list2): 66 | self.assertAlmostEqual(a, b, delta=tol) 67 | 68 | def test_constant_scheduler(self): 69 | scheduler = ConstantLRSchedule(self.optimizer) 70 | lrs = unwrap_schedule(scheduler, self.num_steps) 71 | expected_learning_rates = [10.] * self.num_steps 72 | self.assertEqual(len(lrs[0]), 1) 73 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 74 | 75 | def test_warmup_constant_scheduler(self): 76 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 77 | lrs = unwrap_schedule(scheduler, self.num_steps) 78 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] 79 | self.assertEqual(len(lrs[0]), 1) 80 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 81 | 82 | def test_warmup_linear_scheduler(self): 83 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 84 | lrs = unwrap_schedule(scheduler, self.num_steps) 85 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] 86 | self.assertEqual(len(lrs[0]), 1) 87 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 88 | 89 | def test_warmup_cosine_scheduler(self): 90 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 91 | lrs = unwrap_schedule(scheduler, self.num_steps) 92 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] 93 | self.assertEqual(len(lrs[0]), 1) 94 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 95 | 96 | def test_warmup_cosine_hard_restart_scheduler(self): 97 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 98 | lrs = unwrap_schedule(scheduler, self.num_steps) 99 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] 100 | self.assertEqual(len(lrs[0]), 1) 101 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 102 | 103 | 104 | if __name__ == "__main__": 105 | unittest.main() 106 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_bert import (BasicTokenizer, 22 | BertTokenizer, 23 | WordpieceTokenizer, 24 | _is_control, _is_punctuation, 25 | _is_whitespace, VOCAB_FILES_NAMES) 26 | 27 | from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory 28 | 29 | class TokenizationTest(unittest.TestCase): 30 | 31 | def test_full_tokenizer(self): 32 | vocab_tokens = [ 33 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 34 | "##ing", ",", "low", "lowest", 35 | ] 36 | with TemporaryDirectory() as tmpdirname: 37 | vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 38 | with open(vocab_file, "w", encoding='utf-8') as vocab_writer: 39 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 40 | 41 | input_text = u"UNwant\u00E9d,running" 42 | output_text = u"unwanted, running" 43 | 44 | create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname) 45 | 46 | tokenizer = BertTokenizer(vocab_file) 47 | 48 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 49 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 50 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 51 | 52 | def test_chinese(self): 53 | tokenizer = BasicTokenizer() 54 | 55 | self.assertListEqual( 56 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 57 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 58 | 59 | def test_basic_tokenizer_lower(self): 60 | tokenizer = BasicTokenizer(do_lower_case=True) 61 | 62 | self.assertListEqual( 63 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 64 | ["hello", "!", "how", "are", "you", "?"]) 65 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 66 | 67 | def test_basic_tokenizer_no_lower(self): 68 | tokenizer = BasicTokenizer(do_lower_case=False) 69 | 70 | self.assertListEqual( 71 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 72 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 73 | 74 | def test_wordpiece_tokenizer(self): 75 | vocab_tokens = [ 76 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 77 | "##ing" 78 | ] 79 | 80 | vocab = {} 81 | for (i, token) in enumerate(vocab_tokens): 82 | vocab[token] = i 83 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 84 | 85 | self.assertListEqual(tokenizer.tokenize(""), []) 86 | 87 | self.assertListEqual( 88 | tokenizer.tokenize("unwanted running"), 89 | ["un", "##want", "##ed", "runn", "##ing"]) 90 | 91 | self.assertListEqual( 92 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 93 | 94 | def test_is_whitespace(self): 95 | self.assertTrue(_is_whitespace(u" ")) 96 | self.assertTrue(_is_whitespace(u"\t")) 97 | self.assertTrue(_is_whitespace(u"\r")) 98 | self.assertTrue(_is_whitespace(u"\n")) 99 | self.assertTrue(_is_whitespace(u"\u00A0")) 100 | 101 | self.assertFalse(_is_whitespace(u"A")) 102 | self.assertFalse(_is_whitespace(u"-")) 103 | 104 | def test_is_control(self): 105 | self.assertTrue(_is_control(u"\u0005")) 106 | 107 | self.assertFalse(_is_control(u"A")) 108 | self.assertFalse(_is_control(u" ")) 109 | self.assertFalse(_is_control(u"\t")) 110 | self.assertFalse(_is_control(u"\r")) 111 | 112 | def test_is_punctuation(self): 113 | self.assertTrue(_is_punctuation(u"-")) 114 | self.assertTrue(_is_punctuation(u"$")) 115 | self.assertTrue(_is_punctuation(u"`")) 116 | self.assertTrue(_is_punctuation(u".")) 117 | 118 | self.assertFalse(_is_punctuation(u"A")) 119 | self.assertFalse(_is_punctuation(u" ")) 120 | 121 | 122 | if __name__ == '__main__': 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory 24 | 25 | class GPT2TokenizationTest(unittest.TestCase): 26 | 27 | def test_full_tokenizer(self): 28 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 29 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 30 | "lo", "low", "er", 31 | "low", "lowest", "newer", "wider", ""] 32 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 33 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 34 | special_tokens_map = {"unk_token": ""} 35 | 36 | with TemporaryDirectory() as tmpdirname: 37 | vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 38 | merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) 39 | with open(vocab_file, "w") as fp: 40 | fp.write(json.dumps(vocab_tokens)) 41 | with open(merges_file, "w") as fp: 42 | fp.write("\n".join(merges)) 43 | 44 | input_text = u"lower newer" 45 | output_text = u"lowernewer" 46 | 47 | create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map) 48 | 49 | tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map) 50 | text = "lower" 51 | bpe_tokens = ["low", "er"] 52 | tokens = tokenizer.tokenize(text) 53 | self.assertListEqual(tokens, bpe_tokens) 54 | 55 | input_tokens = tokens + [tokenizer.unk_token] 56 | input_bpe_tokens = [13, 12, 17] 57 | self.assertListEqual( 58 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory 24 | 25 | 26 | class OpenAIGPTTokenizationTest(unittest.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 30 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 31 | "w", "r", "t", 32 | "lo", "low", "er", 33 | "low", "lowest", "newer", "wider", ""] 34 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 35 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 36 | 37 | with TemporaryDirectory() as tmpdirname: 38 | vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 39 | merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) 40 | with open(vocab_file, "w") as fp: 41 | fp.write(json.dumps(vocab_tokens)) 42 | with open(merges_file, "w") as fp: 43 | fp.write("\n".join(merges)) 44 | 45 | input_text = u"lower newer" 46 | output_text = u"lower newer" 47 | 48 | create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname) 49 | 50 | tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file) 51 | 52 | text = "lower" 53 | bpe_tokens = ["low", "er"] 54 | tokens = tokenizer.tokenize(text) 55 | self.assertListEqual(tokens, bpe_tokens) 56 | 57 | input_tokens = tokens + [""] 58 | input_bpe_tokens = [14, 15, 20] 59 | self.assertListEqual( 60 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 61 | 62 | 63 | if __name__ == '__main__': 64 | unittest.main() 65 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_tests_commons.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import sys 19 | from io import open 20 | import tempfile 21 | import shutil 22 | 23 | if sys.version_info[0] == 2: 24 | import cPickle as pickle 25 | 26 | class TemporaryDirectory(object): 27 | """Context manager for tempfile.mkdtemp() so it's usable with "with" statement.""" 28 | def __enter__(self): 29 | self.name = tempfile.mkdtemp() 30 | return self.name 31 | def __exit__(self, exc_type, exc_value, traceback): 32 | shutil.rmtree(self.name) 33 | else: 34 | import pickle 35 | TemporaryDirectory = tempfile.TemporaryDirectory 36 | unicode = str 37 | 38 | 39 | def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs): 40 | tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) 41 | 42 | before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") 43 | 44 | with TemporaryDirectory() as tmpdirname: 45 | tokenizer.save_pretrained(tmpdirname) 46 | tokenizer = tokenizer.from_pretrained(tmpdirname) 47 | 48 | after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") 49 | tester.assertListEqual(before_tokens, after_tokens) 50 | 51 | def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs): 52 | tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) 53 | tester.assertIsNotNone(tokenizer) 54 | 55 | text = u"Munich and Berlin are nice cities" 56 | subwords = tokenizer.tokenize(text) 57 | 58 | with TemporaryDirectory() as tmpdirname: 59 | 60 | filename = os.path.join(tmpdirname, u"tokenizer.bin") 61 | pickle.dump(tokenizer, open(filename, "wb")) 62 | 63 | tokenizer_new = pickle.load(open(filename, "rb")) 64 | 65 | subwords_loaded = tokenizer_new.tokenize(text) 66 | 67 | tester.assertListEqual(subwords, subwords_loaded) 68 | 69 | 70 | def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs): 71 | tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) 72 | 73 | vocab_size = tokenizer.vocab_size 74 | all_size = len(tokenizer) 75 | 76 | tester.assertNotEqual(vocab_size, 0) 77 | tester.assertEqual(vocab_size, all_size) 78 | 79 | new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] 80 | added_toks = tokenizer.add_tokens(new_toks) 81 | vocab_size_2 = tokenizer.vocab_size 82 | all_size_2 = len(tokenizer) 83 | 84 | tester.assertNotEqual(vocab_size_2, 0) 85 | tester.assertEqual(vocab_size, vocab_size_2) 86 | tester.assertEqual(added_toks, len(new_toks)) 87 | tester.assertEqual(all_size_2, all_size + len(new_toks)) 88 | 89 | tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") 90 | tester.assertGreaterEqual(len(tokens), 4) 91 | tester.assertGreater(tokens[0], tokenizer.vocab_size - 1) 92 | tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1) 93 | 94 | new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", 95 | 'pad_token': "<<<<<|||>|>>>>|>"} 96 | added_toks_2 = tokenizer.add_special_tokens(new_toks_2) 97 | vocab_size_3 = tokenizer.vocab_size 98 | all_size_3 = len(tokenizer) 99 | 100 | tester.assertNotEqual(vocab_size_3, 0) 101 | tester.assertEqual(vocab_size, vocab_size_3) 102 | tester.assertEqual(added_toks_2, len(new_toks_2)) 103 | tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) 104 | 105 | tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") 106 | 107 | tester.assertGreaterEqual(len(tokens), 6) 108 | tester.assertGreater(tokens[0], tokenizer.vocab_size - 1) 109 | tester.assertGreater(tokens[0], tokens[1]) 110 | tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1) 111 | tester.assertGreater(tokens[-2], tokens[-3]) 112 | tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token)) 113 | tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) 114 | 115 | 116 | def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): 117 | tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) 118 | 119 | tokens = tokenizer.tokenize(input_text) 120 | ids = tokenizer.convert_tokens_to_ids(tokens) 121 | ids_2 = tokenizer.encode(input_text) 122 | tester.assertListEqual(ids, ids_2) 123 | 124 | tokens_2 = tokenizer.convert_ids_to_tokens(ids) 125 | text_2 = tokenizer.decode(ids) 126 | 127 | tester.assertEqual(text_2, output_text) 128 | 129 | tester.assertNotEqual(len(tokens_2), 0) 130 | tester.assertIsInstance(text_2, (str, unicode)) 131 | 132 | 133 | def create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): 134 | weights_list = list(tokenizer_class.max_model_input_sizes.keys()) 135 | weights_lists_2 = [] 136 | for file_id, map_list in tokenizer_class.pretrained_vocab_files_map.items(): 137 | weights_lists_2.append(list(map_list.keys())) 138 | 139 | for weights_list_2 in weights_lists_2: 140 | tester.assertListEqual(weights_list, weights_list_2) 141 | 142 | 143 | def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs): 144 | create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs) 145 | create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs) 146 | create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs) 147 | create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs) 148 | create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs) 149 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES 22 | 23 | from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory 24 | 25 | class TransfoXLTokenizationTest(unittest.TestCase): 26 | 27 | def test_full_tokenizer(self): 28 | vocab_tokens = [ 29 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", 30 | "running", ",", "low", "l", 31 | ] 32 | with TemporaryDirectory() as tmpdirname: 33 | vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 34 | with open(vocab_file, "w", encoding='utf-8') as vocab_writer: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | 37 | input_text = u" UNwanted , running" 38 | output_text = u" unwanted, running" 39 | 40 | create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True) 41 | 42 | tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) 43 | 44 | tokens = tokenizer.tokenize(u" UNwanted , running") 45 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 46 | 47 | self.assertListEqual( 48 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 49 | 50 | def test_full_tokenizer_lower(self): 51 | tokenizer = TransfoXLTokenizer(lower_case=True) 52 | 53 | self.assertListEqual( 54 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 55 | ["hello", "!", "how", "are", "you", "?"]) 56 | 57 | def test_full_tokenizer_no_lower(self): 58 | tokenizer = TransfoXLTokenizer(lower_case=False) 59 | 60 | self.assertListEqual( 61 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 62 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 63 | 64 | 65 | if __name__ == '__main__': 66 | unittest.main() 67 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | 22 | from pytorch_transformers import PreTrainedTokenizer 23 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer 24 | 25 | class TokenizerUtilsTest(unittest.TestCase): 26 | def check_tokenizer_from_pretrained(self, tokenizer_class): 27 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 28 | for model_name in s3_models[:1]: 29 | tokenizer = tokenizer_class.from_pretrained(model_name) 30 | self.assertIsNotNone(tokenizer) 31 | self.assertIsInstance(tokenizer, tokenizer_class) 32 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 33 | 34 | for special_tok in tokenizer.all_special_tokens: 35 | if six.PY2: 36 | self.assertIsInstance(special_tok, unicode) 37 | else: 38 | self.assertIsInstance(special_tok, str) 39 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 40 | self.assertIsInstance(special_tok_id, int) 41 | 42 | def test_pretrained_tokenizers(self): 43 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory 24 | 25 | class XLMTokenizationTest(unittest.TestCase): 26 | 27 | def test_full_tokenizer(self): 28 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 29 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 30 | "w", "r", "t", 31 | "lo", "low", "er", 32 | "low", "lowest", "newer", "wider", ""] 33 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 34 | merges = ["l o 123", "lo w 1456", "e r 1789", ""] 35 | 36 | with TemporaryDirectory() as tmpdirname: 37 | vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 38 | merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) 39 | with open(vocab_file, "w") as fp: 40 | fp.write(json.dumps(vocab_tokens)) 41 | with open(merges_file, "w") as fp: 42 | fp.write("\n".join(merges)) 43 | 44 | input_text = u"lower newer" 45 | output_text = u"lower newer" 46 | 47 | create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname) 48 | 49 | tokenizer = XLMTokenizer(vocab_file, merges_file) 50 | 51 | text = "lower" 52 | bpe_tokens = ["low", "er"] 53 | tokens = tokenizer.tokenize(text) 54 | self.assertListEqual(tokens, bpe_tokens) 55 | 56 | input_tokens = tokens + [""] 57 | input_bpe_tokens = [14, 15, 20] 58 | self.assertListEqual( 59 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 60 | 61 | 62 | if __name__ == '__main__': 63 | unittest.main() 64 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_xlnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) 21 | 22 | from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory 23 | 24 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 25 | 'fixtures/test_sentencepiece.model') 26 | 27 | class XLNetTokenizationTest(unittest.TestCase): 28 | 29 | def test_full_tokenizer(self): 30 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 31 | 32 | with TemporaryDirectory() as tmpdirname: 33 | tokenizer.save_pretrained(tmpdirname) 34 | 35 | input_text = u"This is a test" 36 | output_text = u"This is a test" 37 | 38 | create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname) 39 | 40 | tokens = tokenizer.tokenize(u'This is a test') 41 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 42 | 43 | self.assertListEqual( 44 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 45 | 46 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 47 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 48 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 49 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 50 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 51 | ids = tokenizer.convert_tokens_to_ids(tokens) 52 | self.assertListEqual( 53 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 54 | 602, 347, 347, 347, 3, 12, 66, 55 | 46, 72, 80, 6, 0, 4]) 56 | 57 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 58 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 59 | u'or', u'n', SPIECE_UNDERLINE + u'in', 60 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 61 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 62 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 63 | u'', u'.']) 64 | 65 | def test_tokenizer_lower(self): 66 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) 67 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 68 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 69 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 70 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 71 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 72 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) 73 | 74 | def test_tokenizer_no_lower(self): 75 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) 76 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 77 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', 78 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 79 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 80 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 81 | 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | VOCAB_FILES_NAMES = { 39 | 'vocab_file': 'vocab.json', 40 | 'merges_file': 'merges.txt', 41 | } 42 | 43 | PRETRAINED_VOCAB_FILES_MAP = { 44 | 'vocab_file': 45 | { 46 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 47 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 48 | }, 49 | 'merges_file': 50 | { 51 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 52 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 53 | }, 54 | } 55 | 56 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 57 | 'gpt2': 1024, 58 | 'gpt2-medium': 1024, 59 | } 60 | 61 | @lru_cache() 62 | def bytes_to_unicode(): 63 | """ 64 | Returns list of utf-8 byte and a corresponding list of unicode strings. 65 | The reversible bpe codes work on unicode strings. 66 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 67 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 68 | This is a signficant percentage of your normal, say, 32K bpe vocab. 69 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 70 | And avoids mapping to whitespace/control characters the bpe code barfs on. 71 | """ 72 | _chr = unichr if sys.version_info[0] == 2 else chr 73 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 74 | cs = bs[:] 75 | n = 0 76 | for b in range(2**8): 77 | if b not in bs: 78 | bs.append(b) 79 | cs.append(2**8+n) 80 | n += 1 81 | cs = [_chr(n) for n in cs] 82 | return dict(zip(bs, cs)) 83 | 84 | def get_pairs(word): 85 | """Return set of symbol pairs in a word. 86 | 87 | Word is represented as tuple of symbols (symbols being variable-length strings). 88 | """ 89 | pairs = set() 90 | prev_char = word[0] 91 | for char in word[1:]: 92 | pairs.add((prev_char, char)) 93 | prev_char = char 94 | return pairs 95 | 96 | class GPT2Tokenizer(PreTrainedTokenizer): 97 | """ 98 | GPT-2 BPE tokenizer. Peculiarities: 99 | - Byte-level BPE 100 | """ 101 | vocab_files_names = VOCAB_FILES_NAMES 102 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 103 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 104 | 105 | def __init__(self, vocab_file, merges_file, errors='replace', 106 | bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): 107 | super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs) 108 | 109 | self.encoder = json.load(open(vocab_file)) 110 | self.decoder = {v:k for k,v in self.encoder.items()} 111 | self.errors = errors # how to handle errors in decoding 112 | self.byte_encoder = bytes_to_unicode() 113 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 114 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 115 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 116 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 117 | self.cache = {} 118 | 119 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 120 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 121 | 122 | @property 123 | def vocab_size(self): 124 | return len(self.encoder) 125 | 126 | def bpe(self, token): 127 | if token in self.cache: 128 | return self.cache[token] 129 | word = tuple(token) 130 | pairs = get_pairs(word) 131 | 132 | if not pairs: 133 | return token 134 | 135 | while True: 136 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 137 | if bigram not in self.bpe_ranks: 138 | break 139 | first, second = bigram 140 | new_word = [] 141 | i = 0 142 | while i < len(word): 143 | try: 144 | j = word.index(first, i) 145 | new_word.extend(word[i:j]) 146 | i = j 147 | except: 148 | new_word.extend(word[i:]) 149 | break 150 | 151 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 152 | new_word.append(first+second) 153 | i += 2 154 | else: 155 | new_word.append(word[i]) 156 | i += 1 157 | new_word = tuple(new_word) 158 | word = new_word 159 | if len(word) == 1: 160 | break 161 | else: 162 | pairs = get_pairs(word) 163 | word = ' '.join(word) 164 | self.cache[token] = word 165 | return word 166 | 167 | def _tokenize(self, text): 168 | """ Tokenize a string. """ 169 | bpe_tokens = [] 170 | for token in re.findall(self.pat, text): 171 | if sys.version_info[0] == 2: 172 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 173 | else: 174 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 175 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 176 | return bpe_tokens 177 | 178 | def _convert_token_to_id(self, token): 179 | """ Converts a token (str/unicode) in an id using the vocab. """ 180 | if token in self.encoder: 181 | return self.encoder.get(token) 182 | return self.encoder.get(self.unk_token) 183 | 184 | def _convert_id_to_token(self, index): 185 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 186 | return self.decoder.get(index) 187 | 188 | def convert_tokens_to_string(self, tokens): 189 | """ Converts a sequence of tokens (string) in a single string. """ 190 | text = ''.join(tokens) 191 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 192 | return text 193 | 194 | def save_vocabulary(self, save_directory): 195 | """Save the tokenizer vocabulary and merge files to a directory.""" 196 | if not os.path.isdir(save_directory): 197 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 198 | return 199 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 200 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 201 | 202 | with open(vocab_file, 'w', encoding='utf-8') as f: 203 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 204 | 205 | index = 0 206 | with open(merge_file, "w", encoding="utf-8") as writer: 207 | writer.write(u'#version: 0.2\n') 208 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 209 | if index != token_index: 210 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 211 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 212 | index = token_index 213 | writer.write(' '.join(bpe_tokens) + u'\n') 214 | index += 1 215 | 216 | return vocab_file, merge_file 217 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | from io import open 24 | 25 | from .tokenization_utils import PreTrainedTokenizer 26 | from .tokenization_bert import BasicTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = { 31 | 'vocab_file': 'vocab.json', 32 | 'merges_file': 'merges.txt', 33 | } 34 | 35 | PRETRAINED_VOCAB_FILES_MAP = { 36 | 'vocab_file': 37 | { 38 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 39 | }, 40 | 'merges_file': 41 | { 42 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 43 | }, 44 | } 45 | 46 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 47 | 'openai-gpt': 512, 48 | } 49 | 50 | def get_pairs(word): 51 | """ 52 | Return set of symbol pairs in a word. 53 | word is represented as tuple of symbols (symbols being variable-length strings) 54 | """ 55 | pairs = set() 56 | prev_char = word[0] 57 | for char in word[1:]: 58 | pairs.add((prev_char, char)) 59 | prev_char = char 60 | return pairs 61 | 62 | def text_standardize(text): 63 | """ 64 | fixes some issues the spacy tokenizer had on books corpus 65 | also does some whitespace standardization 66 | """ 67 | text = text.replace('—', '-') 68 | text = text.replace('–', '-') 69 | text = text.replace('―', '-') 70 | text = text.replace('…', '...') 71 | text = text.replace('´', "'") 72 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 73 | text = re.sub(r'\s*\n\s*', ' \n ', text) 74 | text = re.sub(r'[^\S\n]+', ' ', text) 75 | return text.strip() 76 | 77 | class OpenAIGPTTokenizer(PreTrainedTokenizer): 78 | """ 79 | BPE tokenizer. Peculiarities: 80 | - lower case all inputs 81 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 82 | """ 83 | vocab_files_names = VOCAB_FILES_NAMES 84 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 85 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 86 | 87 | def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): 88 | super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) 89 | 90 | try: 91 | import ftfy 92 | import spacy 93 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 94 | self.fix_text = ftfy.fix_text 95 | except ImportError: 96 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 97 | self.nlp = BasicTokenizer(do_lower_case=True) 98 | self.fix_text = None 99 | 100 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 101 | self.decoder = {v:k for k,v in self.encoder.items()} 102 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 103 | merges = [tuple(merge.split()) for merge in merges] 104 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 105 | self.cache = {} 106 | 107 | @property 108 | def vocab_size(self): 109 | return len(self.encoder) 110 | 111 | def bpe(self, token): 112 | word = tuple(token[:-1]) + (token[-1] + '',) 113 | if token in self.cache: 114 | return self.cache[token] 115 | pairs = get_pairs(word) 116 | 117 | if not pairs: 118 | return token+'' 119 | 120 | while True: 121 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 122 | if bigram not in self.bpe_ranks: 123 | break 124 | first, second = bigram 125 | new_word = [] 126 | i = 0 127 | while i < len(word): 128 | try: 129 | j = word.index(first, i) 130 | new_word.extend(word[i:j]) 131 | i = j 132 | except: 133 | new_word.extend(word[i:]) 134 | break 135 | 136 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 137 | new_word.append(first+second) 138 | i += 2 139 | else: 140 | new_word.append(word[i]) 141 | i += 1 142 | new_word = tuple(new_word) 143 | word = new_word 144 | if len(word) == 1: 145 | break 146 | else: 147 | pairs = get_pairs(word) 148 | word = ' '.join(word) 149 | if word == '\n ': 150 | word = '\n' 151 | self.cache[token] = word 152 | return word 153 | 154 | def _tokenize(self, text): 155 | """ Tokenize a string. """ 156 | split_tokens = [] 157 | if self.fix_text is None: 158 | # Using BERT's BasicTokenizer 159 | text = self.nlp.tokenize(text) 160 | for token in text: 161 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 162 | else: 163 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 164 | text = self.nlp(text_standardize(self.fix_text(text))) 165 | for token in text: 166 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 167 | return split_tokens 168 | 169 | def _convert_token_to_id(self, token): 170 | """ Converts a token (str/unicode) in an id using the vocab. """ 171 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 172 | 173 | def _convert_id_to_token(self, index): 174 | """Converts an id in a token (BPE) using the vocab.""" 175 | return self.decoder.get(index, self.unk_token) 176 | 177 | def convert_tokens_to_string(self, tokens): 178 | """ Converts a sequence of tokens (string) in a single string. """ 179 | out_string = ''.join(tokens).replace('', ' ').strip() 180 | return out_string 181 | 182 | def save_vocabulary(self, save_directory): 183 | """Save the tokenizer vocabulary and merge files to a directory.""" 184 | if not os.path.isdir(save_directory): 185 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 186 | return 187 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 188 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 189 | 190 | with open(vocab_file, 'w', encoding='utf-8') as f: 191 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 192 | 193 | index = 0 194 | with open(merge_file, "w", encoding="utf-8") as writer: 195 | writer.write(u'#version: 0.2\n') 196 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 197 | if index != token_index: 198 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 199 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 200 | index = token_index 201 | writer.write(' '.join(bpe_tokens) + u'\n') 202 | index += 1 203 | 204 | return vocab_file, merge_file 205 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_xlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | from io import open 24 | 25 | from .tokenization_utils import PreTrainedTokenizer 26 | from .tokenization_bert import BasicTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = { 31 | 'vocab_file': 'vocab.json', 32 | 'merges_file': 'merges.txt', 33 | } 34 | 35 | PRETRAINED_VOCAB_FILES_MAP = { 36 | 'vocab_file': 37 | { 38 | 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json", 39 | 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json", 40 | 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json", 41 | 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json", 42 | 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json", 43 | 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json", 44 | 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json", 45 | 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json", 46 | }, 47 | 'merges_file': 48 | { 49 | 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", 50 | 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", 51 | 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 52 | 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt", 53 | 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt", 54 | 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt", 55 | 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 56 | 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", 57 | }, 58 | } 59 | 60 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 61 | 'xlm-mlm-en-2048': 512, 62 | 'xlm-mlm-ende-1024': 512, 63 | 'xlm-mlm-enfr-1024': 512, 64 | 'xlm-mlm-enro-1024': 512, 65 | 'xlm-mlm-tlm-xnli15-1024': 512, 66 | 'xlm-mlm-xnli15-1024': 512, 67 | 'xlm-clm-enfr-1024': 512, 68 | 'xlm-clm-ende-1024': 512, 69 | } 70 | 71 | def get_pairs(word): 72 | """ 73 | Return set of symbol pairs in a word. 74 | word is represented as tuple of symbols (symbols being variable-length strings) 75 | """ 76 | pairs = set() 77 | prev_char = word[0] 78 | for char in word[1:]: 79 | pairs.add((prev_char, char)) 80 | prev_char = char 81 | return pairs 82 | 83 | def text_standardize(text): 84 | """ 85 | fixes some issues the spacy tokenizer had on books corpus 86 | also does some whitespace standardization 87 | """ 88 | text = text.replace('—', '-') 89 | text = text.replace('–', '-') 90 | text = text.replace('―', '-') 91 | text = text.replace('…', '...') 92 | text = text.replace('´', "'") 93 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 94 | text = re.sub(r'\s*\n\s*', ' \n ', text) 95 | text = re.sub(r'[^\S\n]+', ' ', text) 96 | return text.strip() 97 | 98 | class XLMTokenizer(PreTrainedTokenizer): 99 | """ 100 | BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities: 101 | 102 | - lower case all inputs 103 | 104 | - uses `SpaCy tokenizer `_ and \ 105 | `ftfy `_ for pre-BPE tokenization if they are installed, \ 106 | fallback to BERT's BasicTokenizer if not. 107 | 108 | - argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \ 109 | (ex: "__classify__") to a vocabulary. 110 | """ 111 | vocab_files_names = VOCAB_FILES_NAMES 112 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 113 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 114 | 115 | def __init__(self, vocab_file, merges_file, unk_token="", bos_token="", 116 | sep_token="", pad_token="", cls_token="", 117 | mask_token="", additional_special_tokens=["", 118 | "", "", "", "", "", 119 | "", "", "", ""], **kwargs): 120 | super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token, 121 | sep_token=sep_token, pad_token=pad_token, 122 | cls_token=cls_token, mask_token=mask_token, 123 | additional_special_tokens=additional_special_tokens, 124 | **kwargs) 125 | try: 126 | import ftfy 127 | import spacy 128 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 129 | self.fix_text = ftfy.fix_text 130 | except ImportError: 131 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 132 | self.nlp = BasicTokenizer(do_lower_case=True) 133 | self.fix_text = None 134 | 135 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 136 | self.decoder = {v:k for k,v in self.encoder.items()} 137 | merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1] 138 | merges = [tuple(merge.split()[:2]) for merge in merges] 139 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 140 | self.cache = {} 141 | 142 | @property 143 | def vocab_size(self): 144 | return len(self.encoder) 145 | 146 | def bpe(self, token): 147 | word = tuple(token[:-1]) + (token[-1] + '',) 148 | if token in self.cache: 149 | return self.cache[token] 150 | pairs = get_pairs(word) 151 | 152 | if not pairs: 153 | return token+'' 154 | 155 | while True: 156 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 157 | if bigram not in self.bpe_ranks: 158 | break 159 | first, second = bigram 160 | new_word = [] 161 | i = 0 162 | while i < len(word): 163 | try: 164 | j = word.index(first, i) 165 | new_word.extend(word[i:j]) 166 | i = j 167 | except: 168 | new_word.extend(word[i:]) 169 | break 170 | 171 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 172 | new_word.append(first+second) 173 | i += 2 174 | else: 175 | new_word.append(word[i]) 176 | i += 1 177 | new_word = tuple(new_word) 178 | word = new_word 179 | if len(word) == 1: 180 | break 181 | else: 182 | pairs = get_pairs(word) 183 | word = ' '.join(word) 184 | if word == '\n ': 185 | word = '\n' 186 | self.cache[token] = word 187 | return word 188 | 189 | def _tokenize(self, text): 190 | """ Tokenize a string. """ 191 | split_tokens = [] 192 | if self.fix_text is None: 193 | # Using BERT's BasicTokenizer 194 | text = self.nlp.tokenize(text) 195 | for token in text: 196 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 197 | else: 198 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 199 | text = self.nlp(text_standardize(self.fix_text(text))) 200 | for token in text: 201 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 202 | return split_tokens 203 | 204 | def _convert_token_to_id(self, token): 205 | """ Converts a token (str/unicode) in an id using the vocab. """ 206 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 207 | 208 | def _convert_id_to_token(self, index): 209 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 210 | return self.decoder.get(index, self.unk_token) 211 | 212 | def convert_tokens_to_string(self, tokens): 213 | """ Converts a sequence of tokens (string) in a single string. """ 214 | out_string = ''.join(tokens).replace('', ' ').strip() 215 | return out_string 216 | 217 | def save_vocabulary(self, save_directory): 218 | """Save the tokenizer vocabulary and merge files to a directory.""" 219 | if not os.path.isdir(save_directory): 220 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 221 | return 222 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 223 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 224 | 225 | with open(vocab_file, 'w', encoding='utf-8') as f: 226 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 227 | 228 | index = 0 229 | with open(merge_file, "w", encoding="utf-8") as writer: 230 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 231 | if index != token_index: 232 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 233 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 234 | index = token_index 235 | writer.write(' '.join(bpe_tokens) + u'\n') 236 | index += 1 237 | 238 | return vocab_file, merge_file 239 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_xlnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization classes for XLNet model.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import logging 20 | import os 21 | from shutil import copyfile 22 | 23 | import unicodedata 24 | import six 25 | 26 | from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} 31 | 32 | PRETRAINED_VOCAB_FILES_MAP = { 33 | 'vocab_file': 34 | { 35 | 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model", 36 | 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model", 37 | } 38 | } 39 | 40 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 41 | 'xlnet-base-cased': None, 42 | 'xlnet-large-cased': None, 43 | } 44 | 45 | SPIECE_UNDERLINE = u'▁' 46 | 47 | # Segments (not really needed) 48 | SEG_ID_A = 0 49 | SEG_ID_B = 1 50 | SEG_ID_CLS = 2 51 | SEG_ID_SEP = 3 52 | SEG_ID_PAD = 4 53 | 54 | class XLNetTokenizer(PreTrainedTokenizer): 55 | """ 56 | SentencePiece based tokenizer. Peculiarities: 57 | 58 | - requires `SentencePiece `_ 59 | """ 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 63 | 64 | def __init__(self, vocab_file, max_len=None, 65 | do_lower_case=False, remove_space=True, keep_accents=False, 66 | bos_token="", eos_token="", unk_token="", sep_token="", 67 | pad_token="", cls_token="", mask_token="", 68 | additional_special_tokens=["", ""], **kwargs): 69 | super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, 70 | unk_token=unk_token, sep_token=sep_token, 71 | pad_token=pad_token, cls_token=cls_token, 72 | mask_token=mask_token, additional_special_tokens= 73 | additional_special_tokens, **kwargs) 74 | try: 75 | import sentencepiece as spm 76 | except ImportError: 77 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" 78 | "pip install sentencepiece") 79 | 80 | self.do_lower_case = do_lower_case 81 | self.remove_space = remove_space 82 | self.keep_accents = keep_accents 83 | self.vocab_file = vocab_file 84 | 85 | self.sp_model = spm.SentencePieceProcessor() 86 | self.sp_model.Load(vocab_file) 87 | 88 | @property 89 | def vocab_size(self): 90 | return len(self.sp_model) 91 | 92 | def __getstate__(self): 93 | state = self.__dict__.copy() 94 | state["sp_model"] = None 95 | return state 96 | 97 | def __setstate__(self, d): 98 | self.__dict__ = d 99 | try: 100 | import sentencepiece as spm 101 | except ImportError: 102 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" 103 | "pip install sentencepiece") 104 | self.sp_model = spm.SentencePieceProcessor() 105 | self.sp_model.Load(self.vocab_file) 106 | 107 | def preprocess_text(self, inputs): 108 | if self.remove_space: 109 | outputs = ' '.join(inputs.strip().split()) 110 | else: 111 | outputs = inputs 112 | outputs = outputs.replace("``", '"').replace("''", '"') 113 | 114 | if six.PY2 and isinstance(outputs, str): 115 | outputs = outputs.decode('utf-8') 116 | 117 | if not self.keep_accents: 118 | outputs = unicodedata.normalize('NFKD', outputs) 119 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) 120 | if self.do_lower_case: 121 | outputs = outputs.lower() 122 | 123 | return outputs 124 | 125 | def _tokenize(self, text, return_unicode=True, sample=False): 126 | """ Tokenize a string. 127 | return_unicode is used only for py2 128 | """ 129 | text = self.preprocess_text(text) 130 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2 131 | if six.PY2 and isinstance(text, unicode): 132 | text = text.encode('utf-8') 133 | 134 | if not sample: 135 | pieces = self.sp_model.EncodeAsPieces(text) 136 | else: 137 | pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) 138 | new_pieces = [] 139 | for piece in pieces: 140 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): 141 | cur_pieces = self.sp_model.EncodeAsPieces( 142 | piece[:-1].replace(SPIECE_UNDERLINE, '')) 143 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 144 | if len(cur_pieces[0]) == 1: 145 | cur_pieces = cur_pieces[1:] 146 | else: 147 | cur_pieces[0] = cur_pieces[0][1:] 148 | cur_pieces.append(piece[-1]) 149 | new_pieces.extend(cur_pieces) 150 | else: 151 | new_pieces.append(piece) 152 | 153 | # note(zhiliny): convert back to unicode for py2 154 | if six.PY2 and return_unicode: 155 | ret_pieces = [] 156 | for piece in new_pieces: 157 | if isinstance(piece, str): 158 | piece = piece.decode('utf-8') 159 | ret_pieces.append(piece) 160 | new_pieces = ret_pieces 161 | 162 | return new_pieces 163 | 164 | def _convert_token_to_id(self, token): 165 | """ Converts a token (str/unicode) in an id using the vocab. """ 166 | return self.sp_model.PieceToId(token) 167 | 168 | def _convert_id_to_token(self, index, return_unicode=True): 169 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 170 | token = self.sp_model.IdToPiece(index) 171 | if six.PY2 and return_unicode and isinstance(token, str): 172 | token = token.decode('utf-8') 173 | return token 174 | 175 | def convert_tokens_to_string(self, tokens): 176 | """Converts a sequence of tokens (strings for sub-words) in a single string.""" 177 | out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() 178 | return out_string 179 | 180 | def save_vocabulary(self, save_directory): 181 | """ Save the sentencepiece vocabulary (copy original file) and special tokens file 182 | to a directory. 183 | """ 184 | if not os.path.isdir(save_directory): 185 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 186 | return 187 | out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 188 | 189 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 190 | copyfile(self.vocab_file, out_vocab_file) 191 | 192 | return (out_vocab_file,) 193 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | torch>=0.4.1 3 | # progress bars in model download and training scripts 4 | tqdm 5 | # Accessing files from S3 directly. 6 | boto3 7 | # Used for downloading models over HTTP 8 | requests 9 | # For OpenAI GPT 10 | regex 11 | # For XLNet 12 | sentencepiece -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | 4 | To create the package for pypi. 5 | 6 | 1. Change the version in __init__.py and setup.py. 7 | 8 | 2. Commit these changes with the message: "Release: VERSION" 9 | 10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 11 | Push the tag to git: git push --tags origin master 12 | 13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 14 | creating the wheel and the source distribution (obviously). 15 | 16 | For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. 17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 18 | 19 | For the sources, run: "python setup.py sdist" 20 | You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. 21 | 22 | 5. Check that everything looks correct by uploading the package to the pypi test server: 23 | 24 | twine upload dist/* -r pypitest 25 | (pypi suggest using twine as other methods upload files via plaintext.) 26 | 27 | Check that you can install it in a virtualenv by running: 28 | pip install -i https://testpypi.python.org/pypi pytorch-transformers 29 | 30 | 6. Upload the final version to actual pypi: 31 | twine upload dist/* -r pypi 32 | 33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 34 | 35 | """ 36 | from io import open 37 | from setuptools import find_packages, setup 38 | 39 | setup( 40 | name="pytorch_transformers", 41 | version="1.0.0", 42 | author="Thomas Wolf, Lysandre Debut, Victor Sanh, Tim Rault, Google AI Language Team Authors, Open AI team Authors", 43 | author_email="thomas@huggingface.co", 44 | description="Repository of pre-trained NLP Transformer models: BERT, GPT & GPT-2, Transformer-XL, XLNet and XLM", 45 | long_description=open("README.md", "r", encoding='utf-8').read(), 46 | long_description_content_type="text/markdown", 47 | keywords='NLP deep learning transformer pytorch BERT GPT GPT-2 google openai CMU', 48 | license='Apache', 49 | url="https://github.com/huggingface/pytorch-transformers", 50 | packages=find_packages(exclude=["*.tests", "*.tests.*", 51 | "tests.*", "tests"]), 52 | install_requires=['torch>=0.4.1', 53 | 'numpy', 54 | 'boto3', 55 | 'requests', 56 | 'tqdm', 57 | 'regex', 58 | 'sentencepiece'], 59 | entry_points={ 60 | 'console_scripts': [ 61 | "pytorch_transformers=pytorch_transformers.__main__:main", 62 | ] 63 | }, 64 | # python_requires='>=3.5.0', 65 | tests_require=['pytest'], 66 | classifiers=[ 67 | 'Intended Audience :: Science/Research', 68 | 'License :: OSI Approved :: Apache Software License', 69 | 'Programming Language :: Python :: 3', 70 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 71 | ], 72 | ) 73 | --------------------------------------------------------------------------------