├── MANIFEST.in ├── pytorch_transformers ├── tests │ ├── __init__.py │ ├── fixtures │ │ ├── input.txt │ │ ├── test_sentencepiece.model │ │ └── sample_text.txt │ ├── conftest.py │ ├── modeling_auto_test.py │ ├── tokenization_auto_test.py │ ├── tokenization_utils_test.py │ ├── modeling_gpt2_test.py │ ├── modeling_openai_test.py │ ├── tokenization_openai_test.py │ ├── tokenization_gpt2_test.py │ ├── tokenization_transfo_xl_test.py │ ├── tokenization_xlm_test.py │ ├── tokenization_roberta_test.py │ ├── tokenization_xlnet_test.py │ ├── tokenization_bert_test.py │ ├── tokenization_tests_commons.py │ └── optimization_test.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_xlm_checkpoint_to_pytorch.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── __init__.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_xlnet_checkpoint_to_pytorch.py ├── convert_pytorch_checkpoint_to_tf.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── tokenization_auto.py └── __main__.py ├── examples ├── requirements.txt ├── tests_samples │ ├── .gitignore │ └── MRPC │ │ ├── dev.tsv │ │ └── train.tsv ├── test_examples.py ├── lm_finetuning │ └── README.md └── single_model_scripts │ └── run_transfo_xl.py ├── docs ├── source │ ├── _static │ │ ├── css │ │ │ ├── Calibre-Light.ttf │ │ │ ├── Calibre-Medium.otf │ │ │ ├── Calibre-Thin.otf │ │ │ ├── Calibre-Regular.otf │ │ │ ├── code-snippets.css │ │ │ └── huggingface.css │ │ └── js │ │ │ └── custom.js │ ├── imgs │ │ ├── warmup_cosine_schedule.png │ │ ├── warmup_linear_schedule.png │ │ ├── warmup_constant_schedule.png │ │ ├── warmup_cosine_hard_restarts_schedule.png │ │ └── warmup_cosine_warm_restarts_schedule.png │ ├── main_classes │ │ ├── configuration.rst │ │ ├── model.rst │ │ ├── tokenizer.rst │ │ └── optimizer_schedules.rst │ ├── model_doc │ │ ├── transformerxl.rst │ │ ├── gpt2.rst │ │ ├── roberta.rst │ │ ├── gpt.rst │ │ ├── auto.rst │ │ ├── xlm.rst │ │ ├── xlnet.rst │ │ └── bert.rst │ ├── bertology.rst │ ├── notebooks.rst │ ├── installation.rst │ ├── index.rst │ ├── converting_tensorflow_models.rst │ ├── migration.md │ ├── conf.py │ └── torchscript.rst ├── Makefile ├── requirements.txt └── README.md ├── docker └── Dockerfile ├── .github ├── ISSUE_TEMPLATE │ ├── question-help.md │ ├── feature-request.md │ ├── bug-report.md │ └── migration.md └── stale.yml ├── .coveragerc ├── requirements.txt ├── hubconf.py ├── .circleci └── config.yml ├── .gitignore ├── setup.py └── hubconfs ├── transformer_xl_hubconf.py ├── xlm_hubconf.py ├── gpt2_hubconf.py └── xlnet_hubconf.1.py /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | scikit-learn -------------------------------------------------------------------------------- /examples/tests_samples/.gitignore: -------------------------------------------------------------------------------- 1 | *.* 2 | cache* 3 | temp* 4 | !*.tsv 5 | !*.json 6 | !.gitignore -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/input.txt: -------------------------------------------------------------------------------- 1 | Who was Jim Henson ? ||| Jim Henson was a puppeteer 2 | -------------------------------------------------------------------------------- /docs/source/_static/css/Calibre-Light.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/docs/source/_static/css/Calibre-Light.ttf -------------------------------------------------------------------------------- /docs/source/_static/css/Calibre-Medium.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/docs/source/_static/css/Calibre-Medium.otf -------------------------------------------------------------------------------- /docs/source/_static/css/Calibre-Thin.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/docs/source/_static/css/Calibre-Thin.otf -------------------------------------------------------------------------------- /docs/source/_static/css/Calibre-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/docs/source/_static/css/Calibre-Regular.otf -------------------------------------------------------------------------------- /docs/source/imgs/warmup_cosine_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/docs/source/imgs/warmup_cosine_schedule.png -------------------------------------------------------------------------------- /docs/source/imgs/warmup_linear_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/docs/source/imgs/warmup_linear_schedule.png -------------------------------------------------------------------------------- /docs/source/imgs/warmup_constant_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/docs/source/imgs/warmup_constant_schedule.png -------------------------------------------------------------------------------- /docs/source/imgs/warmup_cosine_hard_restarts_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/docs/source/imgs/warmup_cosine_hard_restarts_schedule.png -------------------------------------------------------------------------------- /docs/source/imgs/warmup_cosine_warm_restarts_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/docs/source/imgs/warmup_cosine_warm_restarts_schedule.png -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/test_sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlpyang/pytorch-transformers/HEAD/pytorch_transformers/tests/fixtures/test_sentencepiece.model -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:latest 2 | 3 | RUN git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext 4 | 5 | RUN pip install pytorch_transformers 6 | 7 | WORKDIR /workspace -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question-help.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓Questions & Help" 3 | about: Start a general discussion related to PyTorch Transformers 4 | --- 5 | 6 | ## ❓ Questions & Help 7 | 8 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source=pytorch_transformers 3 | omit = 4 | # skip convertion scripts from testing for now 5 | */convert_* 6 | */__main__.py 7 | [report] 8 | exclude_lines = 9 | pragma: no cover 10 | raise 11 | except 12 | register_parameter -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | torch>=1.0.0 3 | # progress bars in model download and training scripts 4 | tqdm 5 | # Accessing files from S3 directly. 6 | boto3 7 | # Used for downloading models over HTTP 8 | requests 9 | # For OpenAI GPT 10 | regex 11 | # For XLNet 12 | sentencepiece -------------------------------------------------------------------------------- /docs/source/_static/css/code-snippets.css: -------------------------------------------------------------------------------- 1 | 2 | .highlight .c1, .highlight .sd{ 3 | color: #999 4 | } 5 | 6 | .highlight .nn, .highlight .k, .highlight .s1, .highlight .nb, .highlight .bp, .highlight .kc { 7 | color: #FB8D68; 8 | } 9 | 10 | .highlight .kn, .highlight .nv, .highlight .s2, .highlight .ow { 11 | color: #6670FF; 12 | } -------------------------------------------------------------------------------- /docs/source/main_classes/configuration.rst: -------------------------------------------------------------------------------- 1 | Configuration 2 | ---------------------------------------------------- 3 | 4 | The base class ``PretrainedConfig`` implements the common methods for loading/saving a configuration either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS S3 repository). 5 | 6 | ``PretrainedConfig`` 7 | ~~~~~~~~~~~~~~~~~~~~~ 8 | 9 | .. autoclass:: pytorch_transformers.PretrainedConfig 10 | :members: 11 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | 11 | 12 | def pytest_collection_modifyitems(config, items): 13 | if config.getoption("--runslow"): 14 | # --runslow given in cli: do not skip slow tests 15 | return 16 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 17 | for item in items: 18 | if "slow" in item.keywords: 19 | item.add_marker(skip_slow) 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature Request" 3 | about: Submit a proposal/request for a new PyTorch Transformers feature 4 | --- 5 | 6 | ## 🚀 Feature 7 | 8 | 9 | 10 | ## Motivation 11 | 12 | 13 | 14 | ## Additional context 15 | 16 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster==0.7.12 2 | Babel==2.7.0 3 | certifi==2019.6.16 4 | chardet==3.0.4 5 | commonmark==0.9.0 6 | docutils==0.14 7 | future==0.17.1 8 | idna==2.8 9 | imagesize==1.1.0 10 | Jinja2==2.10.1 11 | MarkupSafe==1.1.1 12 | packaging==19.0 13 | Pygments==2.4.2 14 | pyparsing==2.4.0 15 | pytz==2019.1 16 | recommonmark==0.5.0 17 | requests==2.22.0 18 | six==1.12.0 19 | snowballstemmer==1.9.0 20 | Sphinx==2.1.2 21 | sphinx-rtd-theme==0.4.3 22 | sphinxcontrib-applehelp==1.0.1 23 | sphinxcontrib-devhelp==1.0.1 24 | sphinxcontrib-htmlhelp==1.0.2 25 | sphinxcontrib-jsmath==1.0.1 26 | sphinxcontrib-qthelp==1.0.2 27 | sphinxcontrib-serializinghtml==1.1.3 28 | urllib3==1.25.3 29 | -------------------------------------------------------------------------------- /docs/source/model_doc/transformerxl.rst: -------------------------------------------------------------------------------- 1 | Transformer XL 2 | ---------------------------------------------------- 3 | 4 | 5 | ``TransfoXLConfig`` 6 | ~~~~~~~~~~~~~~~~~~~~~ 7 | 8 | .. autoclass:: pytorch_transformers.TransfoXLConfig 9 | :members: 10 | 11 | 12 | ``TransfoXLTokenizer`` 13 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 14 | 15 | .. autoclass:: pytorch_transformers.TransfoXLTokenizer 16 | :members: 17 | 18 | 19 | ``TransfoXLModel`` 20 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 21 | 22 | .. autoclass:: pytorch_transformers.TransfoXLModel 23 | :members: 24 | 25 | 26 | ``TransfoXLLMHeadModel`` 27 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 28 | 29 | .. autoclass:: pytorch_transformers.TransfoXLLMHeadModel 30 | :members: 31 | -------------------------------------------------------------------------------- /docs/source/main_classes/model.rst: -------------------------------------------------------------------------------- 1 | Models 2 | ---------------------------------------------------- 3 | 4 | The base class ``PreTrainedModel`` implements the common methods for loading/saving a model either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS S3 repository). 5 | 6 | ``PreTrainedModel`` also implements a few methods which are common among all the models to: 7 | 8 | - resize the input token embeddings when new tokens are added to the vocabulary 9 | - prune the attention heads of the model. 10 | 11 | ``PreTrainedModel`` 12 | ~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autoclass:: pytorch_transformers.PreTrainedModel 15 | :members: 16 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false -------------------------------------------------------------------------------- /docs/source/model_doc/gpt2.rst: -------------------------------------------------------------------------------- 1 | OpenAI GPT2 2 | ---------------------------------------------------- 3 | 4 | ``GPT2Config`` 5 | ~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autoclass:: pytorch_transformers.GPT2Config 8 | :members: 9 | 10 | 11 | ``GPT2Tokenizer`` 12 | ~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autoclass:: pytorch_transformers.GPT2Tokenizer 15 | :members: 16 | 17 | 18 | ``GPT2Model`` 19 | ~~~~~~~~~~~~~~~~~~~~~ 20 | 21 | .. autoclass:: pytorch_transformers.GPT2Model 22 | :members: 23 | 24 | 25 | ``GPT2LMHeadModel`` 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | .. autoclass:: pytorch_transformers.GPT2LMHeadModel 29 | :members: 30 | 31 | 32 | ``GPT2DoubleHeadsModel`` 33 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 34 | 35 | .. autoclass:: pytorch_transformers.GPT2DoubleHeadsModel 36 | :members: 37 | -------------------------------------------------------------------------------- /docs/source/model_doc/roberta.rst: -------------------------------------------------------------------------------- 1 | RoBERTa 2 | ---------------------------------------------------- 3 | 4 | ``RobertaConfig`` 5 | ~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autoclass:: pytorch_transformers.RobertaConfig 8 | :members: 9 | 10 | 11 | ``RobertaTokenizer`` 12 | ~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autoclass:: pytorch_transformers.RobertaTokenizer 15 | :members: 16 | 17 | 18 | ``RobertaModel`` 19 | ~~~~~~~~~~~~~~~~~~~~ 20 | 21 | .. autoclass:: pytorch_transformers.RobertaModel 22 | :members: 23 | 24 | 25 | ``RobertaForMaskedLM`` 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | .. autoclass:: pytorch_transformers.RobertaForMaskedLM 29 | :members: 30 | 31 | 32 | ``RobertaForSequenceClassification`` 33 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 34 | 35 | .. autoclass:: pytorch_transformers.RobertaForSequenceClassification 36 | :members: 37 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex'] 2 | 3 | from hubconfs.bert_hubconf import ( 4 | bertTokenizer, 5 | bertModel, 6 | bertForNextSentencePrediction, 7 | bertForPreTraining, 8 | bertForMaskedLM, 9 | bertForSequenceClassification, 10 | bertForMultipleChoice, 11 | bertForQuestionAnswering, 12 | bertForTokenClassification 13 | ) 14 | from hubconfs.gpt_hubconf import ( 15 | openAIGPTTokenizer, 16 | openAIGPTModel, 17 | openAIGPTLMHeadModel, 18 | openAIGPTDoubleHeadsModel 19 | ) 20 | from hubconfs.gpt2_hubconf import ( 21 | gpt2Tokenizer, 22 | gpt2Model, 23 | gpt2LMHeadModel, 24 | gpt2DoubleHeadsModel 25 | ) 26 | from hubconfs.transformer_xl_hubconf import ( 27 | transformerXLTokenizer, 28 | transformerXLModel, 29 | transformerXLLMHeadModel 30 | ) 31 | -------------------------------------------------------------------------------- /docs/source/model_doc/gpt.rst: -------------------------------------------------------------------------------- 1 | OpenAI GPT 2 | ---------------------------------------------------- 3 | 4 | ``OpenAIGPTConfig`` 5 | ~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autoclass:: pytorch_transformers.OpenAIGPTConfig 8 | :members: 9 | 10 | 11 | ``OpenAIGPTTokenizer`` 12 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autoclass:: pytorch_transformers.OpenAIGPTTokenizer 15 | :members: 16 | 17 | 18 | ``OpenAIGPTModel`` 19 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 20 | 21 | .. autoclass:: pytorch_transformers.OpenAIGPTModel 22 | :members: 23 | 24 | 25 | ``OpenAIGPTLMHeadModel`` 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | .. autoclass:: pytorch_transformers.OpenAIGPTLMHeadModel 29 | :members: 30 | 31 | 32 | ``OpenAIGPTDoubleHeadsModel`` 33 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 34 | 35 | .. autoclass:: pytorch_transformers.OpenAIGPTDoubleHeadsModel 36 | :members: 37 | -------------------------------------------------------------------------------- /docs/source/main_classes/tokenizer.rst: -------------------------------------------------------------------------------- 1 | Tokenizer 2 | ---------------------------------------------------- 3 | 4 | The base class ``PreTrainedTokenizer`` implements the common methods for loading/saving a tokenizer either from a local file or directory, or from a pretrained tokenizer provided by the library (downloaded from HuggingFace's AWS S3 repository). 5 | 6 | ``PreTrainedTokenizer`` is the main entry point into tokenizers as it also implements the main methods for using all the tokenizers: 7 | 8 | - tokenizing, converting tokens to ids and back and encoding/decoding, 9 | - adding new tokens to the vocabulary in a way that is independant of the underlying structure (BPE, SentencePiece...), 10 | - managing special tokens (adding them, assigning them to roles, making sure they are not split during tokenization) 11 | 12 | ``PreTrainedTokenizer`` 13 | ~~~~~~~~~~~~~~~~~~~~~~~~ 14 | 15 | .. autoclass:: pytorch_transformers.PreTrainedTokenizer 16 | :members: 17 | -------------------------------------------------------------------------------- /docs/source/model_doc/auto.rst: -------------------------------------------------------------------------------- 1 | AutoModels 2 | ----------- 3 | 4 | In many cases, the architecture you want to use can be guessed from the name or the path of the pretrained model you are supplying to the ``from_pretrained`` method. 5 | 6 | AutoClasses are here to do this job for you so that you automatically retreive the relevant model given the name/path to the pretrained weights/config/vocabulary: 7 | 8 | Instantiating one of ``AutoModel``, ``AutoConfig`` and ``AutoTokenizer`` will directly create a class of the relevant architecture (ex: ``model = AutoModel.from_pretrained('bert-base-cased')`` will create a instance of ``BertModel``). 9 | 10 | 11 | ``AutoConfig`` 12 | ~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autoclass:: pytorch_transformers.AutoConfig 15 | :members: 16 | 17 | 18 | ``AutoModel`` 19 | ~~~~~~~~~~~~~~~~~~~~~ 20 | 21 | .. autoclass:: pytorch_transformers.AutoModel 22 | :members: 23 | 24 | 25 | ``AutoTokenizer`` 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | .. autoclass:: pytorch_transformers.AutoTokenizer 29 | :members: 30 | -------------------------------------------------------------------------------- /docs/source/model_doc/xlm.rst: -------------------------------------------------------------------------------- 1 | XLM 2 | ---------------------------------------------------- 3 | 4 | ``XLMConfig`` 5 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autoclass:: pytorch_transformers.XLMConfig 8 | :members: 9 | 10 | ``XLMTokenizer`` 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autoclass:: pytorch_transformers.XLMTokenizer 14 | :members: 15 | 16 | ``XLMModel`` 17 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 18 | 19 | .. autoclass:: pytorch_transformers.XLMModel 20 | :members: 21 | 22 | 23 | ``XLMWithLMHeadModel`` 24 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 25 | 26 | .. autoclass:: pytorch_transformers.XLMWithLMHeadModel 27 | :members: 28 | 29 | 30 | ``XLMForSequenceClassification`` 31 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 32 | 33 | .. autoclass:: pytorch_transformers.XLMForSequenceClassification 34 | :members: 35 | 36 | 37 | ``XLMForQuestionAnswering`` 38 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 39 | 40 | .. autoclass:: pytorch_transformers.XLMForQuestionAnswering 41 | :members: 42 | -------------------------------------------------------------------------------- /docs/source/model_doc/xlnet.rst: -------------------------------------------------------------------------------- 1 | XLNet 2 | ---------------------------------------------------- 3 | 4 | ``XLNetConfig`` 5 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autoclass:: pytorch_transformers.XLNetConfig 8 | :members: 9 | 10 | 11 | ``XLNetTokenizer`` 12 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autoclass:: pytorch_transformers.XLNetTokenizer 15 | :members: 16 | 17 | 18 | ``XLNetModel`` 19 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 20 | 21 | .. autoclass:: pytorch_transformers.XLNetModel 22 | :members: 23 | 24 | 25 | ``XLNetLMHeadModel`` 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | .. autoclass:: pytorch_transformers.XLNetLMHeadModel 29 | :members: 30 | 31 | 32 | ``XLNetForSequenceClassification`` 33 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 34 | 35 | .. autoclass:: pytorch_transformers.XLNetForSequenceClassification 36 | :members: 37 | 38 | 39 | ``XLNetForQuestionAnswering`` 40 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 41 | 42 | .. autoclass:: pytorch_transformers.XLNetForQuestionAnswering 43 | :members: 44 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | jobs: 3 | build_py3: 4 | working_directory: ~/pytorch-transformers 5 | docker: 6 | - image: circleci/python:3.5 7 | resource_class: large 8 | parallelism: 4 9 | steps: 10 | - checkout 11 | - run: sudo pip install --progress-bar off . 12 | - run: sudo pip install pytest codecov pytest-cov 13 | - run: sudo pip install tensorboardX scikit-learn 14 | - run: python -m pytest -sv ./pytorch_transformers/tests/ --cov 15 | - run: python -m pytest -sv ./examples/ 16 | - run: codecov 17 | build_py2: 18 | working_directory: ~/pytorch-transformers 19 | resource_class: large 20 | parallelism: 4 21 | docker: 22 | - image: circleci/python:2.7 23 | steps: 24 | - checkout 25 | - run: sudo pip install --progress-bar off . 26 | - run: sudo pip install pytest codecov pytest-cov 27 | - run: python -m pytest -sv ./pytorch_transformers/tests/ --cov 28 | - run: codecov 29 | workflows: 30 | version: 2 31 | build_and_test: 32 | jobs: 33 | - build_py3 34 | - build_py2 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug Report" 3 | about: Submit a bug report to help us improve PyTorch Transformers 4 | --- 5 | 6 | ## 🐛 Bug 7 | 8 | 9 | 10 | Model I am using (Bert, XLNet....): 11 | 12 | Language I am using the model on (English, Chinese....): 13 | 14 | The problem arise when using: 15 | * [ ] the official example scripts: (give details) 16 | * [ ] my own modified scripts: (give details) 17 | 18 | The tasks I am working on is: 19 | * [ ] an official GLUE/SQUaD task: (give the name) 20 | * [ ] my own task or dataset: (give details) 21 | 22 | ## To Reproduce 23 | 24 | Steps to reproduce the behavior: 25 | 26 | 1. 27 | 2. 28 | 3. 29 | 30 | 31 | 32 | ## Expected behavior 33 | 34 | 35 | 36 | ## Environment 37 | 38 | * OS: 39 | * Python version: 40 | * PyTorch version: 41 | * PyTorch Transformers version (or branch): 42 | * Using GPU ? 43 | * Distributed of parallel setup ? 44 | * Any other relevant information: 45 | 46 | ## Additional context 47 | 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/migration.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DA Migration from PyTorch-pretrained-Bert" 3 | about: Report a problem when migrating from PyTorch-pretrained-Bert to PyTorch-Transformers 4 | --- 5 | 6 | ## 📚 Migration 7 | 8 | 9 | 10 | Model I am using (Bert, XLNet....): 11 | 12 | Language I am using the model on (English, Chinese....): 13 | 14 | The problem arise when using: 15 | * [ ] the official example scripts: (give details) 16 | * [ ] my own modified scripts: (give details) 17 | 18 | The tasks I am working on is: 19 | * [ ] an official GLUE/SQUaD task: (give the name) 20 | * [ ] my own task or dataset: (give details) 21 | 22 | Details of the issue: 23 | 24 | 25 | 26 | ## Environment 27 | 28 | * OS: 29 | * Python version: 30 | * PyTorch version: 31 | * PyTorch Transformers version (or branch): 32 | * Using GPU ? 33 | * Distributed of parallel setup ? 34 | * Any other relevant information: 35 | 36 | ## Checklist 37 | 38 | - [ ] I have read the migration guide in the readme. 39 | - [ ] I checked if a related official extension example runs on my machine. 40 | 41 | ## Additional context 42 | 43 | -------------------------------------------------------------------------------- /examples/tests_samples/MRPC/dev.tsv: -------------------------------------------------------------------------------- 1 | Quality #1 ID #2 ID #1 String #2 String 2 | 1 1355540 1355592 He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . " The foodservice pie business does not fit our long-term growth strategy . 3 | 0 2029631 2029565 Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war . His wife said he was " 100 percent behind George Bush " and looked forward to using his years of training in the war . 4 | 0 487993 487952 The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat . The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent . 5 | 1 1989515 1989458 The AFL-CIO is waiting until October to decide if it will endorse a candidate . The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries . 6 | 0 1783137 1782659 No dates have been set for the civil or the criminal trial . No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty . 7 | 1 3039165 3039036 Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed . It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status . 8 | -------------------------------------------------------------------------------- /examples/tests_samples/MRPC/train.tsv: -------------------------------------------------------------------------------- 1 | Quality #1 ID #2 ID #1 String #2 String 2 | 1 1355540 1355592 He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . " The foodservice pie business does not fit our long-term growth strategy . 3 | 0 2029631 2029565 Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war . His wife said he was " 100 percent behind George Bush " and looked forward to using his years of training in the war . 4 | 0 487993 487952 The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat . The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent . 5 | 1 1989515 1989458 The AFL-CIO is waiting until October to decide if it will endorse a candidate . The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries . 6 | 0 1783137 1782659 No dates have been set for the civil or the criminal trial . No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty . 7 | 1 3039165 3039036 Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed . It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status . 8 | -------------------------------------------------------------------------------- /docs/source/bertology.rst: -------------------------------------------------------------------------------- 1 | BERTology 2 | --------- 3 | 4 | There is a growing field of study concerned with investigating the inner working of large-scale transformers like BERT (that some call "BERTology"). Some good examples of this field are: 5 | 6 | 7 | * BERT Rediscovers the Classical NLP Pipeline by Ian Tenney, Dipanjan Das, Ellie Pavlick: https://arxiv.org/abs/1905.05950 8 | * Are Sixteen Heads Really Better than One? by Paul Michel, Omer Levy, Graham Neubig: https://arxiv.org/abs/1905.10650 9 | * What Does BERT Look At? An Analysis of BERT's Attention by Kevin Clark, Urvashi Khandelwal, Omer Levy, Christopher D. Manning: https://arxiv.org/abs/1906.04341 10 | 11 | In order to help this new field develop, we have included a few additional features in the BERT/GPT/GPT-2 models to help people access the inner representations, mainly adapted from the great work of Paul Michel (https://arxiv.org/abs/1905.10650): 12 | 13 | 14 | * accessing all the hidden-states of BERT/GPT/GPT-2, 15 | * accessing all the attention weights for each head of BERT/GPT/GPT-2, 16 | * retrieving heads output values and gradients to be able to compute head importance score and prune head as explained in https://arxiv.org/abs/1905.10650. 17 | 18 | To help you understand and use these features, we have added a specific example script: `bertology.py `_ while extract information and prune a model pre-trained on GLUE. 19 | -------------------------------------------------------------------------------- /docs/source/main_classes/optimizer_schedules.rst: -------------------------------------------------------------------------------- 1 | Optimizer 2 | ---------------------------------------------------- 3 | 4 | The ``.optimization`` module provides: 5 | 6 | - an optimizer with weight decay fixed that can be used to fine-tuned models, and 7 | - several schedules in the form of schedule objects that inherit from ``_LRSchedule``: 8 | 9 | ``AdamW`` 10 | ~~~~~~~~~~~~~~~~ 11 | 12 | .. autoclass:: pytorch_transformers.AdamW 13 | :members: 14 | 15 | Schedules 16 | ---------------------------------------------------- 17 | 18 | Learning Rate Schedules 19 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 20 | 21 | .. autoclass:: pytorch_transformers.ConstantLRSchedule 22 | :members: 23 | 24 | 25 | .. autoclass:: pytorch_transformers.WarmupConstantSchedule 26 | :members: 27 | 28 | .. image:: /imgs/warmup_constant_schedule.png 29 | :target: /imgs/warmup_constant_schedule.png 30 | :alt: 31 | 32 | 33 | .. autoclass:: pytorch_transformers.WarmupCosineSchedule 34 | :members: 35 | 36 | .. image:: /imgs/warmup_cosine_schedule.png 37 | :target: /imgs/warmup_cosine_schedule.png 38 | :alt: 39 | 40 | 41 | .. autoclass:: pytorch_transformers.WarmupCosineWithHardRestartsSchedule 42 | :members: 43 | 44 | .. image:: /imgs/warmup_cosine_hard_restarts_schedule.png 45 | :target: /imgs/warmup_cosine_hard_restarts_schedule.png 46 | :alt: 47 | 48 | 49 | 50 | .. autoclass:: pytorch_transformers.WarmupLinearSchedule 51 | :members: 52 | 53 | .. image:: /imgs/warmup_linear_schedule.png 54 | :target: /imgs/warmup_linear_schedule.png 55 | :alt: 56 | -------------------------------------------------------------------------------- /docs/source/model_doc/bert.rst: -------------------------------------------------------------------------------- 1 | BERT 2 | ---------------------------------------------------- 3 | 4 | ``BertConfig`` 5 | ~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | .. autoclass:: pytorch_transformers.BertConfig 8 | :members: 9 | 10 | 11 | ``BertTokenizer`` 12 | ~~~~~~~~~~~~~~~~~~~~~ 13 | 14 | .. autoclass:: pytorch_transformers.BertTokenizer 15 | :members: 16 | 17 | 18 | ``BertModel`` 19 | ~~~~~~~~~~~~~~~~~~~~ 20 | 21 | .. autoclass:: pytorch_transformers.BertModel 22 | :members: 23 | 24 | 25 | ``BertForPreTraining`` 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | .. autoclass:: pytorch_transformers.BertForPreTraining 29 | :members: 30 | 31 | 32 | ``BertForMaskedLM`` 33 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 34 | 35 | .. autoclass:: pytorch_transformers.BertForMaskedLM 36 | :members: 37 | 38 | 39 | ``BertForNextSentencePrediction`` 40 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 41 | 42 | .. autoclass:: pytorch_transformers.BertForNextSentencePrediction 43 | :members: 44 | 45 | 46 | ``BertForSequenceClassification`` 47 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 48 | 49 | .. autoclass:: pytorch_transformers.BertForSequenceClassification 50 | :members: 51 | 52 | 53 | ``BertForMultipleChoice`` 54 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 55 | 56 | .. autoclass:: pytorch_transformers.BertForMultipleChoice 57 | :members: 58 | 59 | 60 | ``BertForTokenClassification`` 61 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 62 | 63 | .. autoclass:: pytorch_transformers.BertForTokenClassification 64 | :members: 65 | 66 | 67 | ``BertForQuestionAnswering`` 68 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 69 | 70 | .. autoclass:: pytorch_transformers.BertForQuestionAnswering 71 | :members: 72 | 73 | -------------------------------------------------------------------------------- /docs/source/notebooks.rst: -------------------------------------------------------------------------------- 1 | Notebooks 2 | ================================================ 3 | 4 | We include `three Jupyter Notebooks `_ that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model. 5 | 6 | 7 | * 8 | The first NoteBook (\ `Comparing-TF-and-PT-models.ipynb `_\ ) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the standard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models. 9 | 10 | * 11 | The second NoteBook (\ `Comparing-TF-and-PT-models-SQuAD.ipynb `_\ ) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the ``BertForQuestionAnswering`` and computes the standard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models. 12 | 13 | * 14 | The third NoteBook (\ `Comparing-TF-and-PT-models-MLM-NSP.ipynb `_\ ) compares the predictions computed by the TensorFlow and the PyTorch models for masked token language modeling using the pre-trained masked language modeling model. 15 | 16 | Please follow the instructions given in the notebooks to run and modify them. 17 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Generating the documentation 2 | 3 | To generate the documentation, you first have to build it. Several packages are necessary to build the doc, 4 | you can install them using: 5 | 6 | ```bash 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ## Packages installed 11 | 12 | Here's an overview of all the packages installed. If you ran the previous command installing all packages from 13 | `requirements.txt`, you do not need to run the following commands. 14 | 15 | Building it requires the package `sphinx` that you can 16 | install using: 17 | 18 | ```bash 19 | pip install -U sphinx 20 | ``` 21 | 22 | You would also need the custom installed [theme](https://github.com/readthedocs/sphinx_rtd_theme) by 23 | [Read The Docs](https://readthedocs.org/). You can install it using the following command: 24 | 25 | ```bash 26 | pip install sphinx_rtd_theme 27 | ``` 28 | 29 | The third necessary package is the `recommonmark` package to accept Markdown as well as Restructured text: 30 | 31 | ```bash 32 | pip install recommonmark 33 | ``` 34 | 35 | ## Building the documentation 36 | 37 | Once you have setup `sphinx`, you can build the documentation by running the following command in the `/docs` folder: 38 | 39 | ```bash 40 | make html 41 | ``` 42 | 43 | --- 44 | **NOTE** 45 | 46 | If you are adding/removing elements from the toc-tree or from any strutural item, it is recommended to clean the build 47 | directory before rebuilding. Run the following command to clean and build: 48 | 49 | ```bash 50 | make clean && make html 51 | ``` 52 | 53 | --- 54 | 55 | It should build the static app that will be available under `/docs/_build/html` 56 | 57 | ## Adding a new element to the tree (toc-tree) 58 | 59 | Accepted files are reStructuredText (.rst) and Markdown (.md). Create a file with its extension and put it 60 | in the source directory. You can then link it to the toc-tree by putting the filename without the extension. 61 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from pytorch_transformers import AutoConfig, BertConfig, AutoModel, BertModel 25 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 26 | 27 | from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) 28 | 29 | 30 | class AutoModelTest(unittest.TestCase): 31 | def test_model_from_pretrained(self): 32 | logging.basicConfig(level=logging.INFO) 33 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 34 | config = AutoConfig.from_pretrained(model_name) 35 | self.assertIsNotNone(config) 36 | self.assertIsInstance(config, BertConfig) 37 | 38 | model = AutoModel.from_pretrained(model_name) 39 | model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) 40 | self.assertIsNotNone(model) 41 | self.assertIsInstance(model, BertModel) 42 | for value in loading_info.values(): 43 | self.assertEqual(len(value), 0) 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_auto_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import shutil 21 | import pytest 22 | import logging 23 | 24 | from pytorch_transformers import AutoTokenizer, BertTokenizer, AutoTokenizer, GPT2Tokenizer 25 | from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 26 | from pytorch_transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP 27 | 28 | 29 | class AutoTokenizerTest(unittest.TestCase): 30 | def test_tokenizer_from_pretrained(self): 31 | logging.basicConfig(level=logging.INFO) 32 | for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 33 | tokenizer = AutoTokenizer.from_pretrained(model_name) 34 | self.assertIsNotNone(tokenizer) 35 | self.assertIsInstance(tokenizer, BertTokenizer) 36 | self.assertGreater(len(tokenizer), 0) 37 | 38 | for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: 39 | tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | self.assertIsNotNone(tokenizer) 41 | self.assertIsInstance(tokenizer, GPT2Tokenizer) 42 | self.assertGreater(len(tokenizer), 0) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_utils_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import six 21 | 22 | from pytorch_transformers import PreTrainedTokenizer 23 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer 24 | 25 | class TokenizerUtilsTest(unittest.TestCase): 26 | def check_tokenizer_from_pretrained(self, tokenizer_class): 27 | s3_models = list(tokenizer_class.max_model_input_sizes.keys()) 28 | for model_name in s3_models[:1]: 29 | tokenizer = tokenizer_class.from_pretrained(model_name) 30 | self.assertIsNotNone(tokenizer) 31 | self.assertIsInstance(tokenizer, tokenizer_class) 32 | self.assertIsInstance(tokenizer, PreTrainedTokenizer) 33 | 34 | for special_tok in tokenizer.all_special_tokens: 35 | if six.PY2: 36 | self.assertIsInstance(special_tok, unicode) 37 | else: 38 | self.assertIsInstance(special_tok, str) 39 | special_tok_id = tokenizer.convert_tokens_to_ids(special_tok) 40 | self.assertIsInstance(special_tok_id, int) 41 | 42 | def test_pretrained_tokenizers(self): 43 | self.check_tokenizer_from_pretrained(GPT2Tokenizer) 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import pytest 21 | 22 | 23 | from pytorch_transformers import (GPT2Config, GPT2Model, 24 | GPT2LMHeadModel, GPT2DoubleHeadsModel) 25 | 26 | from .modeling_common_test import CommonTestCases, ConfigTester 27 | 28 | class GPT2ModelTest(unittest.TestCase): 29 | 30 | def test_config(self): 31 | config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37) 32 | config_tester.run_common_tests() 33 | 34 | def test_model(self): 35 | model_tester = CommonTestCases.GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model, 36 | lm_head_model_class=GPT2LMHeadModel, 37 | double_head_model_class=GPT2DoubleHeadsModel) 38 | model_tester.run_common_tests(test_presents=True) 39 | 40 | @pytest.mark.slow 41 | def test_pretrained(self): 42 | model_tester = CommonTestCases.GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model, 43 | lm_head_model_class=GPT2LMHeadModel, 44 | double_head_model_class=GPT2DoubleHeadsModel) 45 | model_tester.run_slow_tests() 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | # vscode 119 | .vscode 120 | 121 | # TF code 122 | tensorflow_code 123 | 124 | # Models 125 | models 126 | proc_data 127 | 128 | # examples 129 | runs 130 | examples/runs 131 | 132 | data -------------------------------------------------------------------------------- /pytorch_transformers/tests/modeling_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import pytest 21 | 22 | 23 | from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, 24 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) 25 | 26 | from .modeling_common_test import CommonTestCases, ConfigTester 27 | 28 | class OpenAIModelTest(unittest.TestCase): 29 | 30 | def test_config(self): 31 | config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37) 32 | config_tester.run_common_tests() 33 | 34 | def test_model(self): 35 | model_tester = CommonTestCases.GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel, 36 | lm_head_model_class=OpenAIGPTLMHeadModel, 37 | double_head_model_class=OpenAIGPTDoubleHeadsModel) 38 | model_tester.run_common_tests(test_presents=False) 39 | 40 | @pytest.mark.slow 41 | def test_pretrained(self): 42 | model_tester = CommonTestCases.GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel, 43 | lm_head_model_class=OpenAIGPTLMHeadModel, 44 | double_head_model_class=OpenAIGPTDoubleHeadsModel) 45 | model_tester.run_slow_tests() 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /docs/source/_static/js/custom.js: -------------------------------------------------------------------------------- 1 | function addIcon() { 2 | const huggingFaceLogo = "http://lysand.re/huggingface_logo.svg"; 3 | const image = document.createElement("img"); 4 | image.setAttribute("src", huggingFaceLogo); 5 | 6 | const div = document.createElement("div"); 7 | div.appendChild(image); 8 | div.style.textAlign = 'center'; 9 | div.style.paddingTop = '30px'; 10 | div.style.backgroundColor = '#6670FF'; 11 | 12 | const scrollDiv = document.getElementsByClassName("wy-side-scroll")[0]; 13 | scrollDiv.prepend(div); 14 | } 15 | 16 | function addCustomFooter() { 17 | const customFooter = document.createElement("div"); 18 | const questionOrIssue = document.createElement("div"); 19 | questionOrIssue.innerHTML = "Stuck? Read our Blog posts or Create an issue"; 20 | customFooter.appendChild(questionOrIssue); 21 | customFooter.classList.add("footer"); 22 | 23 | const social = document.createElement("div"); 24 | social.classList.add("footer__Social"); 25 | 26 | const imageDetails = [ 27 | { link: "https://huggingface.co", imageLink: "http://lysand.re/icons/website.svg" }, 28 | { link: "https://twitter.com/huggingface", imageLink: "http://lysand.re/icons/twitter.svg" }, 29 | { link: "https://github.com/huggingface", imageLink: "http://lysand.re/icons/github.svg" }, 30 | { link: "https://www.linkedin.com/company/huggingface/", imageLink: "http://lysand.re/icons/linkedin.svg" } 31 | ]; 32 | 33 | imageDetails.forEach(imageLinks => { 34 | const link = document.createElement("a"); 35 | const image = document.createElement("img"); 36 | image.src = imageLinks.imageLink; 37 | link.href = imageLinks.link; 38 | image.style.width = "30px"; 39 | image.classList.add("footer__CustomImage"); 40 | link.appendChild(image); 41 | social.appendChild(link); 42 | }); 43 | 44 | customFooter.appendChild(social); 45 | document.getElementsByTagName("footer")[0].appendChild(customFooter); 46 | } 47 | 48 | function onLoad() { 49 | addIcon(); 50 | addCustomFooter(); 51 | } 52 | 53 | window.addEventListener("load", onLoad); 54 | 55 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ================================================ 3 | 4 | PyTorch-Transformers is tested on Python 2.7 and 3.5+ (examples are tested only on python 3.5+) and PyTorch 1.1.0 5 | 6 | With pip 7 | ^^^^^^^^ 8 | 9 | PyTorch Transformers can be installed using pip as follows: 10 | 11 | .. code-block:: bash 12 | 13 | pip install pytorch-transformers 14 | 15 | From source 16 | ^^^^^^^^^^^ 17 | 18 | To install from source, clone the repository and install with: 19 | 20 | .. code-block:: bash 21 | 22 | git clone https://github.com/huggingface/pytorch-transformers.git 23 | cd pytorch-transformers 24 | pip install [--editable] . 25 | 26 | 27 | Tests 28 | ^^^^^ 29 | 30 | An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the `tests folder `_ and examples tests in the `examples folder `_. 31 | 32 | Tests can be run using `pytest` (install pytest if needed with `pip install pytest`). 33 | 34 | Run all the tests from the root of the cloned repository with the commands: 35 | 36 | .. code-block:: bash 37 | 38 | python -m pytest -sv ./pytorch_transformers/tests/ 39 | python -m pytest -sv ./examples/ 40 | 41 | 42 | OpenAI GPT original tokenization workflow 43 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 44 | 45 | If you want to reproduce the original tokenization process of the ``OpenAI GPT`` paper, you will need to install ``ftfy`` (use version 4.4.3 if you are using Python 2) and ``SpaCy`` : 46 | 47 | .. code-block:: bash 48 | 49 | pip install spacy ftfy==4.4.3 50 | python -m spacy download en 51 | 52 | If you don't install ``ftfy`` and ``SpaCy``\ , the ``OpenAI GPT`` tokenizer will default to tokenize using BERT's ``BasicTokenizer`` followed by Byte-Pair Encoding (which should be fine for most usage, don't worry). 53 | 54 | 55 | Do you want to run a Transformer model on a mobile device? 56 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 57 | 58 | You should check out our `swift-coreml-transformers `_ repo. 59 | 60 | It contains an example of a conversion script from a Pytorch trained Transformer model (here, ``GPT-2``) to a CoreML model that runs on iOS devices. 61 | 62 | It also contains an implementation of BERT for Question answering. 63 | 64 | At some point in the future, you'll be able to seamlessly move from pre-training or fine-tuning models in PyTorch to productizing them in CoreML, 65 | or prototype a model or an app in CoreML then research its hyperparameters or architecture from PyTorch. Super exciting! -------------------------------------------------------------------------------- /pytorch_transformers/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Pytorch-Transformers 2 | ================================================================================================================================================ 3 | 4 | PyTorch-Transformers is a library of state-of-the-art pre-trained models for Natural Language Processing (NLP). 5 | 6 | The library currently contains PyTorch implementations, pre-trained model weights, usage scripts and conversion utilities for the following models: 7 | 8 | 1. `BERT `_ (from Google) released with the paper `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding `_ by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. 9 | 2. `GPT `_ (from OpenAI) released with the paper `Improving Language Understanding by Generative Pre-Training `_ by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 10 | 3. `GPT-2 `_ (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `_ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 11 | 4. `Transformer-XL `_ (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `_ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. 12 | 5. `XLNet `_ (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `_ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. 13 | 6. `XLM `_ (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `_ by Guillaume Lample and Alexis Conneau. 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | :caption: Notes 18 | 19 | installation 20 | quickstart 21 | pretrained_models 22 | examples 23 | notebooks 24 | serialization 25 | converting_tensorflow_models 26 | migration 27 | bertology 28 | torchscript 29 | 30 | .. toctree:: 31 | :maxdepth: 2 32 | :caption: Main classes 33 | 34 | main_classes/configuration 35 | main_classes/model 36 | main_classes/tokenizer 37 | main_classes/optimizer_schedules 38 | 39 | .. toctree:: 40 | :maxdepth: 2 41 | :caption: Package Reference 42 | 43 | model_doc/auto 44 | model_doc/bert 45 | model_doc/gpt 46 | model_doc/transformerxl 47 | model_doc/gpt2 48 | model_doc/xlm 49 | model_doc/xlnet 50 | model_doc/roberta 51 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_openai_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | 26 | class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester): 27 | 28 | tokenizer_class = OpenAIGPTTokenizer 29 | 30 | def setUp(self): 31 | super(OpenAIGPTTokenizationTest, self).setUp() 32 | 33 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 34 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 35 | "w", "r", "t", 36 | "lo", "low", "er", 37 | "low", "lowest", "newer", "wider", ""] 38 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 39 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 40 | 41 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 42 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 43 | with open(self.vocab_file, "w") as fp: 44 | fp.write(json.dumps(vocab_tokens)) 45 | with open(self.merges_file, "w") as fp: 46 | fp.write("\n".join(merges)) 47 | 48 | def get_tokenizer(self): 49 | return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname) 50 | 51 | def get_input_output_texts(self): 52 | input_text = u"lower newer" 53 | output_text = u"lower newer" 54 | return input_text, output_text 55 | 56 | 57 | def test_full_tokenizer(self): 58 | tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file) 59 | 60 | text = "lower" 61 | bpe_tokens = ["low", "er"] 62 | tokens = tokenizer.tokenize(text) 63 | self.assertListEqual(tokens, bpe_tokens) 64 | 65 | input_tokens = tokens + [""] 66 | input_bpe_tokens = [14, 15, 20] 67 | self.assertListEqual( 68 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_gpt2_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = GPT2Tokenizer 28 | 29 | def setUp(self): 30 | super(GPT2TokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 34 | "lo", "low", "er", 35 | "low", "lowest", "newer", "wider", ""] 36 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 37 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 38 | self.special_tokens_map = {"unk_token": ""} 39 | 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 42 | with open(self.vocab_file, "w") as fp: 43 | fp.write(json.dumps(vocab_tokens)) 44 | with open(self.merges_file, "w") as fp: 45 | fp.write("\n".join(merges)) 46 | 47 | def get_tokenizer(self): 48 | return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) 49 | 50 | def get_input_output_texts(self): 51 | input_text = u"lower newer" 52 | output_text = u"lowernewer" 53 | return input_text, output_text 54 | 55 | def test_full_tokenizer(self): 56 | tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 57 | text = "lower" 58 | bpe_tokens = ["low", "er"] 59 | tokens = tokenizer.tokenize(text) 60 | self.assertListEqual(tokens, bpe_tokens) 61 | 62 | input_tokens = tokens + [tokenizer.unk_token] 63 | input_bpe_tokens = [13, 12, 17] 64 | self.assertListEqual( 65 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 66 | 67 | 68 | if __name__ == '__main__': 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_transfo_xl_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES 22 | 23 | from.tokenization_tests_commons import CommonTestCases 24 | 25 | class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = TransfoXLTokenizer 28 | 29 | def setUp(self): 30 | super(TransfoXLTokenizationTest, self).setUp() 31 | 32 | vocab_tokens = [ 33 | "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", 34 | "running", ",", "low", "l", 35 | ] 36 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 37 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 38 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 39 | 40 | def get_tokenizer(self): 41 | return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True) 42 | 43 | def get_input_output_texts(self): 44 | input_text = u" UNwanted , running" 45 | output_text = u" unwanted, running" 46 | return input_text, output_text 47 | 48 | def test_full_tokenizer(self): 49 | tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True) 50 | 51 | tokens = tokenizer.tokenize(u" UNwanted , running") 52 | self.assertListEqual(tokens, ["", "unwanted", ",", "running"]) 53 | 54 | self.assertListEqual( 55 | tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7]) 56 | 57 | def test_full_tokenizer_lower(self): 58 | tokenizer = TransfoXLTokenizer(lower_case=True) 59 | 60 | self.assertListEqual( 61 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 62 | ["hello", "!", "how", "are", "you", "?"]) 63 | 64 | def test_full_tokenizer_no_lower(self): 65 | tokenizer = TransfoXLTokenizer(lower_case=False) 66 | 67 | self.assertListEqual( 68 | tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "), 69 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 70 | 71 | 72 | if __name__ == '__main__': 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | 4 | To create the package for pypi. 5 | 6 | 1. Change the version in __init__.py and setup.py. 7 | 8 | 2. Commit these changes with the message: "Release: VERSION" 9 | 10 | 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " 11 | Push the tag to git: git push --tags origin master 12 | 13 | 4. Build both the sources and the wheel. Do not change anything in setup.py between 14 | creating the wheel and the source distribution (obviously). 15 | 16 | For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. 17 | (this will build a wheel for the python version you use to build it - make sure you use python 3.x). 18 | 19 | For the sources, run: "python setup.py sdist" 20 | You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. 21 | 22 | 5. Check that everything looks correct by uploading the package to the pypi test server: 23 | 24 | twine upload dist/* -r pypitest 25 | (pypi suggest using twine as other methods upload files via plaintext.) 26 | 27 | Check that you can install it in a virtualenv by running: 28 | pip install -i https://testpypi.python.org/pypi pytorch-transformers 29 | 30 | 6. Upload the final version to actual pypi: 31 | twine upload dist/* -r pypi 32 | 33 | 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 34 | 35 | """ 36 | from io import open 37 | from setuptools import find_packages, setup 38 | 39 | setup( 40 | name="pytorch_transformers", 41 | version="1.1.0", 42 | author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Google AI Language Team Authors, Open AI team Authors", 43 | author_email="thomas@huggingface.co", 44 | description="Repository of pre-trained NLP Transformer models: BERT & RoBERTa, GPT & GPT-2, Transformer-XL, XLNet and XLM", 45 | long_description=open("README.md", "r", encoding='utf-8').read(), 46 | long_description_content_type="text/markdown", 47 | keywords='NLP deep learning transformer pytorch BERT GPT GPT-2 google openai CMU', 48 | license='Apache', 49 | url="https://github.com/huggingface/pytorch-transformers", 50 | packages=find_packages(exclude=["*.tests", "*.tests.*", 51 | "tests.*", "tests"]), 52 | install_requires=['torch>=1.0.0', 53 | 'numpy', 54 | 'boto3', 55 | 'requests', 56 | 'tqdm', 57 | 'regex', 58 | 'sentencepiece'], 59 | entry_points={ 60 | 'console_scripts': [ 61 | "pytorch_transformers=pytorch_transformers.__main__:main", 62 | ] 63 | }, 64 | # python_requires='>=3.5.0', 65 | tests_require=['pytest'], 66 | classifiers=[ 67 | 'Intended Audience :: Science/Research', 68 | 'License :: OSI Approved :: Apache Software License', 69 | 'Programming Language :: Python :: 3', 70 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 71 | ], 72 | ) 73 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import json 21 | from io import open 22 | 23 | import torch 24 | import numpy 25 | 26 | from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME 27 | from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') 35 | 36 | model = chkpt['model'] 37 | 38 | config = chkpt['params'] 39 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 40 | 41 | vocab = chkpt['dico_word2id'] 42 | vocab = dict((s + '' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] 48 | 49 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 50 | torch.save(model, pytorch_weights_dump_path) 51 | 52 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 53 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 54 | f.write(json.dumps(config, indent=2) + "\n") 55 | 56 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 57 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 58 | f.write(json.dumps(vocab, indent=2) + "\n") 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | ## Required parameters 64 | parser.add_argument("--xlm_checkpoint_path", 65 | default = None, 66 | type = str, 67 | required = True, 68 | help = "Path the official PyTorch dump.") 69 | parser.add_argument("--pytorch_dump_folder_path", 70 | default = None, 71 | type = str, 72 | required = True, 73 | help = "Path to the output PyTorch model.") 74 | args = parser.parse_args() 75 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if gpt2_config_file == "": 36 | config = GPT2Config() 37 | else: 38 | config = GPT2Config.from_json_file(gpt2_config_file) 39 | model = GPT2Model(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--gpt2_checkpoint_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--gpt2_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 74 | args.gpt2_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.0" 2 | from .tokenization_auto import AutoTokenizer 3 | from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer 4 | from .tokenization_openai import OpenAIGPTTokenizer 5 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 6 | from .tokenization_gpt2 import GPT2Tokenizer 7 | from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE 8 | from .tokenization_xlm import XLMTokenizer 9 | from .tokenization_roberta import RobertaTokenizer 10 | 11 | from .tokenization_utils import (PreTrainedTokenizer) 12 | 13 | from .modeling_auto import (AutoConfig, AutoModel) 14 | 15 | from .modeling_bert import (BertConfig, BertPreTrainedModel, BertModel, BertForPreTraining, 16 | BertForMaskedLM, BertForNextSentencePrediction, 17 | BertForSequenceClassification, BertForMultipleChoice, 18 | BertForTokenClassification, BertForQuestionAnswering, 19 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 20 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) 21 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTPreTrainedModel, OpenAIGPTModel, 22 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 23 | load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, 24 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) 25 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, 26 | load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, 27 | TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) 28 | from .modeling_gpt2 import (GPT2Config, GPT2PreTrainedModel, GPT2Model, 29 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 30 | load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, 31 | GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) 32 | from .modeling_xlnet import (XLNetConfig, 33 | XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, 34 | XLNetForSequenceClassification, XLNetForQuestionAnswering, 35 | load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, 36 | XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) 37 | from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel, 38 | XLMWithLMHeadModel, XLMForSequenceClassification, 39 | XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, 40 | XLM_PRETRAINED_MODEL_ARCHIVE_MAP) 41 | from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, 42 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) 43 | from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, 44 | PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) 45 | 46 | from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, 47 | WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 48 | 49 | from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path) 50 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if openai_config_file == "": 36 | config = OpenAIGPTConfig() 37 | else: 38 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 39 | model = OpenAIGPTModel(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--openai_checkpoint_folder_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--openai_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 74 | args.openai_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_xlm_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | import json 20 | 21 | from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES 22 | 23 | from .tokenization_tests_commons import CommonTestCases 24 | 25 | class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | 27 | tokenizer_class = XLMTokenizer 28 | 29 | def setUp(self): 30 | super(XLMTokenizationTest, self).setUp() 31 | 32 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 33 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 34 | "w", "r", "t", 35 | "lo", "low", "er", 36 | "low", "lowest", "newer", "wider", ""] 37 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 38 | merges = ["l o 123", "lo w 1456", "e r 1789", ""] 39 | 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 42 | with open(self.vocab_file, "w") as fp: 43 | fp.write(json.dumps(vocab_tokens)) 44 | with open(self.merges_file, "w") as fp: 45 | fp.write("\n".join(merges)) 46 | 47 | def get_tokenizer(self): 48 | return XLMTokenizer.from_pretrained(self.tmpdirname) 49 | 50 | def get_input_output_texts(self): 51 | input_text = u"lower newer" 52 | output_text = u"lower newer" 53 | return input_text, output_text 54 | 55 | def test_full_tokenizer(self): 56 | """ Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """ 57 | tokenizer = XLMTokenizer(self.vocab_file, self.merges_file) 58 | 59 | text = "lower" 60 | bpe_tokens = ["low", "er"] 61 | tokens = tokenizer.tokenize(text) 62 | self.assertListEqual(tokens, bpe_tokens) 63 | 64 | input_tokens = tokens + [""] 65 | input_bpe_tokens = [14, 15, 20] 66 | self.assertListEqual( 67 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 68 | 69 | def test_sequence_builders(self): 70 | tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") 71 | 72 | text = tokenizer.encode("sequence builders") 73 | text_2 = tokenizer.encode("multi-sequence build") 74 | 75 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 76 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 77 | 78 | assert encoded_sentence == [1] + text + [1] 79 | assert encoded_pair == [1] + text + [1] + text_2 + [1] 80 | 81 | if __name__ == '__main__': 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_roberta_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import json 19 | import unittest 20 | 21 | from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES 22 | from .tokenization_tests_commons import CommonTestCases 23 | 24 | 25 | class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): 26 | tokenizer_class = RobertaTokenizer 27 | 28 | def setUp(self): 29 | super(RobertaTokenizationTest, self).setUp() 30 | 31 | # Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt 32 | vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", 33 | "lo", "low", "er", 34 | "low", "lowest", "newer", "wider", ""] 35 | vocab_tokens = dict(zip(vocab, range(len(vocab)))) 36 | merges = ["#version: 0.2", "l o", "lo w", "e r", ""] 37 | self.special_tokens_map = {"unk_token": ""} 38 | 39 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 40 | self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file']) 41 | with open(self.vocab_file, "w") as fp: 42 | fp.write(json.dumps(vocab_tokens)) 43 | with open(self.merges_file, "w") as fp: 44 | fp.write("\n".join(merges)) 45 | 46 | def get_tokenizer(self): 47 | return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) 48 | 49 | def get_input_output_texts(self): 50 | input_text = u"lower newer" 51 | output_text = u"lowernewer" 52 | return input_text, output_text 53 | 54 | def test_full_tokenizer(self): 55 | tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) 56 | text = "lower" 57 | bpe_tokens = ["low", "er"] 58 | tokens = tokenizer.tokenize(text) 59 | self.assertListEqual(tokens, bpe_tokens) 60 | 61 | input_tokens = tokens + [tokenizer.unk_token] 62 | input_bpe_tokens = [13, 12, 17] 63 | self.assertListEqual( 64 | tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) 65 | 66 | def roberta_dict_integration_testing(self): 67 | tokenizer = self.get_tokenizer() 68 | 69 | self.assertListEqual( 70 | tokenizer.encode('Hello world!'), 71 | [0, 31414, 232, 328, 2] 72 | ) 73 | self.assertListEqual( 74 | tokenizer.encode('Hello world! cécé herlolip 418'), 75 | [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] 76 | ) 77 | 78 | def test_sequence_builders(self): 79 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 80 | 81 | text = tokenizer.encode("sequence builders") 82 | text_2 = tokenizer.encode("multi-sequence build") 83 | 84 | encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) 85 | encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) 86 | 87 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 88 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 89 | 90 | assert encoded_sentence == encoded_text_from_decode 91 | assert encoded_pair == encoded_pair_from_decode 92 | 93 | 94 | if __name__ == '__main__': 95 | unittest.main() 96 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/fixtures/sample_text.txt: -------------------------------------------------------------------------------- 1 | This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত 2 | Text should be one-sentence-per-line, with empty lines between documents. 3 | This sample text is public domain and was randomly selected from Project Guttenberg. 4 | 5 | The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors. 6 | Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity. 7 | Possibly this may have been the reason why early risers in that locality, during the rainy season, adopted a thoughtful habit of body, and seldom lifted their eyes to the rifted or india-ink washed skies above them. 8 | "Cass" Beard had risen early that morning, but not with a view to discovery. 9 | A leak in his cabin roof,--quite consistent with his careless, improvident habits,--had roused him at 4 A. M., with a flooded "bunk" and wet blankets. 10 | The chips from his wood pile refused to kindle a fire to dry his bed-clothes, and he had recourse to a more provident neighbor's to supply the deficiency. 11 | This was nearly opposite. 12 | Mr. Cassius crossed the highway, and stopped suddenly. 13 | Something glittered in the nearest red pool before him. 14 | Gold, surely! 15 | But, wonderful to relate, not an irregular, shapeless fragment of crude ore, fresh from Nature's crucible, but a bit of jeweler's handicraft in the form of a plain gold ring. 16 | Looking at it more attentively, he saw that it bore the inscription, "May to Cass." 17 | Like most of his fellow gold-seekers, Cass was superstitious. 18 | 19 | The fountain of classic wisdom, Hypatia herself. 20 | As the ancient sage--the name is unimportant to a monk--pumped water nightly that he might study by day, so I, the guardian of cloaks and parasols, at the sacred doors of her lecture-room, imbibe celestial knowledge. 21 | From my youth I felt in me a soul above the matter-entangled herd. 22 | She revealed to me the glorious fact, that I am a spark of Divinity itself. 23 | A fallen star, I am, sir!' continued he, pensively, stroking his lean stomach--'a fallen star!--fallen, if the dignity of philosophy will allow of the simile, among the hogs of the lower world--indeed, even into the hog-bucket itself. Well, after all, I will show you the way to the Archbishop's. 24 | There is a philosophic pleasure in opening one's treasures to the modest young. 25 | Perhaps you will assist me by carrying this basket of fruit?' And the little man jumped up, put his basket on Philammon's head, and trotted off up a neighbouring street. 26 | Philammon followed, half contemptuous, half wondering at what this philosophy might be, which could feed the self-conceit of anything so abject as his ragged little apish guide; 27 | but the novel roar and whirl of the street, the perpetual stream of busy faces, the line of curricles, palanquins, laden asses, camels, elephants, which met and passed him, and squeezed him up steps and into doorways, as they threaded their way through the great Moon-gate into the ample street beyond, drove everything from his mind but wondering curiosity, and a vague, helpless dread of that great living wilderness, more terrible than any dead wilderness of sand which he had left behind. 28 | Already he longed for the repose, the silence of the Laura--for faces which knew him and smiled upon him; but it was too late to turn back now. 29 | His guide held on for more than a mile up the great main street, crossed in the centre of the city, at right angles, by one equally magnificent, at each end of which, miles away, appeared, dim and distant over the heads of the living stream of passengers, the yellow sand-hills of the desert; 30 | while at the end of the vista in front of them gleamed the blue harbour, through a network of countless masts. 31 | At last they reached the quay at the opposite end of the street; 32 | and there burst on Philammon's astonished eyes a vast semicircle of blue sea, ringed with palaces and towers. 33 | He stopped involuntarily; and his little guide stopped also, and looked askance at the young monk, to watch the effect which that grand panorama should produce on him. 34 | -------------------------------------------------------------------------------- /examples/test_examples.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 HuggingFace Inc.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import sys 20 | import unittest 21 | import argparse 22 | import logging 23 | 24 | try: 25 | # python 3.4+ can use builtin unittest.mock instead of mock package 26 | from unittest.mock import patch 27 | except ImportError: 28 | from mock import patch 29 | 30 | import run_glue 31 | import run_squad 32 | import run_generation 33 | 34 | logging.basicConfig(level=logging.DEBUG) 35 | 36 | logger = logging.getLogger() 37 | 38 | def get_setup_file(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('-f') 41 | args = parser.parse_args() 42 | return args.f 43 | 44 | class ExamplesTests(unittest.TestCase): 45 | 46 | def test_run_glue(self): 47 | stream_handler = logging.StreamHandler(sys.stdout) 48 | logger.addHandler(stream_handler) 49 | 50 | testargs = ["run_glue.py", 51 | "--data_dir=./examples/tests_samples/MRPC/", 52 | "--task_name=mrpc", 53 | "--do_train", 54 | "--do_eval", 55 | "--output_dir=./examples/tests_samples/temp_dir", 56 | "--per_gpu_train_batch_size=2", 57 | "--per_gpu_eval_batch_size=1", 58 | "--learning_rate=1e-4", 59 | "--max_steps=10", 60 | "--warmup_steps=2", 61 | "--overwrite_output_dir", 62 | "--seed=42"] 63 | model_type, model_name = ("--model_type=bert", 64 | "--model_name_or_path=bert-base-uncased") 65 | with patch.object(sys, 'argv', testargs + [model_type, model_name]): 66 | result = run_glue.main() 67 | for value in result.values(): 68 | self.assertGreaterEqual(value, 0.75) 69 | 70 | def test_run_squad(self): 71 | stream_handler = logging.StreamHandler(sys.stdout) 72 | logger.addHandler(stream_handler) 73 | 74 | testargs = ["run_squad.py", 75 | "--train_file=./examples/tests_samples/SQUAD/dev-v2.0-small.json", 76 | "--predict_file=./examples/tests_samples/SQUAD/dev-v2.0-small.json", 77 | "--model_name=bert-base-uncased", 78 | "--output_dir=./examples/tests_samples/temp_dir", 79 | "--max_steps=10", 80 | "--warmup_steps=2", 81 | "--do_train", 82 | "--do_eval", 83 | "--version_2_with_negative", 84 | "--learning_rate=1e-4", 85 | "--per_gpu_train_batch_size=2", 86 | "--per_gpu_eval_batch_size=1", 87 | "--overwrite_output_dir", 88 | "--seed=42"] 89 | model_type, model_name = ("--model_type=bert", 90 | "--model_name_or_path=bert-base-uncased") 91 | with patch.object(sys, 'argv', testargs + [model_type, model_name]): 92 | result = run_squad.main() 93 | self.assertGreaterEqual(result['f1'], 30) 94 | self.assertGreaterEqual(result['exact'], 30) 95 | 96 | def test_generation(self): 97 | stream_handler = logging.StreamHandler(sys.stdout) 98 | logger.addHandler(stream_handler) 99 | 100 | testargs = ["run_generation.py", 101 | "--prompt=Hello", 102 | "--length=10", 103 | "--seed=42"] 104 | model_type, model_name = ("--model_type=openai-gpt", 105 | "--model_name_or_path=openai-gpt") 106 | with patch.object(sys, 'argv', testargs + [model_type, model_name]): 107 | result = run_generation.main() 108 | self.assertGreaterEqual(len(result), 10) 109 | 110 | if __name__ == "__main__": 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import torch 24 | 25 | from pytorch_transformers.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME, 26 | XLNetConfig, 27 | XLNetLMHeadModel, XLNetForQuestionAnswering, 28 | XLNetForSequenceClassification, 29 | load_tf_weights_in_xlnet) 30 | 31 | GLUE_TASKS_NUM_LABELS = { 32 | "cola": 2, 33 | "mnli": 3, 34 | "mrpc": 2, 35 | "sst-2": 2, 36 | "sts-b": 1, 37 | "qqp": 2, 38 | "qnli": 2, 39 | "rte": 2, 40 | "wnli": 2, 41 | } 42 | 43 | import logging 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): 47 | # Initialise PyTorch model 48 | config = XLNetConfig.from_json_file(bert_config_file) 49 | 50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 51 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 53 | config.finetuning_task = finetuning_task 54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 55 | model = XLNetForSequenceClassification(config) 56 | elif 'squad' in finetuning_task: 57 | config.finetuning_task = finetuning_task 58 | model = XLNetForQuestionAnswering(config) 59 | else: 60 | model = XLNetLMHeadModel(config) 61 | 62 | # Load weights from tf checkpoint 63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 64 | 65 | # Save pytorch-model 66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 69 | torch.save(model.state_dict(), pytorch_weights_dump_path) 70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 72 | f.write(config.to_json_string()) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | ## Required parameters 78 | parser.add_argument("--tf_checkpoint_path", 79 | default = None, 80 | type = str, 81 | required = True, 82 | help = "Path to the TensorFlow checkpoint path.") 83 | parser.add_argument("--xlnet_config_file", 84 | default = None, 85 | type = str, 86 | required = True, 87 | help = "The config json file corresponding to the pre-trained XLNet model. \n" 88 | "This specifies the model architecture.") 89 | parser.add_argument("--pytorch_dump_folder_path", 90 | default = None, 91 | type = str, 92 | required = True, 93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 94 | parser.add_argument("--finetuning_task", 95 | default = None, 96 | type = str, 97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") 98 | args = parser.parse_args() 99 | print(args) 100 | 101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, 102 | args.xlnet_config_file, 103 | args.pytorch_dump_folder_path, 104 | args.finetuning_task) 105 | -------------------------------------------------------------------------------- /docs/source/converting_tensorflow_models.rst: -------------------------------------------------------------------------------- 1 | Converting Tensorflow Checkpoints 2 | ================================================ 3 | 4 | A command-line interface is provided to convert original Bert/GPT/GPT-2/Transformer-XL/XLNet/XLM checkpoints in models than be loaded using the ``from_pretrained`` methods of the library. 5 | 6 | BERT 7 | ^^^^ 8 | 9 | You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google `_\ ) in a PyTorch save file by using the `convert_tf_checkpoint_to_pytorch.py `_ script. 10 | 11 | This CLI takes as input a TensorFlow checkpoint (three files starting with ``bert_model.ckpt``\ ) and the associated configuration file (\ ``bert_config.json``\ ), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using ``torch.load()`` (see examples in `run_bert_extract_features.py `_\ , `run_bert_classifier.py `_ and `run_bert_squad.py `_\ ). 12 | 13 | You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with ``bert_model.ckpt``\ ) but be sure to keep the configuration file (\ ``bert_config.json``\ ) and the vocabulary file (\ ``vocab.txt``\ ) as these are needed for the PyTorch model too. 14 | 15 | To run this specific conversion script you will need to have TensorFlow and PyTorch installed (\ ``pip install tensorflow``\ ). The rest of the repository only requires PyTorch. 16 | 17 | Here is an example of the conversion process for a pre-trained ``BERT-Base Uncased`` model: 18 | 19 | .. code-block:: shell 20 | 21 | export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 22 | 23 | pytorch_transformers bert \ 24 | $BERT_BASE_DIR/bert_model.ckpt \ 25 | $BERT_BASE_DIR/bert_config.json \ 26 | $BERT_BASE_DIR/pytorch_model.bin 27 | 28 | You can download Google's pre-trained models for the conversion `here `__. 29 | 30 | OpenAI GPT 31 | ^^^^^^^^^^ 32 | 33 | Here is an example of the conversion process for a pre-trained OpenAI GPT model, assuming that your NumPy checkpoint save as the same format than OpenAI pretrained model (see `here `__\ ) 34 | 35 | .. code-block:: shell 36 | 37 | export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights 38 | 39 | pytorch_transformers gpt \ 40 | $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \ 41 | $PYTORCH_DUMP_OUTPUT \ 42 | [OPENAI_GPT_CONFIG] 43 | 44 | OpenAI GPT-2 45 | ^^^^^^^^^^^^ 46 | 47 | Here is an example of the conversion process for a pre-trained OpenAI GPT-2 model (see `here `__\ ) 48 | 49 | .. code-block:: shell 50 | 51 | export OPENAI_GPT2_CHECKPOINT_PATH=/path/to/gpt2/pretrained/weights 52 | 53 | pytorch_transformers gpt2 \ 54 | $OPENAI_GPT2_CHECKPOINT_PATH \ 55 | $PYTORCH_DUMP_OUTPUT \ 56 | [OPENAI_GPT2_CONFIG] 57 | 58 | Transformer-XL 59 | ^^^^^^^^^^^^^^ 60 | 61 | Here is an example of the conversion process for a pre-trained Transformer-XL model (see `here `__\ ) 62 | 63 | .. code-block:: shell 64 | 65 | export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint 66 | 67 | pytorch_transformers transfo_xl \ 68 | $TRANSFO_XL_CHECKPOINT_FOLDER_PATH \ 69 | $PYTORCH_DUMP_OUTPUT \ 70 | [TRANSFO_XL_CONFIG] 71 | 72 | 73 | XLNet 74 | ^^^^^ 75 | 76 | Here is an example of the conversion process for a pre-trained XLNet model, fine-tuned on STS-B using the TensorFlow script: 77 | 78 | .. code-block:: shell 79 | 80 | export TRANSFO_XL_CHECKPOINT_PATH=/path/to/xlnet/checkpoint 81 | export TRANSFO_XL_CONFIG_PATH=/path/to/xlnet/config 82 | 83 | pytorch_transformers xlnet \ 84 | $TRANSFO_XL_CHECKPOINT_PATH \ 85 | $TRANSFO_XL_CONFIG_PATH \ 86 | $PYTORCH_DUMP_OUTPUT \ 87 | STS-B \ 88 | 89 | 90 | XLM 91 | ^^^ 92 | 93 | Here is an example of the conversion process for a pre-trained XLM model: 94 | 95 | .. code-block:: shell 96 | 97 | export XLM_CHECKPOINT_PATH=/path/to/xlm/checkpoint 98 | 99 | pytorch_transformers xlm \ 100 | $XLM_CHECKPOINT_PATH \ 101 | $PYTORCH_DUMP_OUTPUT \ 102 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_pytorch_checkpoint_to_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from pytorch_transformers.modeling import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transpose = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | state_dict = model.state_dict() 66 | 67 | def to_tf_var_name(name:str): 68 | for patt, repl in iter(var_map): 69 | name = name.replace(patt, repl) 70 | return 'bert/{}'.format(name) 71 | 72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): 73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 75 | session.run(tf.variables_initializer([tf_var])) 76 | session.run(tf_var) 77 | return tf_var 78 | 79 | tf.reset_default_graph() 80 | with tf.Session() as session: 81 | for var_name in state_dict: 82 | tf_name = to_tf_var_name(var_name) 83 | torch_tensor = state_dict[var_name].numpy() 84 | if any([x in var_name for x in tensors_to_transpose]): 85 | torch_tensor = torch_tensor.T 86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 87 | tf.keras.backend.set_value(tf_var, torch_tensor) 88 | tf_weight = session.run(tf_var) 89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 90 | 91 | saver = tf.train.Saver(tf.trainable_variables()) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /docs/source/_static/css/huggingface.css: -------------------------------------------------------------------------------- 1 | huggingface.css 2 | 3 | /* The literal code blocks */ 4 | .rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { 5 | color: #6670FF; 6 | } 7 | 8 | /* To keep the logo centered */ 9 | .wy-side-scroll { 10 | width: auto; 11 | font-size: 20px; 12 | } 13 | 14 | /* The div that holds the Hugging Face logo */ 15 | .HuggingFaceDiv { 16 | width: 100% 17 | } 18 | 19 | /* The research field on top of the toc tree */ 20 | .wy-side-nav-search{ 21 | background-color: #6670FF; 22 | } 23 | 24 | /* The toc tree */ 25 | .wy-nav-side{ 26 | background-color: #6670FF; 27 | } 28 | 29 | /* The selected items in the toc tree */ 30 | .wy-menu-vertical li.current{ 31 | background-color: #A6B0FF; 32 | } 33 | 34 | /* When a list item that does belong to the selected block from the toc tree is hovered */ 35 | .wy-menu-vertical li.current a:hover{ 36 | background-color: #B6C0FF; 37 | } 38 | 39 | /* When a list item that does NOT belong to the selected block from the toc tree is hovered. */ 40 | .wy-menu-vertical li a:hover{ 41 | background-color: #A7AFFB; 42 | } 43 | 44 | /* The text items on the toc tree */ 45 | .wy-menu-vertical a { 46 | color: #FFFFDD; 47 | font-family: Calibre-Light; 48 | } 49 | .wy-menu-vertical header, .wy-menu-vertical p.caption{ 50 | color: white; 51 | font-family: Calibre-Light; 52 | } 53 | 54 | /* The color inside the selected toc tree block */ 55 | .wy-menu-vertical li.toctree-l2 a, .wy-menu-vertical li.toctree-l3 a, .wy-menu-vertical li.toctree-l4 a { 56 | color: black; 57 | } 58 | 59 | /* Inside the depth-2 selected toc tree block */ 60 | .wy-menu-vertical li.toctree-l2.current>a { 61 | background-color: #B6C0FF 62 | } 63 | .wy-menu-vertical li.toctree-l2.current li.toctree-l3>a { 64 | background-color: #C6D0FF 65 | } 66 | 67 | /* Inside the depth-3 selected toc tree block */ 68 | .wy-menu-vertical li.toctree-l3.current li.toctree-l4>a{ 69 | background-color: #D6E0FF 70 | } 71 | 72 | /* Inside code snippets */ 73 | .rst-content dl:not(.docutils) dt{ 74 | font-size: 15px; 75 | } 76 | 77 | /* Links */ 78 | a { 79 | color: #6670FF; 80 | } 81 | 82 | /* Content bars */ 83 | .rst-content dl:not(.docutils) dt { 84 | background-color: rgba(251, 141, 104, 0.1); 85 | border-right: solid 2px #FB8D68; 86 | border-left: solid 2px #FB8D68; 87 | color: #FB8D68; 88 | font-family: Calibre-Light; 89 | border-top: none; 90 | font-style: normal !important; 91 | } 92 | 93 | /* Expand button */ 94 | .wy-menu-vertical li.toctree-l2 span.toctree-expand, 95 | .wy-menu-vertical li.on a span.toctree-expand, .wy-menu-vertical li.current>a span.toctree-expand, 96 | .wy-menu-vertical li.toctree-l3 span.toctree-expand{ 97 | color: black; 98 | } 99 | 100 | /* Max window size */ 101 | .wy-nav-content{ 102 | max-width: 1200px; 103 | } 104 | 105 | /* Mobile header */ 106 | .wy-nav-top{ 107 | background-color: #6670FF; 108 | } 109 | 110 | 111 | /* Source spans */ 112 | .rst-content .viewcode-link, .rst-content .viewcode-back{ 113 | color: #6670FF; 114 | font-size: 110%; 115 | letter-spacing: 2px; 116 | text-transform: uppercase; 117 | } 118 | 119 | /* It would be better for table to be visible without horizontal scrolling */ 120 | .wy-table-responsive table td, .wy-table-responsive table th{ 121 | white-space: normal; 122 | } 123 | 124 | .footer { 125 | margin-top: 20px; 126 | } 127 | 128 | .footer__Social { 129 | display: flex; 130 | flex-direction: row; 131 | } 132 | 133 | .footer__CustomImage { 134 | margin: 2px 5px 0 0; 135 | } 136 | 137 | /* class and method names in doc */ 138 | .rst-content dl:not(.docutils) tt.descname, .rst-content dl:not(.docutils) tt.descclassname, .rst-content dl:not(.docutils) tt.descname, .rst-content dl:not(.docutils) code.descname, .rst-content dl:not(.docutils) tt.descclassname, .rst-content dl:not(.docutils) code.descclassname{ 139 | font-family: Calibre; 140 | font-size: 20px !important; 141 | } 142 | 143 | /* class name in doc*/ 144 | .rst-content dl:not(.docutils) tt.descname, .rst-content dl:not(.docutils) tt.descname, .rst-content dl:not(.docutils) code.descname{ 145 | margin-right: 10px; 146 | font-family: Calibre-Medium; 147 | } 148 | 149 | /* Method and class parameters */ 150 | .sig-param{ 151 | line-height: 23px; 152 | } 153 | 154 | /* Class introduction "class" string at beginning */ 155 | .rst-content dl:not(.docutils) .property{ 156 | font-size: 18px; 157 | color: black; 158 | } 159 | 160 | 161 | /* FONTS */ 162 | body{ 163 | font-family: Calibre; 164 | font-size: 16px; 165 | } 166 | 167 | h1 { 168 | font-family: Calibre-Thin; 169 | font-size: 70px; 170 | } 171 | 172 | h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend{ 173 | font-family: Calibre-Medium; 174 | } 175 | 176 | @font-face { 177 | font-family: Calibre-Medium; 178 | src: url(./Calibre-Medium.otf); 179 | font-weight:400; 180 | } 181 | 182 | @font-face { 183 | font-family: Calibre; 184 | src: url(./Calibre-Regular.otf); 185 | font-weight:400; 186 | } 187 | 188 | @font-face { 189 | font-family: Calibre-Light; 190 | src: url(./Calibre-Light.ttf); 191 | font-weight:400; 192 | } 193 | 194 | @font-face { 195 | font-family: Calibre-Thin; 196 | src: url(./Calibre-Thin.otf); 197 | font-weight:400; 198 | } 199 | 200 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_xlnet_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | 20 | from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) 21 | 22 | from .tokenization_tests_commons import CommonTestCases 23 | 24 | SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 25 | 'fixtures/test_sentencepiece.model') 26 | 27 | class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): 28 | 29 | tokenizer_class = XLNetTokenizer 30 | 31 | def setUp(self): 32 | super(XLNetTokenizationTest, self).setUp() 33 | 34 | # We have a SentencePiece fixture for testing 35 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 36 | tokenizer.save_pretrained(self.tmpdirname) 37 | 38 | def get_tokenizer(self): 39 | return XLNetTokenizer.from_pretrained(self.tmpdirname) 40 | 41 | def get_input_output_texts(self): 42 | input_text = u"This is a test" 43 | output_text = u"This is a test" 44 | return input_text, output_text 45 | 46 | 47 | def test_full_tokenizer(self): 48 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) 49 | 50 | tokens = tokenizer.tokenize(u'This is a test') 51 | self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) 52 | 53 | self.assertListEqual( 54 | tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) 55 | 56 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 57 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 58 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 59 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 60 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.']) 61 | ids = tokenizer.convert_tokens_to_ids(tokens) 62 | self.assertListEqual( 63 | ids, [8, 21, 84, 55, 24, 19, 7, 0, 64 | 602, 347, 347, 347, 3, 12, 66, 65 | 46, 72, 80, 6, 0, 4]) 66 | 67 | back_tokens = tokenizer.convert_ids_to_tokens(ids) 68 | self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 69 | u'or', u'n', SPIECE_UNDERLINE + u'in', 70 | SPIECE_UNDERLINE + u'', u'', u'2', u'0', u'0', u'0', u',', 71 | SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 72 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', 73 | u'', u'.']) 74 | 75 | def test_tokenizer_lower(self): 76 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) 77 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 78 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', 79 | u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 80 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 81 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 82 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"]) 83 | 84 | def test_tokenizer_no_lower(self): 85 | tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) 86 | tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") 87 | self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or', 88 | u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'', 89 | u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', 90 | SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) 91 | 92 | def test_sequence_builders(self): 93 | tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") 94 | 95 | text = tokenizer.encode("sequence builders") 96 | text_2 = tokenizer.encode("multi-sequence build") 97 | 98 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 99 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 100 | 101 | assert encoded_sentence == text + [4, 3] 102 | assert encoded_pair == text + [4] + text_2 + [4, 3] 103 | 104 | 105 | if __name__ == '__main__': 106 | unittest.main() 107 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_bert_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import unittest 19 | from io import open 20 | 21 | from pytorch_transformers.tokenization_bert import (BasicTokenizer, 22 | BertTokenizer, 23 | WordpieceTokenizer, 24 | _is_control, _is_punctuation, 25 | _is_whitespace, VOCAB_FILES_NAMES) 26 | 27 | from .tokenization_tests_commons import CommonTestCases 28 | 29 | class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): 30 | 31 | tokenizer_class = BertTokenizer 32 | 33 | def setUp(self): 34 | super(BertTokenizationTest, self).setUp() 35 | 36 | vocab_tokens = [ 37 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 38 | "##ing", ",", "low", "lowest", 39 | ] 40 | self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) 41 | with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: 42 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 43 | 44 | def get_tokenizer(self): 45 | return BertTokenizer.from_pretrained(self.tmpdirname) 46 | 47 | def get_input_output_texts(self): 48 | input_text = u"UNwant\u00E9d,running" 49 | output_text = u"unwanted, running" 50 | return input_text, output_text 51 | 52 | def test_full_tokenizer(self): 53 | tokenizer = BertTokenizer(self.vocab_file) 54 | 55 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 56 | self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 57 | self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 58 | 59 | def test_chinese(self): 60 | tokenizer = BasicTokenizer() 61 | 62 | self.assertListEqual( 63 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 64 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 65 | 66 | def test_basic_tokenizer_lower(self): 67 | tokenizer = BasicTokenizer(do_lower_case=True) 68 | 69 | self.assertListEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["hello", "!", "how", "are", "you", "?"]) 72 | self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 73 | 74 | def test_basic_tokenizer_no_lower(self): 75 | tokenizer = BasicTokenizer(do_lower_case=False) 76 | 77 | self.assertListEqual( 78 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 79 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 80 | 81 | def test_wordpiece_tokenizer(self): 82 | vocab_tokens = [ 83 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 84 | "##ing" 85 | ] 86 | 87 | vocab = {} 88 | for (i, token) in enumerate(vocab_tokens): 89 | vocab[token] = i 90 | tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") 91 | 92 | self.assertListEqual(tokenizer.tokenize(""), []) 93 | 94 | self.assertListEqual( 95 | tokenizer.tokenize("unwanted running"), 96 | ["un", "##want", "##ed", "runn", "##ing"]) 97 | 98 | self.assertListEqual( 99 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 100 | 101 | def test_is_whitespace(self): 102 | self.assertTrue(_is_whitespace(u" ")) 103 | self.assertTrue(_is_whitespace(u"\t")) 104 | self.assertTrue(_is_whitespace(u"\r")) 105 | self.assertTrue(_is_whitespace(u"\n")) 106 | self.assertTrue(_is_whitespace(u"\u00A0")) 107 | 108 | self.assertFalse(_is_whitespace(u"A")) 109 | self.assertFalse(_is_whitespace(u"-")) 110 | 111 | def test_is_control(self): 112 | self.assertTrue(_is_control(u"\u0005")) 113 | 114 | self.assertFalse(_is_control(u"A")) 115 | self.assertFalse(_is_control(u" ")) 116 | self.assertFalse(_is_control(u"\t")) 117 | self.assertFalse(_is_control(u"\r")) 118 | 119 | def test_is_punctuation(self): 120 | self.assertTrue(_is_punctuation(u"-")) 121 | self.assertTrue(_is_punctuation(u"$")) 122 | self.assertTrue(_is_punctuation(u"`")) 123 | self.assertTrue(_is_punctuation(u".")) 124 | 125 | self.assertFalse(_is_punctuation(u"A")) 126 | self.assertFalse(_is_punctuation(u" ")) 127 | 128 | def test_sequence_builders(self): 129 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 130 | 131 | text = tokenizer.encode("sequence builders") 132 | text_2 = tokenizer.encode("multi-sequence build") 133 | 134 | encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) 135 | encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) 136 | 137 | assert encoded_sentence == [101] + text + [102] 138 | assert encoded_pair == [101] + text + [102] + text_2 + [102] 139 | 140 | if __name__ == '__main__': 141 | unittest.main() 142 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_transformers.tokenization_transfo_xl as data_utils 27 | 28 | from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME 29 | from pytorch_transformers.modeling_transfo_xl import (TransfoXLConfig, TransfoXLLMHeadModel, 30 | load_tf_weights_in_transfo_xl) 31 | from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) 32 | 33 | if sys.version_info[0] == 2: 34 | import cPickle as pickle 35 | else: 36 | import pickle 37 | 38 | import logging 39 | logging.basicConfig(level=logging.INFO) 40 | 41 | # We do this to be able to load python 2 datasets pickles 42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 43 | data_utils.Vocab = data_utils.TransfoXLTokenizer 44 | data_utils.Corpus = data_utils.TransfoXLCorpus 45 | sys.modules['data_utils'] = data_utils 46 | sys.modules['vocabulary'] = data_utils 47 | 48 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 49 | transfo_xl_config_file, 50 | pytorch_dump_folder_path, 51 | transfo_xl_dataset_file): 52 | if transfo_xl_dataset_file: 53 | # Convert a pre-processed corpus (see original TensorFlow repo) 54 | with open(transfo_xl_dataset_file, "rb") as fp: 55 | corpus = pickle.load(fp, encoding="latin1") 56 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 57 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] 58 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 59 | corpus_vocab_dict = corpus.vocab.__dict__ 60 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 61 | 62 | corpus_dict_no_vocab = corpus.__dict__ 63 | corpus_dict_no_vocab.pop('vocab', None) 64 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 65 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 66 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 67 | 68 | if tf_checkpoint_path: 69 | # Convert a pre-trained TensorFlow model 70 | config_path = os.path.abspath(transfo_xl_config_file) 71 | tf_path = os.path.abspath(tf_checkpoint_path) 72 | 73 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 74 | # Initialise PyTorch model 75 | if transfo_xl_config_file == "": 76 | config = TransfoXLConfig() 77 | else: 78 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file) 79 | print("Building PyTorch model from configuration: {}".format(str(config))) 80 | model = TransfoXLLMHeadModel(config) 81 | 82 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 83 | # Save pytorch-model 84 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 85 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 86 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 87 | torch.save(model.state_dict(), pytorch_weights_dump_path) 88 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 89 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 90 | f.write(config.to_json_string()) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--pytorch_dump_folder_path", 96 | default = None, 97 | type = str, 98 | required = True, 99 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 100 | parser.add_argument("--tf_checkpoint_path", 101 | default = "", 102 | type = str, 103 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 104 | parser.add_argument("--transfo_xl_config_file", 105 | default = "", 106 | type = str, 107 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 108 | "This specifies the model architecture.") 109 | parser.add_argument("--transfo_xl_dataset_file", 110 | default = "", 111 | type = str, 112 | help = "An optional dataset file to be converted in a vocabulary.") 113 | args = parser.parse_args() 114 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 115 | args.transfo_xl_config_file, 116 | args.pytorch_dump_folder_path, 117 | args.transfo_xl_dataset_file) 118 | -------------------------------------------------------------------------------- /docs/source/migration.md: -------------------------------------------------------------------------------- 1 | # Migrating from pytorch-pretrained-bert 2 | 3 | 4 | Here is a quick summary of what you should take care of when migrating from `pytorch-pretrained-bert` to `pytorch-transformers` 5 | 6 | ### Models always output `tuples` 7 | 8 | The main breaking change when migrating from `pytorch-pretrained-bert` to `pytorch-transformers` is that the models forward method always outputs a `tuple` with various elements depending on the model and the configuration parameters. 9 | 10 | The exact content of the tuples for each model are detailled in the models' docstrings and the [documentation](https://huggingface.co/pytorch-transformers/). 11 | 12 | In pretty much every case, you will be fine by taking the first element of the output as the output you previously used in `pytorch-pretrained-bert`. 13 | 14 | Here is a `pytorch-pretrained-bert` to `pytorch-transformers` conversion example for a `BertForSequenceClassification` classification model: 15 | 16 | ```python 17 | # Let's load our model 18 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased') 19 | 20 | # If you used to have this line in pytorch-pretrained-bert: 21 | loss = model(input_ids, labels=labels) 22 | 23 | # Now just use this line in pytorch-transformers to extract the loss from the output tuple: 24 | outputs = model(input_ids, labels=labels) 25 | loss = outputs[0] 26 | 27 | # In pytorch-transformers you can also have access to the logits: 28 | loss, logits = outputs[:2] 29 | 30 | # And even the attention weigths if you configure the model to output them (and other outputs too, see the docstrings and documentation) 31 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased', output_attentions=True) 32 | outputs = model(input_ids, labels=labels) 33 | loss, logits, attentions = outputs 34 | ``` 35 | 36 | ### Serialization 37 | 38 | Breaking change in the `from_pretrained()`method: 39 | 40 | 1. Models are now set in evaluation mode by default when instantiated with the `from_pretrained()` method. To train them don't forget to set them back in training mode (`model.train()`) to activate the dropout modules. 41 | 42 | 2. The additional `*inputs` and `**kwargs` arguments supplied to the `from_pretrained()` method used to be directly passed to the underlying model's class `__init__()` method. They are now used to update the model configuration attribute first which can break derived model classes build based on the previous `BertForSequenceClassification` examples. More precisely, the positional arguments `*inputs` provided to `from_pretrained()` are directly forwarded the model `__init__()` method while the keyword arguments `**kwargs` (i) which match configuration class attributes are used to update said attributes (ii) which don't match any configuration class attributes are forwarded to the model `__init__()` method. 43 | 44 | Also, while not a breaking change, the serialization methods have been standardized and you probably should switch to the new method `save_pretrained(save_directory)` if you were using any other serialization method before. 45 | 46 | Here is an example: 47 | 48 | ```python 49 | ### Let's load a model and tokenizer 50 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased') 51 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 52 | 53 | ### Do some stuff to our model and tokenizer 54 | # Ex: add new tokens to the vocabulary and embeddings of our model 55 | tokenizer.add_tokens(['[SPECIAL_TOKEN_1]', '[SPECIAL_TOKEN_2]']) 56 | model.resize_token_embeddings(len(tokenizer)) 57 | # Train our model 58 | train(model) 59 | 60 | ### Now let's save our model and tokenizer to a directory 61 | model.save_pretrained('./my_saved_model_directory/') 62 | tokenizer.save_pretrained('./my_saved_model_directory/') 63 | 64 | ### Reload the model and the tokenizer 65 | model = BertForSequenceClassification.from_pretrained('./my_saved_model_directory/') 66 | tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/') 67 | ``` 68 | 69 | ### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules 70 | 71 | The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences: 72 | 73 | - it only implements weights decay correction, 74 | - schedules are now externals (see below), 75 | - gradient clipping is now also external (see below). 76 | 77 | The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping. 78 | 79 | The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore. 80 | 81 | Here is a conversion examples from `BertAdam` with a linear warmup and decay schedule to `AdamW` and the same schedule: 82 | 83 | ```python 84 | # Parameters: 85 | lr = 1e-3 86 | max_grad_norm = 1.0 87 | num_total_steps = 1000 88 | num_warmup_steps = 100 89 | warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 90 | 91 | ### Previously BertAdam optimizer was instantiated like this: 92 | optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_total_steps) 93 | ### and used like this: 94 | for batch in train_data: 95 | loss = model(batch) 96 | loss.backward() 97 | optimizer.step() 98 | 99 | ### In PyTorch-Transformers, optimizer and schedules are splitted and instantiated like this: 100 | optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce BertAdam specific behavior set correct_bias=False 101 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler 102 | ### and used like this: 103 | for batch in train_data: 104 | loss = model(batch) 105 | loss.backward() 106 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue) 107 | scheduler.step() 108 | optimizer.step() 109 | ``` 110 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_auto.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Auto Model class. """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import logging 20 | 21 | from .tokenization_bert import BertTokenizer 22 | from .tokenization_openai import OpenAIGPTTokenizer 23 | from .tokenization_gpt2 import GPT2Tokenizer 24 | from .tokenization_transfo_xl import TransfoXLTokenizer 25 | from .tokenization_xlnet import XLNetTokenizer 26 | from .tokenization_xlm import XLMTokenizer 27 | from .tokenization_roberta import RobertaTokenizer 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | class AutoTokenizer(object): 32 | r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class 33 | that will be instantiated as one of the tokenizer classes of the library 34 | when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` 35 | class method. 36 | 37 | The `from_pretrained()` method take care of returning the correct tokenizer class instance 38 | using pattern matching on the `pretrained_model_name_or_path` string. 39 | 40 | The tokenizer class to instantiate is selected as the first pattern matching 41 | in the `pretrained_model_name_or_path` string (in the following order): 42 | - contains `bert`: BertTokenizer (Bert model) 43 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) 44 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) 45 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) 46 | - contains `xlnet`: XLNetTokenizer (XLNet model) 47 | - contains `xlm`: XLMTokenizer (XLM model) 48 | - contains `roberta`: RobertaTokenizer (RoBERTa model) 49 | 50 | This class cannot be instantiated using `__init__()` (throw an error). 51 | """ 52 | def __init__(self): 53 | raise EnvironmentError("AutoTokenizer is designed to be instantiated " 54 | "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.") 55 | 56 | @classmethod 57 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 58 | r""" Instantiate a one of the tokenizer classes of the library 59 | from a pre-trained model vocabulary. 60 | 61 | The tokenizer class to instantiate is selected as the first pattern matching 62 | in the `pretrained_model_name_or_path` string (in the following order): 63 | - contains `bert`: BertTokenizer (Bert model) 64 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) 65 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) 66 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) 67 | - contains `xlnet`: XLNetTokenizer (XLNet model) 68 | - contains `xlm`: XLMTokenizer (XLM model) 69 | - contains `roberta`: RobertaTokenizer (XLM model) 70 | 71 | Params: 72 | **pretrained_model_name_or_path**: either: 73 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache 74 | or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). 75 | - a path to a `directory` containing a configuration file saved 76 | using the `save_pretrained(save_directory)` method. 77 | - a path or url to a saved configuration `file`. 78 | **cache_dir**: (`optional`) string: 79 | Path to a directory in which a downloaded pre-trained model 80 | configuration should be cached if the standard cache should not be used. 81 | 82 | Examples:: 83 | 84 | config = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache. 85 | config = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` 86 | 87 | """ 88 | if 'roberta' in pretrained_model_name_or_path: 89 | return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 90 | elif 'bert' in pretrained_model_name_or_path: 91 | return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 92 | elif 'openai-gpt' in pretrained_model_name_or_path: 93 | return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 94 | elif 'gpt2' in pretrained_model_name_or_path: 95 | return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 96 | elif 'transfo-xl' in pretrained_model_name_or_path: 97 | return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 98 | elif 'xlnet' in pretrained_model_name_or_path: 99 | return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 100 | elif 'xlm' in pretrained_model_name_or_path: 101 | return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 102 | 103 | raise ValueError("Unrecognized model identifier in {}. Should contains one of " 104 | "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " 105 | "'xlm', 'roberta'".format(pretrained_model_name_or_path)) 106 | -------------------------------------------------------------------------------- /hubconfs/transformer_xl_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer 2 | from pytorch_transformers.modeling_transfo_xl import ( 3 | TransfoXLModel, 4 | TransfoXLLMHeadModel 5 | ) 6 | 7 | # A lot of models share the same param doc. Use a decorator 8 | # to save typing 9 | transformer_xl_docstring = """ 10 | Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that: 11 | - you don't need to specify positioning embeddings indices 12 | - the tokens in the vocabulary have to be sorted to decreasing frequency. 13 | 14 | Params: 15 | pretrained_model_name_or_path: either: 16 | - a str with the name of a pre-trained model to load selected in the list of: 17 | . `transfo-xl-wt103` 18 | - a path or url to a pretrained model archive containing: 19 | . `transfo_xl_config.json` a configuration file for the model 20 | . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance 21 | - a path or url to a pretrained model archive containing: 22 | . `transfo_xl_config.json` a configuration file for the model 23 | . `model.chkpt` a TensorFlow checkpoint 24 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 25 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 26 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 27 | *inputs, **kwargs: additional input for the specific TransformerXL class 28 | """ 29 | 30 | 31 | def _append_from_pretrained_docstring(docstr): 32 | def docstring_decorator(fn): 33 | fn.__doc__ = fn.__doc__ + docstr 34 | return fn 35 | return docstring_decorator 36 | 37 | 38 | def transformerXLTokenizer(*args, **kwargs): 39 | """ 40 | Instantiate a Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl 41 | 42 | Args: 43 | pretrained_model_name_or_path: Path to pretrained model archive 44 | or one of pre-trained vocab configs below. 45 | * transfo-xl-wt103 46 | 47 | Example: 48 | import torch 49 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103') 50 | 51 | text = "Who was Jim Henson ?" 52 | tokenized_text = tokenizer.tokenize(tokenized_text) 53 | indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) 54 | """ 55 | tokenizer = TransfoXLTokenizer.from_pretrained(*args, **kwargs) 56 | return tokenizer 57 | 58 | 59 | @_append_from_pretrained_docstring(transformer_xl_docstring) 60 | def transformerXLModel(*args, **kwargs): 61 | """ 62 | transformerXLModel is the basic Transformer XL model. 63 | 64 | Example: 65 | # Load the tokenizer 66 | import torch 67 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103') 68 | 69 | # Prepare tokenized input 70 | text_1 = "Who was Jim Henson ?" 71 | text_2 = "Jim Henson was a puppeteer" 72 | tokenized_text_1 = tokenizer.tokenize(text_1) 73 | tokenized_text_2 = tokenizer.tokenize(text_2) 74 | indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1) 75 | indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2) 76 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 77 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 78 | 79 | # Load transformerXLModel 80 | model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLModel', 'transfo-xl-wt103') 81 | model.eval() 82 | 83 | # Predict hidden states features for each layer 84 | # We can re-use the memory cells in a subsequent call to attend a longer context 85 | with torch.no_grad(): 86 | hidden_states_1, mems_1 = model(tokens_tensor_1) 87 | hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1) 88 | """ 89 | model = TransfoXLModel.from_pretrained(*args, **kwargs) 90 | return model 91 | 92 | 93 | @_append_from_pretrained_docstring(transformer_xl_docstring) 94 | def transformerXLLMHeadModel(*args, **kwargs): 95 | """ 96 | transformerXLModel is the basic Transformer XL model with the 97 | tied (pre-trained) language modeling head on top. 98 | 99 | Example: 100 | # Load the tokenizer 101 | import torch 102 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLTokenizer', 'transfo-xl-wt103') 103 | 104 | # Prepare tokenized input 105 | text_1 = "Who was Jim Henson ?" 106 | text_2 = "Jim Henson was a puppeteer" 107 | tokenized_text_1 = tokenizer.tokenize(text_1) 108 | tokenized_text_2 = tokenizer.tokenize(text_2) 109 | indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1) 110 | indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2) 111 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 112 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 113 | 114 | # Load transformerXLLMHeadModel 115 | model = torch.hub.load('huggingface/pytorch-transformers', 'transformerXLLMHeadModel', 'transfo-xl-wt103') 116 | model.eval() 117 | 118 | # Predict hidden states features for each layer 119 | # We can re-use the memory cells in a subsequent call to attend a longer context 120 | with torch.no_grad(): 121 | predictions_1, mems_1 = model(tokens_tensor_1) 122 | predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1) 123 | 124 | # Get the predicted last token 125 | predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 126 | predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] 127 | assert predicted_token == 'who' 128 | """ 129 | model = TransfoXLLMHeadModel.from_pretrained(*args, **kwargs) 130 | return model 131 | -------------------------------------------------------------------------------- /examples/lm_finetuning/README.md: -------------------------------------------------------------------------------- 1 | # BERT Model Finetuning using Masked Language Modeling objective 2 | 3 | ## Introduction 4 | 5 | The three example scripts in this folder can be used to **fine-tune** a pre-trained BERT model using the pretraining objective (combination of masked language modeling and next sentence prediction loss). In general, pretrained models like BERT are first trained with a pretraining objective (masked language modeling and next sentence prediction for BERT) on a large and general natural language corpus. A classifier head is then added on top of the pre-trained architecture and the model is quickly fine-tuned on a target task, while still (hopefully) retaining its general language understanding. This greatly reduces overfitting and yields state-of-the-art results, especially when training data for the target task are limited. 6 | 7 | The [ULMFiT paper](https://arxiv.org/abs/1801.06146) took a slightly different approach, however, and added an intermediate step in which the model is fine-tuned on text **from the same domain as the target task and using the pretraining objective** before the final stage in which the classifier head is added and the model is trained on the target task itself. This paper reported significantly improved results from this step, and found that they could get high-quality classifications even with only tiny numbers (<1000) of labelled training examples, as long as they had a lot of unlabelled data from the target domain. 8 | 9 | Although this wasn't covered in the original BERT paper, domain-specific fine-tuning of Transformer models has [recently been reported by other authors](https://arxiv.org/pdf/1905.05583.pdf), and they report performance improvements as well. 10 | 11 | ## Input format 12 | 13 | The scripts in this folder expect a single file as input, consisting of untokenized text, with one **sentence** per line, and one blank line between documents. The reason for the sentence splitting is that part of BERT's training involves a _next sentence_ objective in which the model must predict whether two sequences of text are contiguous text from the same document or not, and to avoid making the task _too easy_, the split point between the sequences is always at the end of a sentence. The linebreaks in the file are therefore necessary to mark the points where the text can be split. 14 | 15 | ## Usage 16 | 17 | There are two ways to fine-tune a language model using these scripts. The first _quick_ approach is to use [`simple_lm_finetuning.py`](./simple_lm_finetuning.py). This script does everything in a single script, but generates training instances that consist of just two sentences. This is quite different from the BERT paper, where (confusingly) the NextSentence task concatenated sentences together from each document to form two long multi-sentences, which the paper just referred to as _sentences_. The difference between this simple approach and the original paper approach can have a significant effect for long sequences since two sentences will be much shorter than the max sequence length. In this case, most of each training example will just consist of blank padding characters, which wastes a lot of computation and results in a model that isn't really training on long sequences. 18 | 19 | As such, the preferred approach (assuming you have documents containing multiple contiguous sentences from your target domain) is to use [`pregenerate_training_data.py`](./pregenerate_training_data.py) to pre-process your data into training examples following the methodology used for LM training in the original BERT paper and repository. Since there is a significant random component to training data generation for BERT, this script includes an option to generate multiple _epochs_ of pre-processed data, to avoid training on the same random splits each epoch. Generating an epoch of data for each training epoch should result a better final model, and so we recommend doing so. 20 | 21 | You can then train on the pregenerated data using [`finetune_on_pregenerated.py`](./finetune_on_pregenerated.py), and pointing it to the folder created by [`pregenerate_training_data.py`](./pregenerate_training_data.py). Note that you should use the same `bert_model` and case options for both! Also note that `max_seq_len` does not need to be specified for the [`finetune_on_pregenerated.py`](./finetune_on_pregenerated.py) script, as it is inferred from the training examples. 22 | 23 | There are various options that can be tweaked, but they are mostly set to the values from the BERT paper/repository and default values should make sense. The most relevant ones are: 24 | 25 | - `--max_seq_len`: Controls the length of training examples (in wordpiece tokens) seen by the model. Defaults to 128 but can be set as high as 512. Higher values may yield stronger language models at the cost of slower and more memory-intensive training. 26 | - `--fp16`: Enables fast half-precision training on recent GPUs. 27 | 28 | In addition, if memory usage is an issue, especially when training on a single GPU, reducing `--train_batch_size` from the default 32 to a lower number (4-16) can be helpful, or leaving `--train_batch_size` at the default and increasing `--gradient_accumulation_steps` to 2-8. Changing `--gradient_accumulation_steps` may be preferable as alterations to the batch size may require corresponding changes in the learning rate to compensate. There is also a `--reduce_memory` option for both the `pregenerate_training_data.py` and `finetune_on_pregenerated.py` scripts that spills data to disc in shelf objects or numpy memmaps rather than retaining it in memory, which significantly reduces memory usage with little performance impact. 29 | 30 | ## Examples 31 | 32 | ### Simple fine-tuning 33 | 34 | ``` 35 | python3 simple_lm_finetuning.py 36 | --train_corpus my_corpus.txt 37 | --bert_model bert-base-uncased 38 | --do_lower_case 39 | --output_dir finetuned_lm/ 40 | --do_train 41 | ``` 42 | 43 | ### Pregenerating training data 44 | 45 | ``` 46 | python3 pregenerate_training_data.py 47 | --train_corpus my_corpus.txt 48 | --bert_model bert-base-uncased 49 | --do_lower_case 50 | --output_dir training/ 51 | --epochs_to_generate 3 52 | --max_seq_len 256 53 | ``` 54 | 55 | ### Training on pregenerated data 56 | 57 | ``` 58 | python3 finetune_on_pregenerated.py 59 | --pregenerated_data training/ 60 | --bert_model bert-base-uncased 61 | --do_lower_case 62 | --output_dir finetuned_lm/ 63 | --epochs 3 64 | ``` 65 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('../..')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = u'pytorch-transformers' 23 | copyright = u'2019, huggingface' 24 | author = u'huggingface' 25 | 26 | # The short X.Y version 27 | version = u'' 28 | # The full version, including alpha/beta/rc tags 29 | release = u'1.0.0' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx.ext.coverage', 44 | 'sphinx.ext.napoleon', 45 | 'recommonmark', 46 | 'sphinx.ext.viewcode' 47 | ] 48 | 49 | # Add any paths that contain templates here, relative to this directory. 50 | templates_path = ['_templates'] 51 | 52 | # The suffix(es) of source filenames. 53 | # You can specify multiple suffix as a list of string: 54 | # 55 | source_suffix = ['.rst', '.md'] 56 | # source_suffix = '.rst' 57 | 58 | # The master toctree document. 59 | master_doc = 'index' 60 | 61 | # The language for content autogenerated by Sphinx. Refer to documentation 62 | # for a list of supported languages. 63 | # 64 | # This is also used if you do content translation via gettext catalogs. 65 | # Usually you set "language" from the command line for these cases. 66 | language = None 67 | 68 | # List of patterns, relative to source directory, that match files and 69 | # directories to ignore when looking for source files. 70 | # This pattern also affects html_static_path and html_extra_path. 71 | exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store'] 72 | 73 | # The name of the Pygments (syntax highlighting) style to use. 74 | pygments_style = None 75 | 76 | 77 | # -- Options for HTML output ------------------------------------------------- 78 | 79 | # The theme to use for HTML and HTML Help pages. See the documentation for 80 | # a list of builtin themes. 81 | # 82 | html_theme = 'sphinx_rtd_theme' 83 | 84 | # Theme options are theme-specific and customize the look and feel of a theme 85 | # further. For a list of options available for each theme, see the 86 | # documentation. 87 | # 88 | html_theme_options = { 89 | 'analytics_id': 'UA-83738774-2' 90 | } 91 | 92 | # Add any paths that contain custom static files (such as style sheets) here, 93 | # relative to this directory. They are copied after the builtin static files, 94 | # so a file named "default.css" will overwrite the builtin "default.css". 95 | html_static_path = ['_static'] 96 | 97 | # Custom sidebar templates, must be a dictionary that maps document names 98 | # to template names. 99 | # 100 | # The default sidebars (for documents that don't match any pattern) are 101 | # defined by theme itself. Builtin themes are using these templates by 102 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 103 | # 'searchbox.html']``. 104 | # 105 | # html_sidebars = {} 106 | 107 | 108 | # -- Options for HTMLHelp output --------------------------------------------- 109 | 110 | # Output file base name for HTML help builder. 111 | htmlhelp_basename = 'pytorch-transformersdoc' 112 | 113 | 114 | # -- Options for LaTeX output ------------------------------------------------ 115 | 116 | latex_elements = { 117 | # The paper size ('letterpaper' or 'a4paper'). 118 | # 119 | # 'papersize': 'letterpaper', 120 | 121 | # The font size ('10pt', '11pt' or '12pt'). 122 | # 123 | # 'pointsize': '10pt', 124 | 125 | # Additional stuff for the LaTeX preamble. 126 | # 127 | # 'preamble': '', 128 | 129 | # Latex figure (float) alignment 130 | # 131 | # 'figure_align': 'htbp', 132 | } 133 | 134 | # Grouping the document tree into LaTeX files. List of tuples 135 | # (source start file, target name, title, 136 | # author, documentclass [howto, manual, or own class]). 137 | latex_documents = [ 138 | (master_doc, 'pytorch-transformers.tex', u'pytorch-transformers Documentation', 139 | u'huggingface', 'manual'), 140 | ] 141 | 142 | 143 | # -- Options for manual page output ------------------------------------------ 144 | 145 | # One entry per manual page. List of tuples 146 | # (source start file, name, description, authors, manual section). 147 | man_pages = [ 148 | (master_doc, 'pytorch-transformers', u'pytorch-transformers Documentation', 149 | [author], 1) 150 | ] 151 | 152 | 153 | # -- Options for Texinfo output ---------------------------------------------- 154 | 155 | # Grouping the document tree into Texinfo files. List of tuples 156 | # (source start file, target name, title, author, 157 | # dir menu entry, description, category) 158 | texinfo_documents = [ 159 | (master_doc, 'pytorch-transformers', u'pytorch-transformers Documentation', 160 | author, 'pytorch-transformers', 'One line description of project.', 161 | 'Miscellaneous'), 162 | ] 163 | 164 | 165 | # -- Options for Epub output ------------------------------------------------- 166 | 167 | # Bibliographic Dublin Core info. 168 | epub_title = project 169 | 170 | # The unique identifier of the text. This can be a ISBN number 171 | # or the project homepage. 172 | # 173 | # epub_identifier = '' 174 | 175 | # A unique identification for the text. 176 | # 177 | # epub_uid = '' 178 | 179 | # A list of files that should not be packed into the epub file. 180 | epub_exclude_files = ['search.html'] 181 | 182 | def setup(app): 183 | app.add_stylesheet('css/huggingface.css') 184 | app.add_stylesheet('css/code-snippets.css') 185 | app.add_js_file('js/custom.js') 186 | 187 | # -- Extension configuration ------------------------------------------------- 188 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/tokenization_tests_commons.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 HuggingFace Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import os 18 | import sys 19 | from io import open 20 | import tempfile 21 | import shutil 22 | import unittest 23 | 24 | if sys.version_info[0] == 2: 25 | import cPickle as pickle 26 | 27 | class TemporaryDirectory(object): 28 | """Context manager for tempfile.mkdtemp() so it's usable with "with" statement.""" 29 | def __enter__(self): 30 | self.name = tempfile.mkdtemp() 31 | return self.name 32 | def __exit__(self, exc_type, exc_value, traceback): 33 | shutil.rmtree(self.name) 34 | else: 35 | import pickle 36 | TemporaryDirectory = tempfile.TemporaryDirectory 37 | unicode = str 38 | 39 | 40 | class CommonTestCases: 41 | 42 | class CommonTokenizerTester(unittest.TestCase): 43 | 44 | tokenizer_class = None 45 | 46 | def setUp(self): 47 | self.tmpdirname = tempfile.mkdtemp() 48 | 49 | def tearDown(self): 50 | shutil.rmtree(self.tmpdirname) 51 | 52 | def get_tokenizer(self): 53 | raise NotImplementedError 54 | 55 | def get_input_output_texts(self): 56 | raise NotImplementedError 57 | 58 | def test_save_and_load_tokenizer(self): 59 | tokenizer = self.get_tokenizer() 60 | 61 | before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") 62 | 63 | with TemporaryDirectory() as tmpdirname: 64 | tokenizer.save_pretrained(tmpdirname) 65 | tokenizer = tokenizer.from_pretrained(tmpdirname) 66 | 67 | after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") 68 | self.assertListEqual(before_tokens, after_tokens) 69 | 70 | def test_pickle_tokenizer(self): 71 | tokenizer = self.get_tokenizer() 72 | self.assertIsNotNone(tokenizer) 73 | 74 | text = u"Munich and Berlin are nice cities" 75 | subwords = tokenizer.tokenize(text) 76 | 77 | with TemporaryDirectory() as tmpdirname: 78 | 79 | filename = os.path.join(tmpdirname, u"tokenizer.bin") 80 | pickle.dump(tokenizer, open(filename, "wb")) 81 | 82 | tokenizer_new = pickle.load(open(filename, "rb")) 83 | 84 | subwords_loaded = tokenizer_new.tokenize(text) 85 | 86 | self.assertListEqual(subwords, subwords_loaded) 87 | 88 | 89 | def test_add_tokens_tokenizer(self): 90 | tokenizer = self.get_tokenizer() 91 | 92 | vocab_size = tokenizer.vocab_size 93 | all_size = len(tokenizer) 94 | 95 | self.assertNotEqual(vocab_size, 0) 96 | self.assertEqual(vocab_size, all_size) 97 | 98 | new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"] 99 | added_toks = tokenizer.add_tokens(new_toks) 100 | vocab_size_2 = tokenizer.vocab_size 101 | all_size_2 = len(tokenizer) 102 | 103 | self.assertNotEqual(vocab_size_2, 0) 104 | self.assertEqual(vocab_size, vocab_size_2) 105 | self.assertEqual(added_toks, len(new_toks)) 106 | self.assertEqual(all_size_2, all_size + len(new_toks)) 107 | 108 | tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") 109 | self.assertGreaterEqual(len(tokens), 4) 110 | self.assertGreater(tokens[0], tokenizer.vocab_size - 1) 111 | self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) 112 | 113 | new_toks_2 = {'eos_token': ">>>>|||<||<<|<<", 114 | 'pad_token': "<<<<<|||>|>>>>|>"} 115 | added_toks_2 = tokenizer.add_special_tokens(new_toks_2) 116 | vocab_size_3 = tokenizer.vocab_size 117 | all_size_3 = len(tokenizer) 118 | 119 | self.assertNotEqual(vocab_size_3, 0) 120 | self.assertEqual(vocab_size, vocab_size_3) 121 | self.assertEqual(added_toks_2, len(new_toks_2)) 122 | self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) 123 | 124 | tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") 125 | 126 | self.assertGreaterEqual(len(tokens), 6) 127 | self.assertGreater(tokens[0], tokenizer.vocab_size - 1) 128 | self.assertGreater(tokens[0], tokens[1]) 129 | self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) 130 | self.assertGreater(tokens[-2], tokens[-3]) 131 | self.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token)) 132 | self.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) 133 | 134 | 135 | def test_required_methods_tokenizer(self): 136 | tokenizer = self.get_tokenizer() 137 | input_text, output_text = self.get_input_output_texts() 138 | 139 | tokens = tokenizer.tokenize(input_text) 140 | ids = tokenizer.convert_tokens_to_ids(tokens) 141 | ids_2 = tokenizer.encode(input_text) 142 | self.assertListEqual(ids, ids_2) 143 | 144 | tokens_2 = tokenizer.convert_ids_to_tokens(ids) 145 | text_2 = tokenizer.decode(ids) 146 | 147 | self.assertEqual(text_2, output_text) 148 | 149 | self.assertNotEqual(len(tokens_2), 0) 150 | self.assertIsInstance(text_2, (str, unicode)) 151 | 152 | 153 | def test_pretrained_model_lists(self): 154 | weights_list = list(self.tokenizer_class.max_model_input_sizes.keys()) 155 | weights_lists_2 = [] 156 | for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items(): 157 | weights_lists_2.append(list(map_list.keys())) 158 | 159 | for weights_list_2 in weights_lists_2: 160 | self.assertListEqual(weights_list, weights_list_2) 161 | -------------------------------------------------------------------------------- /pytorch_transformers/tests/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import unittest 20 | import os 21 | 22 | import torch 23 | 24 | from pytorch_transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, 25 | WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 26 | 27 | from .tokenization_tests_commons import TemporaryDirectory 28 | 29 | 30 | def unwrap_schedule(scheduler, num_steps=10): 31 | lrs = [] 32 | for _ in range(num_steps): 33 | scheduler.step() 34 | lrs.append(scheduler.get_lr()) 35 | return lrs 36 | 37 | def unwrap_and_save_reload_schedule(scheduler, num_steps=10): 38 | lrs = [] 39 | for step in range(num_steps): 40 | scheduler.step() 41 | lrs.append(scheduler.get_lr()) 42 | if step == num_steps // 2: 43 | with TemporaryDirectory() as tmpdirname: 44 | file_name = os.path.join(tmpdirname, 'schedule.bin') 45 | torch.save(scheduler.state_dict(), file_name) 46 | 47 | state_dict = torch.load(file_name) 48 | scheduler.load_state_dict(state_dict) 49 | return lrs 50 | 51 | class OptimizationTest(unittest.TestCase): 52 | 53 | def assertListAlmostEqual(self, list1, list2, tol): 54 | self.assertEqual(len(list1), len(list2)) 55 | for a, b in zip(list1, list2): 56 | self.assertAlmostEqual(a, b, delta=tol) 57 | 58 | def test_adam_w(self): 59 | w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) 60 | target = torch.tensor([0.4, 0.2, -0.5]) 61 | criterion = torch.nn.MSELoss() 62 | # No warmup, constant schedule, no gradient clipping 63 | optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0) 64 | for _ in range(100): 65 | loss = criterion(w, target) 66 | loss.backward() 67 | optimizer.step() 68 | w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves. 69 | w.grad.zero_() 70 | self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) 71 | 72 | 73 | class ScheduleInitTest(unittest.TestCase): 74 | m = torch.nn.Linear(50, 50) 75 | optimizer = AdamW(m.parameters(), lr=10.) 76 | num_steps = 10 77 | 78 | def assertListAlmostEqual(self, list1, list2, tol): 79 | self.assertEqual(len(list1), len(list2)) 80 | for a, b in zip(list1, list2): 81 | self.assertAlmostEqual(a, b, delta=tol) 82 | 83 | def test_constant_scheduler(self): 84 | scheduler = ConstantLRSchedule(self.optimizer) 85 | lrs = unwrap_schedule(scheduler, self.num_steps) 86 | expected_learning_rates = [10.] * self.num_steps 87 | self.assertEqual(len(lrs[0]), 1) 88 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 89 | 90 | scheduler = ConstantLRSchedule(self.optimizer) 91 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 92 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 93 | 94 | def test_warmup_constant_scheduler(self): 95 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 96 | lrs = unwrap_schedule(scheduler, self.num_steps) 97 | expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] 98 | self.assertEqual(len(lrs[0]), 1) 99 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 100 | 101 | scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) 102 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 103 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 104 | 105 | def test_warmup_linear_scheduler(self): 106 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 107 | lrs = unwrap_schedule(scheduler, self.num_steps) 108 | expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] 109 | self.assertEqual(len(lrs[0]), 1) 110 | self.assertListEqual([l[0] for l in lrs], expected_learning_rates) 111 | 112 | scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) 113 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 114 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 115 | 116 | def test_warmup_cosine_scheduler(self): 117 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 118 | lrs = unwrap_schedule(scheduler, self.num_steps) 119 | expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] 120 | self.assertEqual(len(lrs[0]), 1) 121 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 122 | 123 | scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) 124 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 125 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 126 | 127 | def test_warmup_cosine_hard_restart_scheduler(self): 128 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 129 | lrs = unwrap_schedule(scheduler, self.num_steps) 130 | expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] 131 | self.assertEqual(len(lrs[0]), 1) 132 | self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) 133 | 134 | scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) 135 | lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) 136 | self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) 137 | 138 | if __name__ == "__main__": 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /docs/source/torchscript.rst: -------------------------------------------------------------------------------- 1 | TorchScript 2 | ================================================ 3 | 4 | .. note:: 5 | This is the very beginning of our experiments with TorchScript and we are still exploring its capabilities 6 | with variable-input-size models. It is a focus of interest to us and we will deepen our analysis in upcoming 7 | releases, with more code examples, a more flexible implementation, and benchmarks comparing python-based codes 8 | with compiled TorchScript. 9 | 10 | 11 | According to Pytorch's documentation: "TorchScript is a way to create serializable and optimizable models from PyTorch code". 12 | Pytorch's two modules `JIT and TRACE `_ allow the developer to export 13 | their model to be re-used in other programs, such as efficiency-oriented C++ programs. 14 | 15 | We have provided an interface that allows the export of `pytorch-transformers` models to TorchScript so that they can 16 | be reused in a different environment than a Pytorch-based python program. Here we explain how to use our models so that 17 | they can be exported, and what to be mindful of when using these models with TorchScript. 18 | 19 | Exporting a model needs two things: 20 | 21 | * dummy inputs to execute a model forward pass. 22 | * the model needs to be instantiated with the ``torchscript`` flag. 23 | 24 | These necessities imply several things developers should be careful about. These are detailed below. 25 | 26 | 27 | Implications 28 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 29 | 30 | TorchScript flag and tied weights 31 | ------------------------------------------------ 32 | This flag is necessary because most of the language models in this repository have tied weights between their 33 | ``Embedding`` layer and their ``Decoding`` layer. TorchScript does not allow the export of models that have tied weights, 34 | it is therefore necessary to untie the weights beforehand. 35 | 36 | This implies that models instantiated with the ``torchscript`` flag have their ``Embedding`` layer and ``Decoding`` layer 37 | separate, which means that they should not be trained down the line. Training would de-synchronize the two layers, 38 | leading to unexpected results. 39 | 40 | This is not the case for models that do not have a Language Model head, as those do not have tied weights. These models 41 | can be safely exported without the ``torchscript`` flag. 42 | 43 | Dummy inputs and standard lengths 44 | ------------------------------------------------ 45 | 46 | The dummy inputs are used to do a model forward pass. While the inputs' values are propagating through the layers, 47 | Pytorch keeps track of the different operations executed on each tensor. These recorded operations are then used 48 | to create the "trace" of the model. 49 | 50 | The trace is created relatively to the inputs' dimensions. It is therefore constrained by the dimensions of the dummy 51 | input, and will not work for any other sequence length or batch size. When trying with a different size, an error such 52 | as: 53 | 54 | ``The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2`` 55 | 56 | will be raised. It is therefore recommended to trace the model with a dummy input size at least as large as the largest 57 | input that will be fed to the model during inference. Padding can be performed to fill the missing values. As the model 58 | will have been traced with a large input size however, the dimensions of the different matrix will be large as well, 59 | resulting in more calculations. 60 | 61 | It is recommended to be careful of the total number of operations done on each input and to follow performance closely 62 | when exporting varying sequence-length models. 63 | 64 | Using TorchScript in Python 65 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 66 | 67 | Below are examples of using the Python to save, load models as well as how to use the trace for inference. 68 | 69 | Saving a model 70 | ------------------------------------------------ 71 | 72 | This snippet shows how to use TorchScript to export a ``BertModel``. Here the ``BertModel`` is instantiated 73 | according to a ``BertConfig`` class and then saved to disk under the filename ``traced_bert.pt`` 74 | 75 | .. code-block:: python 76 | 77 | from pytorch_transformers import BertModel, BertTokenizer, BertConfig 78 | import torch 79 | 80 | enc = BertTokenizer.from_pretrained("bert-base-uncased") 81 | 82 | # Tokenizing input text 83 | text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" 84 | tokenized_text = enc.tokenize(text) 85 | 86 | # Masking one of the input tokens 87 | masked_index = 8 88 | tokenized_text[masked_index] = '[MASK]' 89 | indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) 90 | segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] 91 | 92 | # Creating a dummy input 93 | tokens_tensor = torch.tensor([indexed_tokens]) 94 | segments_tensors = torch.tensor([segments_ids]) 95 | dummy_input = [tokens_tensor, segments_tensors] 96 | 97 | # Initializing the model with the torchscript flag 98 | # Flag set to True even though it is not necessary as this model does not have an LM Head. 99 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 100 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True) 101 | 102 | # Instantiating the model 103 | model = BertModel(config) 104 | 105 | # The model needs to be in evaluation mode 106 | model.eval() 107 | 108 | # If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag 109 | model = BertModel.from_pretrained("bert-base-uncased", torchscript=True) 110 | 111 | # Creating the trace 112 | traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) 113 | torch.jit.save(traced_model, "traced_bert.pt") 114 | 115 | Loading a model 116 | ------------------------------------------------ 117 | 118 | This snippet shows how to load the ``BertModel`` that was previously saved to disk under the name ``traced_bert.pt``. 119 | We are re-using the previously initialised ``dummy_input``. 120 | 121 | .. code-block:: python 122 | 123 | loaded_model = torch.jit.load("traced_model.pt") 124 | loaded_model.eval() 125 | 126 | all_encoder_layers, pooled_output = loaded_model(dummy_input) 127 | 128 | Using a traced model for inference 129 | ------------------------------------------------ 130 | 131 | Using the traced model for inference is as simple as using its ``__call__`` dunder method: 132 | 133 | .. code-block:: python 134 | 135 | traced_model(tokens_tensor, segments_tensors) 136 | -------------------------------------------------------------------------------- /hubconfs/xlm_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_xlm import XLMTokenizer 2 | from pytorch_transformers.modeling_xlm import ( 3 | XLMConfig, 4 | XLMModel, 5 | XLMWithLMHeadModel, 6 | XLMForSequenceClassification, 7 | XLMForQuestionAnswering 8 | ) 9 | 10 | # A lot of models share the same param doc. Use a decorator 11 | # to save typing 12 | xlm_start_docstring = """ 13 | Model class adapted from the XLM Transformer model of 14 | "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau 15 | Paper: https://arxiv.org/abs/1901.07291 16 | Original code: https://github.com/facebookresearch/XLM 17 | 18 | Example: 19 | # Load the tokenizer 20 | import torch 21 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlmTokenizer', 'xlm-mlm-en-2048') 22 | 23 | # Prepare tokenized input 24 | text_1 = "Who was Jim Henson ?" 25 | text_2 = "Jim Henson was a puppeteer" 26 | indexed_tokens_1 = tokenizer.encode(text_1) 27 | indexed_tokens_2 = tokenizer.encode(text_2) 28 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 29 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 30 | """ 31 | 32 | # A lot of models share the same param doc. Use a decorator 33 | # to save typing 34 | xlm_end_docstring = """ 35 | Params: 36 | pretrained_model_name_or_path: either: 37 | - a str with the name of a pre-trained model to load selected in the list of: 38 | . `xlm-mlm-en-2048` 39 | - a path or url to a pretrained model archive containing: 40 | . `config.json` a configuration file for the model 41 | . `pytorch_model.bin` a PyTorch dump created using the `convert_xlm_checkpoint_to_pytorch` conversion script 42 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 43 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 44 | *inputs, **kwargs: additional input for the specific XLM class 45 | """ 46 | 47 | 48 | def _begin_with_docstring(docstr): 49 | def docstring_decorator(fn): 50 | fn.__doc__ = fn.__doc__ + docstr 51 | return fn 52 | return docstring_decorator 53 | 54 | def _end_with_docstring(docstr): 55 | def docstring_decorator(fn): 56 | fn.__doc__ = fn.__doc__ + docstr 57 | return fn 58 | return docstring_decorator 59 | 60 | 61 | def xlmTokenizer(*args, **kwargs): 62 | """ 63 | Instantiate a XLM BPE tokenizer for XLM from a pre-trained vocab file. 64 | 65 | Args: 66 | pretrained_model_name_or_path: Path to pretrained model archive 67 | or one of pre-trained vocab configs below. 68 | * xlm-mlm-en-2048 69 | Keyword args: 70 | special_tokens: Special tokens in vocabulary that are not pretrained 71 | Default: None 72 | max_len: An artificial maximum length to truncate tokenized sequences to; 73 | Effective maximum length is always the minimum of this 74 | value (if specified) and the underlying model's 75 | sequence length. 76 | Default: None 77 | 78 | Example: 79 | import torch 80 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlmTokenizer', 'xlm-mlm-en-2048') 81 | 82 | text = "Who was Jim Henson ?" 83 | indexed_tokens = tokenizer.encode(tokenized_text) 84 | """ 85 | tokenizer = XLMTokenizer.from_pretrained(*args, **kwargs) 86 | return tokenizer 87 | 88 | 89 | @_begin_with_docstring(xlm_start_docstring) 90 | @_end_with_docstring(xlm_end_docstring) 91 | def xlmModel(*args, **kwargs): 92 | """ 93 | # Load xlmModel 94 | model = torch.hub.load('huggingface/pytorch-transformers', 'xlmModel', 'xlm-mlm-en-2048') 95 | model.eval() 96 | 97 | # Predict hidden states features for each layer 98 | with torch.no_grad(): 99 | hidden_states_1, mems = model(tokens_tensor_1) 100 | hidden_states_2, mems = model(tokens_tensor_2, past=mems) 101 | """ 102 | model = XLMModel.from_pretrained(*args, **kwargs) 103 | return model 104 | 105 | 106 | @_begin_with_docstring(xlm_start_docstring) 107 | @_end_with_docstring(xlm_end_docstring) 108 | def xlmLMHeadModel(*args, **kwargs): 109 | """ 110 | # Prepare tokenized input 111 | text_1 = "Who was Jim Henson ?" 112 | text_2 = "Jim Henson was a puppeteer" 113 | indexed_tokens_1 = tokenizer.encode(text_1) 114 | indexed_tokens_2 = tokenizer.encode(text_2) 115 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 116 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 117 | 118 | # Load xlnetLMHeadModel 119 | model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetLMHeadModel', 'xlm-mlm-en-2048') 120 | model.eval() 121 | 122 | # Predict hidden states features for each layer 123 | with torch.no_grad(): 124 | predictions_1, mems = model(tokens_tensor_1) 125 | predictions_2, mems = model(tokens_tensor_2, mems=mems) 126 | 127 | # Get the predicted last token 128 | predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 129 | predicted_token = tokenizer.decode([predicted_index]) 130 | assert predicted_token == ' who' 131 | """ 132 | model = XLMWithLMHeadModel.from_pretrained(*args, **kwargs) 133 | return model 134 | 135 | 136 | # @_end_with_docstring(xlnet_docstring) 137 | # def xlnetForSequenceClassification(*args, **kwargs): 138 | # """ 139 | # xlnetModel is the basic XLNet Transformer model from 140 | # "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 141 | # by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 142 | 143 | # Example: 144 | # # Load the tokenizer 145 | # import torch 146 | # tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlm-mlm-en-2048') 147 | 148 | # # Prepare tokenized input 149 | # text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 150 | # text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 151 | # tokenized_text1 = tokenizer.tokenize(text1) 152 | # tokenized_text2 = tokenizer.tokenize(text2) 153 | # indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 154 | # indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 155 | # tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 156 | # mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 157 | 158 | # # Load xlnetForSequenceClassification 159 | # model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetForSequenceClassification', 'xlm-mlm-en-2048') 160 | # model.eval() 161 | 162 | # # Predict sequence classes logits 163 | # with torch.no_grad(): 164 | # lm_logits, mems = model(tokens_tensor) 165 | # """ 166 | # model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs) 167 | # return model 168 | -------------------------------------------------------------------------------- /examples/single_model_scripts/run_transfo_xl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ PyTorch Transformer XL model evaluation script. 17 | Adapted from https://github.com/kimiyoung/transformer-xl. 18 | In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py 19 | 20 | This script with default values evaluates a pretrained Transformer-XL on WikiText 103 21 | """ 22 | from __future__ import absolute_import, division, print_function, unicode_literals 23 | 24 | import argparse 25 | import logging 26 | import time 27 | import math 28 | 29 | import torch 30 | 31 | from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer 32 | 33 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 34 | datefmt = '%m/%d/%Y %H:%M:%S', 35 | level = logging.INFO) 36 | logger = logging.getLogger(__name__) 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') 40 | parser.add_argument('--model_name', type=str, default='transfo-xl-wt103', 41 | help='pretrained model name') 42 | parser.add_argument('--split', type=str, default='test', 43 | choices=['all', 'valid', 'test'], 44 | help='which split to evaluate') 45 | parser.add_argument('--batch_size', type=int, default=10, 46 | help='batch size') 47 | parser.add_argument('--tgt_len', type=int, default=128, 48 | help='number of tokens to predict') 49 | parser.add_argument('--ext_len', type=int, default=0, 50 | help='length of the extended context') 51 | parser.add_argument('--mem_len', type=int, default=1600, 52 | help='length of the retained previous heads') 53 | parser.add_argument('--clamp_len', type=int, default=1000, 54 | help='max positional embedding index') 55 | parser.add_argument('--no_cuda', action='store_true', 56 | help='Do not use CUDA even though CUA is available') 57 | parser.add_argument('--work_dir', type=str, required=True, 58 | help='path to the work_dir') 59 | parser.add_argument('--no_log', action='store_true', 60 | help='do not log the eval result') 61 | parser.add_argument('--same_length', action='store_true', 62 | help='set same length attention with masking') 63 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 64 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 65 | args = parser.parse_args() 66 | assert args.ext_len >= 0, 'extended context length must be non-negative' 67 | 68 | if args.server_ip and args.server_port: 69 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 70 | import ptvsd 71 | print("Waiting for debugger attach") 72 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 73 | ptvsd.wait_for_attach() 74 | 75 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 76 | logger.info("device: {}".format(device)) 77 | 78 | # Load a pre-processed dataset 79 | # You can also build the corpus yourself using TransfoXLCorpus methods 80 | # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax 81 | # and tokenizing the dataset 82 | # The pre-processed corpus is a convertion (using the conversion script ) 83 | tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name) 84 | corpus = TransfoXLCorpus.from_pretrained(args.model_name) 85 | ntokens = len(corpus.vocab) 86 | 87 | va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, 88 | device=device, ext_len=args.ext_len) 89 | te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, 90 | device=device, ext_len=args.ext_len) 91 | 92 | # Load a pre-trained model 93 | model = TransfoXLLMHeadModel.from_pretrained(args.model_name) 94 | model = model.to(device) 95 | 96 | logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( 97 | args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) 98 | 99 | model.reset_length(args.tgt_len, args.ext_len, args.mem_len) 100 | if args.clamp_len > 0: 101 | model.clamp_len = args.clamp_len 102 | if args.same_length: 103 | model.same_length = True 104 | 105 | ############################################################################### 106 | # Evaluation code 107 | ############################################################################### 108 | def evaluate(eval_iter): 109 | # Turn on evaluation mode which disables dropout. 110 | model.eval() 111 | total_len, total_loss = 0, 0. 112 | start_time = time.time() 113 | with torch.no_grad(): 114 | mems = None 115 | for idx, (data, target, seq_len) in enumerate(eval_iter): 116 | ret = model(data, target, mems) 117 | loss, _, mems = ret 118 | loss = loss.mean() 119 | total_loss += seq_len * loss.item() 120 | total_len += seq_len 121 | total_time = time.time() - start_time 122 | logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format( 123 | total_time, 1000 * total_time / (idx+1))) 124 | return total_loss / total_len 125 | 126 | # Run on test data. 127 | if args.split == 'all': 128 | test_loss = evaluate(te_iter) 129 | valid_loss = evaluate(va_iter) 130 | elif args.split == 'valid': 131 | valid_loss = evaluate(va_iter) 132 | test_loss = None 133 | elif args.split == 'test': 134 | test_loss = evaluate(te_iter) 135 | valid_loss = None 136 | 137 | def format_log(loss, split): 138 | log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( 139 | split, loss, math.exp(loss)) 140 | return log_str 141 | 142 | log_str = '' 143 | if valid_loss is not None: 144 | log_str += format_log(valid_loss, 'valid') 145 | if test_loss is not None: 146 | log_str += format_log(test_loss, 'test') 147 | 148 | logger.info('=' * 100) 149 | logger.info(log_str) 150 | logger.info('=' * 100) 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /pytorch_transformers/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: 5 | print( 6 | "Should be used as one of: \n" 7 | ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" 8 | ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" 9 | ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" 10 | ">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" 11 | ">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n" 12 | ">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT") 13 | else: 14 | if sys.argv[1] == "bert": 15 | try: 16 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 17 | except ImportError: 18 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 19 | "In that case, it requires TensorFlow to be installed. Please see " 20 | "https://www.tensorflow.org/install/ for installation instructions.") 21 | raise 22 | 23 | if len(sys.argv) != 5: 24 | # pylint: disable=line-too-long 25 | print("Should be used as `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 26 | else: 27 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 28 | TF_CONFIG = sys.argv.pop() 29 | TF_CHECKPOINT = sys.argv.pop() 30 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 31 | elif sys.argv[1] == "gpt": 32 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 33 | if len(sys.argv) < 4 or len(sys.argv) > 5: 34 | # pylint: disable=line-too-long 35 | print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") 36 | else: 37 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 38 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 39 | if len(sys.argv) == 5: 40 | OPENAI_GPT_CONFIG = sys.argv[4] 41 | else: 42 | OPENAI_GPT_CONFIG = "" 43 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 44 | OPENAI_GPT_CONFIG, 45 | PYTORCH_DUMP_OUTPUT) 46 | elif sys.argv[1] == "transfo_xl": 47 | try: 48 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 49 | except ImportError: 50 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 51 | "In that case, it requires TensorFlow to be installed. Please see " 52 | "https://www.tensorflow.org/install/ for installation instructions.") 53 | raise 54 | if len(sys.argv) < 4 or len(sys.argv) > 5: 55 | # pylint: disable=line-too-long 56 | print("Should be used as `pytorch_transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 57 | else: 58 | if 'ckpt' in sys.argv[2].lower(): 59 | TF_CHECKPOINT = sys.argv[2] 60 | TF_DATASET_FILE = "" 61 | else: 62 | TF_DATASET_FILE = sys.argv[2] 63 | TF_CHECKPOINT = "" 64 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 65 | if len(sys.argv) == 5: 66 | TF_CONFIG = sys.argv[4] 67 | else: 68 | TF_CONFIG = "" 69 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 70 | elif sys.argv[1] == "gpt2": 71 | try: 72 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 73 | except ImportError: 74 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 75 | "In that case, it requires TensorFlow to be installed. Please see " 76 | "https://www.tensorflow.org/install/ for installation instructions.") 77 | raise 78 | 79 | if len(sys.argv) < 4 or len(sys.argv) > 5: 80 | # pylint: disable=line-too-long 81 | print("Should be used as `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 82 | else: 83 | TF_CHECKPOINT = sys.argv[2] 84 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 85 | if len(sys.argv) == 5: 86 | TF_CONFIG = sys.argv[4] 87 | else: 88 | TF_CONFIG = "" 89 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 90 | elif sys.argv[1] == "xlnet": 91 | try: 92 | from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch 93 | except ImportError: 94 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 95 | "In that case, it requires TensorFlow to be installed. Please see " 96 | "https://www.tensorflow.org/install/ for installation instructions.") 97 | raise 98 | 99 | if len(sys.argv) < 5 or len(sys.argv) > 6: 100 | # pylint: disable=line-too-long 101 | print("Should be used as `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") 102 | else: 103 | TF_CHECKPOINT = sys.argv[2] 104 | TF_CONFIG = sys.argv[3] 105 | PYTORCH_DUMP_OUTPUT = sys.argv[4] 106 | if len(sys.argv) == 6: 107 | FINETUNING_TASK = sys.argv[5] 108 | else: 109 | FINETUNING_TASK = None 110 | 111 | convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT, 112 | TF_CONFIG, 113 | PYTORCH_DUMP_OUTPUT, 114 | FINETUNING_TASK) 115 | elif sys.argv[1] == "xlm": 116 | from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch 117 | 118 | if len(sys.argv) != 4: 119 | # pylint: disable=line-too-long 120 | print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`") 121 | else: 122 | XLM_CHECKPOINT_PATH = sys.argv[2] 123 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 124 | 125 | convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT) 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /hubconfs/gpt2_hubconf.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer 2 | from pytorch_transformers.modeling_gpt2 import ( 3 | GPT2Model, 4 | GPT2LMHeadModel, 5 | GPT2DoubleHeadsModel 6 | ) 7 | 8 | # A lot of models share the same param doc. Use a decorator 9 | # to save typing 10 | gpt2_docstring = """ 11 | Params: 12 | pretrained_model_name_or_path: either: 13 | - a str with the name of a pre-trained model to load selected in the list of: 14 | . `gpt2`, `gpt2-medium` 15 | - a path or url to a pretrained model archive containing: 16 | . `gpt2_config.json` a configuration file for the model 17 | . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance 18 | - a path or url to a pretrained model archive containing: 19 | . `gpt2_config.json` a configuration file for the model 20 | . a TensorFlow checkpoint with trained weights 21 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 22 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 23 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 24 | *inputs, **kwargs: additional input for the specific GPT-2 class 25 | """ 26 | 27 | 28 | def _append_from_pretrained_docstring(docstr): 29 | def docstring_decorator(fn): 30 | fn.__doc__ = fn.__doc__ + docstr 31 | return fn 32 | return docstring_decorator 33 | 34 | 35 | def gpt2Tokenizer(*args, **kwargs): 36 | """ 37 | Instantiate a GPT-2 BPE tokenizer for OpenAI GPT-2 from a pre-trained/customized vocab file. 38 | Peculiarities: 39 | - Byte-level BPE 40 | 41 | Args: 42 | pretrained_model_name_or_path: Path to pretrained model archive 43 | or one of pre-trained vocab configs below. 44 | * gpt2 45 | Keyword args: 46 | special_tokens: Special tokens in vocabulary that are not pretrained ([SEP], [CLS]...) 47 | Default: None 48 | max_len: An artificial maximum length to truncate tokenized sequences to; 49 | Effective maximum length is always the minimum of this 50 | value (if specified) and the underlying BERT model's 51 | sequence length. 52 | Default: None 53 | 54 | Example: 55 | import torch 56 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 57 | 58 | text = "Who was Jim Henson ?" 59 | indexed_tokens = tokenizer.encode(tokenized_text) 60 | """ 61 | tokenizer = GPT2Tokenizer.from_pretrained(*args, **kwargs) 62 | return tokenizer 63 | 64 | 65 | @_append_from_pretrained_docstring(gpt2_docstring) 66 | def gpt2Model(*args, **kwargs): 67 | """ 68 | gpt2Model is the basic OpenAI GPT-2 Transformer model based on 69 | identical stacked masked self-attention blocks and pre-trained 70 | on large scale dataset using language modeling signal. 71 | 72 | Example: 73 | # Load the tokenizer 74 | import torch 75 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 76 | 77 | # Prepare tokenized input 78 | text_1 = "Who was Jim Henson ?" 79 | text_2 = "Jim Henson was a puppeteer" 80 | indexed_tokens_1 = tokenizer.encode(text_1) 81 | indexed_tokens_2 = tokenizer.encode(text_2) 82 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 83 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 84 | 85 | # Load gpt2Model 86 | model = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Model', 'gpt2') 87 | model.eval() 88 | 89 | # Predict hidden states features for each layer 90 | # past can be used to reuse precomputed hidden state in a subsequent predictions 91 | with torch.no_grad(): 92 | hidden_states_1, past = model(tokens_tensor_1) 93 | hidden_states_2, past = model(tokens_tensor_2, past=past) 94 | """ 95 | model = GPT2Model.from_pretrained(*args, **kwargs) 96 | return model 97 | 98 | 99 | @_append_from_pretrained_docstring(gpt2_docstring) 100 | def gpt2LMHeadModel(*args, **kwargs): 101 | """ 102 | gpt2LMHeadModel is the OpenAI GPT-2 Transformer model with the 103 | tied (pre-trained) language modeling head on top. 104 | 105 | Example: 106 | # Load the tokenizer 107 | import torch 108 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 109 | 110 | # Prepare tokenized input 111 | text_1 = "Who was Jim Henson ?" 112 | text_2 = "Jim Henson was a puppeteer" 113 | indexed_tokens_1 = tokenizer.encode(text_1) 114 | indexed_tokens_2 = tokenizer.encode(text_2) 115 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 116 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 117 | 118 | # Load gpt2LMHeadModel 119 | model = torch.hub.load('huggingface/pytorch-transformers', 'gpt2LMHeadModel', 'gpt2') 120 | model.eval() 121 | 122 | # Predict hidden states features for each layer 123 | # past can be used to reuse precomputed hidden state in a subsequent predictions 124 | with torch.no_grad(): 125 | predictions_1, past = model(tokens_tensor_1) 126 | predictions_2, past = model(tokens_tensor_2, past=past) 127 | 128 | # Get the predicted last token 129 | predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 130 | predicted_token = tokenizer.decode([predicted_index]) 131 | assert predicted_token == ' who' 132 | """ 133 | model = GPT2LMHeadModel.from_pretrained(*args, **kwargs) 134 | return model 135 | 136 | 137 | @_append_from_pretrained_docstring(gpt2_docstring) 138 | def gpt2DoubleHeadsModel(*args, **kwargs): 139 | """ 140 | gpt2DoubleHeadsModel is the OpenAI GPT-2 Transformer model with the 141 | tied (pre-trained) language modeling head and a multiple choice 142 | classification head (only initialized, not pre-trained). 143 | 144 | Example: 145 | # Load the tokenizer 146 | import torch 147 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'gpt2Tokenizer', 'gpt2') 148 | 149 | # Prepare tokenized input 150 | text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 151 | text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 152 | tokenized_text1 = tokenizer.tokenize(text1) 153 | tokenized_text2 = tokenizer.tokenize(text2) 154 | indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 155 | indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 156 | tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 157 | mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 158 | 159 | # Load gpt2DoubleHeadsModel 160 | model = torch.hub.load('huggingface/pytorch-transformers', 'gpt2DoubleHeadsModel', 'gpt2') 161 | model.eval() 162 | 163 | # Predict hidden states features for each layer 164 | with torch.no_grad(): 165 | lm_logits, multiple_choice_logits, presents = model(tokens_tensor, mc_token_ids) 166 | """ 167 | model = GPT2DoubleHeadsModel.from_pretrained(*args, **kwargs) 168 | return model 169 | -------------------------------------------------------------------------------- /hubconfs/xlnet_hubconf.1.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers.tokenization_xlnet import XLNetTokenizer 2 | from pytorch_transformers.modeling_xlnet import ( 3 | XLNetConfig, 4 | XLNetModel, 5 | XLNetLMHeadModel, 6 | # XLNetForSequenceClassification 7 | ) 8 | 9 | # A lot of models share the same param doc. Use a decorator 10 | # to save typing 11 | xlnet_docstring = """ 12 | Params: 13 | pretrained_model_name_or_path: either: 14 | - a str with the name of a pre-trained model to load selected in the list of: 15 | . `xlnet-large-cased` 16 | - a path or url to a pretrained model archive containing: 17 | . `config.json` a configuration file for the model 18 | . `pytorch_model.bin` a PyTorch dump of a XLNetForPreTraining instance 19 | - a path or url to a pretrained model archive containing: 20 | . `xlnet_config.json` a configuration file for the model 21 | . `model.chkpt` a TensorFlow checkpoint 22 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 23 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 24 | state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models 25 | *inputs, **kwargs: additional input for the specific XLNet class 26 | """ 27 | 28 | 29 | def _append_from_pretrained_docstring(docstr): 30 | def docstring_decorator(fn): 31 | fn.__doc__ = fn.__doc__ + docstr 32 | return fn 33 | return docstring_decorator 34 | 35 | 36 | def xlnetTokenizer(*args, **kwargs): 37 | """ 38 | Instantiate a XLNet sentencepiece tokenizer for XLNet from a pre-trained vocab file. 39 | Peculiarities: 40 | - require Google sentencepiece (https://github.com/google/sentencepiece) 41 | 42 | Args: 43 | pretrained_model_name_or_path: Path to pretrained model archive 44 | or one of pre-trained vocab configs below. 45 | * xlnet-large-cased 46 | Keyword args: 47 | special_tokens: Special tokens in vocabulary that are not pretrained 48 | Default: None 49 | max_len: An artificial maximum length to truncate tokenized sequences to; 50 | Effective maximum length is always the minimum of this 51 | value (if specified) and the underlying model's 52 | sequence length. 53 | Default: None 54 | 55 | Example: 56 | import torch 57 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 58 | 59 | text = "Who was Jim Henson ?" 60 | indexed_tokens = tokenizer.encode(tokenized_text) 61 | """ 62 | tokenizer = XLNetTokenizer.from_pretrained(*args, **kwargs) 63 | return tokenizer 64 | 65 | 66 | @_append_from_pretrained_docstring(xlnet_docstring) 67 | def xlnetModel(*args, **kwargs): 68 | """ 69 | xlnetModel is the basic XLNet Transformer model from 70 | "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 71 | by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 72 | 73 | Example: 74 | # Load the tokenizer 75 | import torch 76 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 77 | 78 | # Prepare tokenized input 79 | text_1 = "Who was Jim Henson ?" 80 | text_2 = "Jim Henson was a puppeteer" 81 | indexed_tokens_1 = tokenizer.encode(text_1) 82 | indexed_tokens_2 = tokenizer.encode(text_2) 83 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 84 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 85 | 86 | # Load xlnetModel 87 | model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetModel', 'xlnet-large-cased') 88 | model.eval() 89 | 90 | # Predict hidden states features for each layer 91 | with torch.no_grad(): 92 | hidden_states_1, mems = model(tokens_tensor_1) 93 | hidden_states_2, mems = model(tokens_tensor_2, past=mems) 94 | """ 95 | model = XLNetModel.from_pretrained(*args, **kwargs) 96 | return model 97 | 98 | 99 | @_append_from_pretrained_docstring(xlnet_docstring) 100 | def xlnetLMHeadModel(*args, **kwargs): 101 | """ 102 | xlnetModel is the basic XLNet Transformer model from 103 | "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 104 | by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 105 | with a tied (pre-trained) language modeling head on top. 106 | 107 | Example: 108 | # Load the tokenizer 109 | import torch 110 | tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 111 | 112 | # Prepare tokenized input 113 | text_1 = "Who was Jim Henson ?" 114 | text_2 = "Jim Henson was a puppeteer" 115 | indexed_tokens_1 = tokenizer.encode(text_1) 116 | indexed_tokens_2 = tokenizer.encode(text_2) 117 | tokens_tensor_1 = torch.tensor([indexed_tokens_1]) 118 | tokens_tensor_2 = torch.tensor([indexed_tokens_2]) 119 | 120 | # Load xlnetLMHeadModel 121 | model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetLMHeadModel', 'xlnet-large-cased') 122 | model.eval() 123 | 124 | # Predict hidden states features for each layer 125 | with torch.no_grad(): 126 | predictions_1, mems = model(tokens_tensor_1) 127 | predictions_2, mems = model(tokens_tensor_2, mems=mems) 128 | 129 | # Get the predicted last token 130 | predicted_index = torch.argmax(predictions_2[0, -1, :]).item() 131 | predicted_token = tokenizer.decode([predicted_index]) 132 | assert predicted_token == ' who' 133 | """ 134 | model = XLNetLMHeadModel.from_pretrained(*args, **kwargs) 135 | return model 136 | 137 | 138 | # @_append_from_pretrained_docstring(xlnet_docstring) 139 | # def xlnetForSequenceClassification(*args, **kwargs): 140 | # """ 141 | # xlnetModel is the basic XLNet Transformer model from 142 | # "XLNet: Generalized Autoregressive Pretraining for Language Understanding" 143 | # by Zhilin Yang, Zihang Dai1, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 144 | 145 | # Example: 146 | # # Load the tokenizer 147 | # import torch 148 | # tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased') 149 | 150 | # # Prepare tokenized input 151 | # text1 = "Who was Jim Henson ? Jim Henson was a puppeteer" 152 | # text2 = "Who was Jim Henson ? Jim Henson was a mysterious young man" 153 | # tokenized_text1 = tokenizer.tokenize(text1) 154 | # tokenized_text2 = tokenizer.tokenize(text2) 155 | # indexed_tokens1 = tokenizer.convert_tokens_to_ids(tokenized_text1) 156 | # indexed_tokens2 = tokenizer.convert_tokens_to_ids(tokenized_text2) 157 | # tokens_tensor = torch.tensor([[indexed_tokens1, indexed_tokens2]]) 158 | # mc_token_ids = torch.LongTensor([[len(tokenized_text1)-1, len(tokenized_text2)-1]]) 159 | 160 | # # Load xlnetForSequenceClassification 161 | # model = torch.hub.load('huggingface/pytorch-transformers', 'xlnetForSequenceClassification', 'xlnet-large-cased') 162 | # model.eval() 163 | 164 | # # Predict sequence classes logits 165 | # with torch.no_grad(): 166 | # lm_logits, mems = model(tokens_tensor) 167 | # """ 168 | # model = XLNetForSequenceClassification.from_pretrained(*args, **kwargs) 169 | # return model 170 | --------------------------------------------------------------------------------