├── .gitignore ├── LICENSE ├── README.md ├── configs ├── cnndm │ ├── attention │ │ └── multibranch_v2 │ │ │ └── embed496.yml │ └── test.sh ├── iwslt14.de-en │ ├── attention │ │ └── multibranch_v2 │ │ │ ├── embed160.yml │ │ │ ├── embed240.yml │ │ │ └── embed320.yml │ ├── prepare.sh │ └── test.sh ├── wmt14.en-fr │ ├── attention │ │ └── multibranch_v2 │ │ │ ├── embed200.yml │ │ │ ├── embed408.yml │ │ │ └── embed496.yml │ ├── prepare.sh │ └── test.sh └── wmt16.en-de │ ├── attention │ └── multibranch_v2 │ │ ├── embed200.yml │ │ ├── embed408.yml │ │ └── embed496.yml │ ├── prepare.sh │ └── test.sh ├── fairseq ├── __init__.py ├── binarizer.py ├── bleu.py ├── checkpoint_utils.py ├── clib │ └── libbleu │ │ ├── libbleu.cpp │ │ └── module.cpp ├── criterions │ ├── __init__.py │ ├── cross_entropy.py │ ├── fairseq_criterion.py │ └── label_smoothed_cross_entropy.py ├── data │ ├── __init__.py │ ├── append_token_dataset.py │ ├── base_wrapper_dataset.py │ ├── concat_dataset.py │ ├── data_utils.py │ ├── data_utils_fast.cpp │ ├── data_utils_fast.pyx │ ├── dictionary.py │ ├── encoders │ │ ├── __init__.py │ │ ├── fastbpe.py │ │ ├── gpt2_bpe.py │ │ ├── gpt2_bpe_utils.py │ │ ├── hf_bert_bpe.py │ │ ├── moses_tokenizer.py │ │ ├── nltk_tokenizer.py │ │ ├── sentencepiece_bpe.py │ │ ├── space_tokenizer.py │ │ └── subword_nmt_bpe.py │ ├── fairseq_dataset.py │ ├── indexed_dataset.py │ ├── iterators.py │ ├── language_pair_dataset.py │ ├── strip_token_dataset.py │ ├── token_block_utils_fast.c │ ├── token_block_utils_fast.pyx │ └── truncate_dataset.py ├── distributed_utils.py ├── file_io.py ├── file_utils.py ├── hub_utils.py ├── init.py ├── legacy_distributed_data_parallel.py ├── meters.py ├── models │ ├── __init__.py │ ├── distributed_fairseq_model.py │ ├── fairseq_decoder.py │ ├── fairseq_encoder.py │ ├── fairseq_incremental_decoder.py │ ├── fairseq_model.py │ ├── transformer.py │ └── transformer_multibranch_v2.py ├── modules │ ├── __init__.py │ ├── adaptive_softmax.py │ ├── cuda_utils.cu │ ├── dynamic_convolution.py │ ├── dynamicconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── dynamicconv_cuda.cpp │ │ ├── dynamicconv_cuda.cuh │ │ ├── dynamicconv_cuda_kernel.cu │ │ ├── dynamicconv_layer.py │ │ ├── dynamiconv_cpu.cpp │ │ └── setup.py │ ├── gelu.py │ ├── layer_norm.py │ ├── learned_positional_embedding.py │ ├── lightconv_layer │ │ ├── __init__.py │ │ ├── cuda_function_gen.py │ │ ├── lightconv_cuda.cpp │ │ ├── lightconv_cuda.cuh │ │ ├── lightconv_cuda_kernel.cu │ │ ├── lightconv_layer.py │ │ └── setup.py │ ├── lightweight_convolution.py │ ├── multibranch.py │ ├── multihead_attention.py │ ├── positional_embedding.py │ ├── sinusoidal_positional_embedding.py │ └── unfold.py ├── optim │ ├── __init__.py │ ├── adam.py │ ├── bmuf.py │ ├── fairseq_optimizer.py │ ├── fp16_optimizer.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── cosine_lr_scheduler.py │ │ ├── fairseq_lr_scheduler.py │ │ ├── fixed_scheduler.py │ │ └── inverse_square_root_schedule.py │ └── nag.py ├── options.py ├── pdb.py ├── progress_bar.py ├── registry.py ├── search.py ├── sequence_generator.py ├── sequence_scorer.py ├── tasks │ ├── __init__.py │ ├── fairseq_task.py │ └── translation.py ├── tokenizer.py ├── trainer.py └── utils.py ├── figures ├── compression.png ├── et.png ├── overview.png └── tradeoff.png ├── generate.py ├── preprocess.py ├── score.py ├── scripts ├── __init__.py ├── average_checkpoints.py ├── compare_namespaces.py ├── compound_split_bleu.sh ├── convert_dictionary.lua ├── convert_model.lua ├── count_docs.py ├── parse_profile.py ├── read_binarized.py ├── rm_pt.py ├── sacrebleu_pregen.sh ├── shard_docs.py ├── split_train_valid_docs.py ├── spm_decode.py ├── spm_encode.py └── spm_train.py ├── setup.py ├── train.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | gpt2_bpe 2 | .vscode 3 | exp 4 | mosesdecoder 5 | subword-nmt 6 | wmt14_en_fr/ 7 | wmt17_en_de/ 8 | orig 9 | /data 10 | distill_teacher/ 11 | # JetBrains PyCharm IDE 12 | .idea/ 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # macOS dir files 23 | .DS_Store 24 | 25 | # Distribution / packaging 26 | .Python 27 | env/ 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # Checkpoints 45 | checkpoints 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | .hypothesis/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # dotenv 102 | .env 103 | 104 | # virtualenv 105 | .venv 106 | venv/ 107 | ENV/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | 122 | # Generated files 123 | fairseq/temporal_convolution_tbc 124 | fairseq/modules/*_layer/*_forward.cu 125 | fairseq/modules/*_layer/*_backward.cu 126 | 127 | # data 128 | data-bin/ 129 | 130 | # reranking 131 | examples/reranking/rerank_data 132 | 133 | profiler.py 134 | configs/**/profile.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | For Lite Transformer software 2 | Copyright (c) 2020, Zhanghao Wu, Zhijian Liu, Ji Lin, Yujun Lin, and Song Han 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | 26 | 27 | ------------------------- LICENSE FOR Fairseq ------------------------------ 28 | MIT License 29 | 30 | Copyright (c) Facebook, Inc. and its affiliates. 31 | 32 | Permission is hereby granted, free of charge, to any person obtaining a copy 33 | of this software and associated documentation files (the "Software"), to deal 34 | in the Software without restriction, including without limitation the rights 35 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 36 | copies of the Software, and to permit persons to whom the Software is 37 | furnished to do so, subject to the following conditions: 38 | 39 | The above copyright notice and this permission notice shall be included in all 40 | copies or substantial portions of the Software. 41 | 42 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 43 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 44 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 45 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 46 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 47 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 48 | SOFTWARE. 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lite Transformer 2 | 3 | ### [paper](https://arxiv.org/abs/2004.11886) | [website](https://hanlab.mit.edu/projects/litetransformer/) | [slides](https://hanlab.mit.edu/projects/litetransformer/Presentation_LiteTransformer.pdf) 4 | 5 | ``` 6 | @inproceedings{Wu2020LiteTransformer, 7 | title={Lite Transformer with Long-Short Range Attention}, 8 | author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han}, 9 | booktitle={International Conference on Learning Representations (ICLR)}, 10 | year={2020} 11 | } 12 | ``` 13 | 14 | ## Overview 15 | 16 | ![overview](figures/overview.png?raw=true "overview") 17 | 18 | ## How to Use 19 | 20 | ### Prerequisite 21 | 22 | * Python version >= 3.6 23 | * [PyTorch](http://pytorch.org/) version >= 1.0.0 24 | * configargparse >= 0.14 25 | * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) 26 | 27 | ### Installation 28 | 29 | 1. Codebase 30 | 31 | To install fairseq from source and develop locally: 32 | ```bash 33 | pip install --editable . 34 | ``` 35 | 36 | 2. Costumized Modules 37 | 38 | We also need to build the `lightconv` and `dynamicconv` for GPU support. 39 | 40 | Lightconv_layer 41 | ```bash 42 | cd fairseq/modules/lightconv_layer 43 | python cuda_function_gen.py 44 | python setup.py install 45 | ``` 46 | Dynamicconv_layer 47 | ```bash 48 | cd fairseq/modules/dynamicconv_layer 49 | python cuda_function_gen.py 50 | python setup.py install 51 | ``` 52 | 53 | ### Data Preparation 54 | #### IWSLT'14 De-En 55 | We follow the data preparation in [fairseq](github.com/pytorch/fairseq). To download and preprocess the data, one can run 56 | ```bash 57 | bash configs/iwslt14.de-en/prepare.sh 58 | ``` 59 | 60 | #### WMT'14 En-Fr 61 | We follow the data pre-processing in [fairseq](github.com/pytorch/fairseq). To download and preprocess the data, one can run 62 | ```bash 63 | bash configs/wmt14.en-fr/prepare.sh 64 | ``` 65 | 66 | #### WMT'16 En-De 67 | We follow the data pre-processing in [fairseq](github.com/pytorch/fairseq). One should first download the preprocessed data from the [Google Drive](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) provided by Google. To binarized the data, one can run 68 | ```bash 69 | bash configs/wmt16.en-de/prepare.sh [path to the downloaded zip file] 70 | ``` 71 | 72 | #### WIKITEXT-103 73 | As the language model task has many additional codes, we place it in another branch: `language-model`. 74 | We follow the data pre-processing in [fairseq](github.com/pytorch/fairseq). To download and preprocess the data, one can run 75 | ```bash 76 | git checkout language-model 77 | bash configs/wikitext-103/prepare.sh 78 | ``` 79 | 80 | ### Testing 81 | 82 | For example, to test the models on WMT'14 En-Fr, one can run 83 | ```bash 84 | configs/wmt14.en-fr/test.sh [path to the model checkpoints] [gpu-id] [test|valid] 85 | ``` 86 | For instance, to evaluate Lite Transformer on GPU 0 (with the BLEU score on test set of WMT'14 En-Fr), one can run 87 | ```bash 88 | configs/wmt14.en-fr/test.sh embed496/ 0 test 89 | ``` 90 | We provide several pretrained models at the bottom. You can download the model and extract the file by 91 | ```bash 92 | tar -xzvf [filename] 93 | ``` 94 | 95 | ### Training 96 | We provided several examples to train Lite Transformer with this repo: 97 | 98 | To train Lite Transformer on WMT'14 En-Fr (with 8 GPUs), one can run 99 | ```bash 100 | python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml 101 | ``` 102 | To train Lite Transformer with less GPUs, e.g. 4 GPUS, one can run 103 | ```bash 104 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32 105 | ``` 106 | In general, to train a model, one can run 107 | ```bash 108 | python train.py [path to the data binary] --configs [path to config file] [override options] 109 | ``` 110 | Note that `--update-freq` should be adjusted according to the GPU numbers (16 for 8 GPUs, 32 for 4 GPUs). 111 | 112 | ### Distributed Training (optional) 113 | 114 | To train Lite Transformer in distributed manner. For example on two GPU nodes with totally 16 GPUs. 115 | ```bash 116 | # On host1 117 | python -m torch.distributed.launch \ 118 | --nproc_per_node=8 \ 119 | --nnodes=2 --node_rank=0 \ 120 | --master_addr=host1 --master_port=8080 \ 121 | train.py data/binary/wmt14_en_fr \ 122 | --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \ 123 | --distributed-no-spawn \ 124 | --update-freq 8 125 | # On host2 126 | python -m torch.distributed.launch \ 127 | --nproc_per_node=8 \ 128 | --nnodes=2 --node_rank=1 \ 129 | --master_addr=host1 --master_port=8080 \ 130 | train.py data/binary/wmt14_en_fr \ 131 | --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml \ 132 | --distributed-no-spawn \ 133 | --update-freq 8 134 | ``` 135 | 136 | ## Models 137 | We provide the checkpoints for our Lite Transformer reported in the paper: 138 | | Dataset | \#Mult-Adds | Test Score | Model and Test Set | 139 | |:--:|:--:|:--:|:--:| 140 | | [WMT'14 En-Fr](http://statmt.org/wmt14/translation-task.html#Download) | 90M | 35.3 |[download](https://drive.google.com/open?id=10Iotg0dnt9sJTqEghtNhIIwJL1R3LYBe) | 141 | | | 360M | 39.1 | [download](https://drive.google.com/open?id=10WMpIrdnDRWa_7afYJsqiiONdWlTLrJs) | 142 | | | 527M | 39.6 | [download](https://drive.google.com/open?id=10Wfv80wOTkL-hkXNyxM8IVlcroHuuUvA) | 143 | | [WMT'16 En-De](https://statmt.org/wmt16/translation-task.html#Download) | 90M | 22.5 | [download](https://drive.google.com/open?id=10ArxzUsMZ8gDe6zw5d3xTHYmeUasys1q) | 144 | | | 360M | 25.6 | [download](https://drive.google.com/open?id=10Fd1iXFiOtuwjxm1K8S2RqiEeCuDhxYn) | 145 | | | 527M | 26.5 | [download](https://drive.google.com/open?id=10HYj-rcJ4CIPp-BtpckkmYIgzH5Urrz0)| 146 | | [CNN / DailyMail](https://github.com/abisee/cnn-dailymail) | 800M | 38.3 (R-L) | [download](https://drive.google.com/open?id=14sQZ_H7HMQGhL7Ko1WkktWUvbEslOeu9)| 147 | | [WIKITEXT-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | 1147M | 22.2 (PPL) | [download](https://drive.google.com/file/d/14gT1j5VERgtDFfo2Ef1yOiliT9Y2eKe_/view?usp=sharing)| 148 | 149 | -------------------------------------------------------------------------------- /configs/cnndm/attention/multibranch_v2/embed496.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_wmt_en_de 2 | no-progress-bar: true 3 | 4 | share-all-embeddings: True 5 | log-interval: 1000 6 | optimizer: adam 7 | adam-betas: (0.9, 0.98) 8 | clip-norm: 0.0 9 | weight-decay: 0.0 10 | criterion: label_smoothed_cross_entropy 11 | label-smoothing: 0.1 12 | update-freq: 16 13 | 14 | keep-last-epochs: 10 15 | ddp-backend: no_c10d 16 | max-tokens: 4096 17 | lr-scheduler: cosine 18 | warmup-init-lr: 1e-7 19 | warmup-updates: 10000 20 | max-update: 50000 21 | lr-shrink: 1 22 | max-lr: 0.001 23 | lr: 1e-7 24 | min-lr: 1e-9 25 | t-mult: 1 26 | lr-period-updates: 40000 27 | 28 | dropout: 0.15 29 | attention-dropout: 0.1 30 | 31 | fp16: false 32 | 33 | weight-dropout: 0.1 34 | encoder-glu: 1 35 | decoder-glu: 1 36 | encoder-branch-type: [attn:1:248:4, dynamic:default:248:4] 37 | decoder-branch-type: [attn:1:248:4, dynamic:default:248:4] 38 | conv-linear: true 39 | 40 | 41 | encoder-embed-dim: 496 42 | encoder-ffn-embed-dim: 496 43 | decoder-embed-dim: 496 44 | decoder-ffn-embed-dim: 496 45 | -------------------------------------------------------------------------------- /configs/cnndm/test.sh: -------------------------------------------------------------------------------- 1 | 2 | checkpoint_path=$1 3 | output_path=$checkpoint_path/exp 4 | gpu=${2:-0} 5 | dataset=${3:-"test"} 6 | mkdir -p $output_path 7 | 8 | CUDA_VISIBLE_DEVICES=$gpu python generate.py data/binary/cnndm --path "$checkpoint_path/checkpoint_best.pt" --remove-bpe --gen-subset $dataset \ 9 | --batch-size 6 --min-len 55 --max-len-b 140 --beam 4 --lenpen 2.0 --no-repeat-ngram-size 3 > $output_path/cnn_dailymail.out 10 | 11 | GEN=$output_path/cnn_dailymail.out 12 | SYS=$GEN.sys 13 | REF=$GEN.ref 14 | grep ^H $GEN | cut -f3- > $SYS 15 | grep ^T $GEN | cut -f2- > $REF 16 | 17 | export CLASSPATH=`pwd`/configs/cnndm/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar 18 | 19 | cat $SYS | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > $SYS.tokenized 20 | cat $REF | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > $REF.target 21 | 22 | files2rouge $SYS.tokenized $REF.target | tee $output_path/rouge.result 23 | -------------------------------------------------------------------------------- /configs/iwslt14.de-en/attention/multibranch_v2/embed160.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_iwslt_de_en 2 | no-progress-bar: true 3 | 4 | optimizer: adam 5 | lr: 0.0005 6 | source-lang: de 7 | target-lang: en 8 | label-smoothing: 0.1 9 | dropout: 0.2 10 | max-tokens: 4096 11 | clip-norm: 0.0 12 | min-lr: 1e-09 13 | lr-scheduler: inverse_sqrt 14 | weight-decay: 0.0001 15 | criterion: label_smoothed_cross_entropy 16 | max-update: 50000 17 | warmup-updates: 4000 18 | warmup-init-lr: 1e-07 19 | adam-betas: (0.9, 0.98) 20 | fp16: False 21 | 22 | weight-dropout: 0.1 23 | encoder-glu: 0 24 | decoder-glu: 0 25 | encoder-branch-type: [attn:1:80:4, lightweight:default:80:4] 26 | decoder-branch-type: [attn:1:80:4, lightweight:default:80:4] 27 | conv-linear: true 28 | 29 | encoder-embed-dim: 160 30 | decoder-embed-dim: 160 31 | encoder-ffn-embed-dim: 160 32 | decoder-ffn-embed-dim: 160 33 | 34 | -------------------------------------------------------------------------------- /configs/iwslt14.de-en/attention/multibranch_v2/embed240.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_iwslt_de_en 2 | no-progress-bar: true 3 | 4 | optimizer: adam 5 | lr: 0.0005 6 | source-lang: de 7 | target-lang: en 8 | label-smoothing: 0.1 9 | dropout: 0.2 10 | max-tokens: 4096 11 | clip-norm: 0.0 12 | min-lr: 1e-09 13 | lr-scheduler: inverse_sqrt 14 | weight-decay: 0.0001 15 | criterion: label_smoothed_cross_entropy 16 | max-update: 50000 17 | warmup-updates: 4000 18 | warmup-init-lr: 1e-07 19 | adam-betas: (0.9, 0.98) 20 | fp16: False 21 | 22 | weight-dropout: 0.1 23 | encoder-glu: 0 24 | decoder-glu: 0 25 | encoder-branch-type: [attn:1:120:4, lightweight:default:120:4] 26 | decoder-branch-type: [attn:1:120:4, lightweight:default:120:4] 27 | conv-linear: true 28 | 29 | encoder-embed-dim: 240 30 | decoder-embed-dim: 240 31 | encoder-ffn-embed-dim: 240 32 | decoder-ffn-embed-dim: 240 33 | 34 | -------------------------------------------------------------------------------- /configs/iwslt14.de-en/attention/multibranch_v2/embed320.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_iwslt_de_en 2 | no-progress-bar: true 3 | 4 | optimizer: adam 5 | lr: 0.0005 6 | source-lang: de 7 | target-lang: en 8 | label-smoothing: 0.1 9 | dropout: 0.2 10 | max-tokens: 4096 11 | clip-norm: 0.0 12 | min-lr: 1e-09 13 | lr-scheduler: inverse_sqrt 14 | weight-decay: 0.0001 15 | criterion: label_smoothed_cross_entropy 16 | max-update: 50000 17 | warmup-updates: 4000 18 | warmup-init-lr: 1e-07 19 | adam-betas: (0.9, 0.98) 20 | fp16: False 21 | 22 | weight-dropout: 0.1 23 | encoder-glu: 0 24 | decoder-glu: 0 25 | encoder-branch-type: [attn:1:160:4, lightweight:default:160:4] 26 | decoder-branch-type: [attn:1:160:4, lightweight:default:160:4] 27 | conv-linear: true 28 | 29 | encoder-embed-dim: 320 30 | decoder-embed-dim: 320 31 | encoder-ffn-embed-dim: 320 32 | decoder-ffn-embed-dim: 320 33 | -------------------------------------------------------------------------------- /configs/iwslt14.de-en/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 4 | 5 | echo 'Cloning Moses github repository (for tokenization scripts)...' 6 | git clone https://github.com/moses-smt/mosesdecoder.git ../mosesdecoder 7 | ln -s ../mosesdecoder 8 | 9 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 10 | git clone https://github.com/rsennrich/subword-nmt.git ../subword-nmt 11 | ln -s ../subword-nmt 12 | 13 | SCRIPTS=mosesdecoder/scripts 14 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 15 | LC=$SCRIPTS/tokenizer/lowercase.perl 16 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 17 | BPEROOT=subword-nmt 18 | BPE_TOKENS=10000 19 | 20 | URL="https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz" 21 | GZ=de-en.tgz 22 | 23 | if [ ! -d "$SCRIPTS" ]; then 24 | echo "Please set SCRIPTS variable correctly to point to Moses scripts." 25 | exit 26 | fi 27 | 28 | src=de 29 | tgt=en 30 | lang=de-en 31 | prep=data/iwslt14.de-en/iwslt14.tokenized.de-en 32 | orig=data/iwslt14.de-en/orig 33 | tmp=$prep/tmp 34 | 35 | mkdir -p $orig $tmp $prep 36 | 37 | cd $orig 38 | if [ ! -f $GZ ]; then 39 | echo "Downloading data from ${URL}..." 40 | wget "$URL" 41 | fi 42 | 43 | if [ -f $GZ ]; then 44 | echo "Data successfully downloaded." 45 | else 46 | echo "Data not successfully downloaded." 47 | exit 48 | fi 49 | 50 | tar zxvf $GZ 51 | cd - 52 | 53 | echo "pre-processing train data..." 54 | for l in $src $tgt; do 55 | f=train.tags.$lang.$l 56 | tok=train.tags.$lang.tok.$l 57 | 58 | cat $orig/$lang/$f | \ 59 | grep -v '' | \ 60 | grep -v '' | \ 61 | grep -v '' | \ 62 | sed -e 's///g' | \ 63 | sed -e 's/<\/title>//g' | \ 64 | sed -e 's/<description>//g' | \ 65 | sed -e 's/<\/description>//g' | \ 66 | perl $TOKENIZER -threads 8 -l $l > $tmp/$tok 67 | echo "" 68 | done 69 | perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train.tags.$lang.clean 1 175 70 | for l in $src $tgt; do 71 | perl $LC < $tmp/train.tags.$lang.clean.$l > $tmp/train.tags.$lang.$l 72 | done 73 | 74 | echo "pre-processing valid/test data..." 75 | for l in $src $tgt; do 76 | for o in `ls $orig/$lang/IWSLT14.TED*.$l.xml`; do 77 | fname=${o##*/} 78 | f=$tmp/${fname%.*} 79 | echo $o $f 80 | grep '<seg id' $o | \ 81 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 82 | sed -e 's/\s*<\/seg>\s*//g' | \ 83 | sed -e "s/\’/\'/g" | \ 84 | perl $TOKENIZER -threads 8 -l $l | \ 85 | perl $LC > $f 86 | echo "" 87 | done 88 | done 89 | 90 | 91 | echo "creating train, valid, test..." 92 | for l in $src $tgt; do 93 | awk '{if (NR%23 == 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/valid.$l 94 | awk '{if (NR%23 != 0) print $0; }' $tmp/train.tags.de-en.$l > $tmp/train.$l 95 | 96 | cat $tmp/IWSLT14.TED.dev2010.de-en.$l \ 97 | $tmp/IWSLT14.TEDX.dev2012.de-en.$l \ 98 | $tmp/IWSLT14.TED.tst2010.de-en.$l \ 99 | $tmp/IWSLT14.TED.tst2011.de-en.$l \ 100 | $tmp/IWSLT14.TED.tst2012.de-en.$l \ 101 | > $tmp/test.$l 102 | done 103 | 104 | TRAIN=$tmp/train.en-de 105 | BPE_CODE=$prep/code 106 | rm -f $TRAIN 107 | for l in $src $tgt; do 108 | cat $tmp/train.$l >> $TRAIN 109 | done 110 | 111 | echo "learn_bpe.py on ${TRAIN}..." 112 | python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE 113 | 114 | for L in $src $tgt; do 115 | for f in train.$L valid.$L test.$L; do 116 | echo "apply_bpe.py to ${f}..." 117 | python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $prep/$f 118 | done 119 | done 120 | 121 | TEXT=data/iwslt14.de-en/iwslt14.tokenized.de-en 122 | fairseq-preprocess --source-lang de --target-lang en \ 123 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ 124 | --destdir data/binary/iwslt14.tokenized.de-en -------------------------------------------------------------------------------- /configs/iwslt14.de-en/test.sh: -------------------------------------------------------------------------------- 1 | checkpoints_path=$1 2 | gpu=${2:-0} 3 | subset=${3:-"test"} 4 | avg_checkpoints=${4:-10} 5 | model=average_model_$avg_checkpoints.pt 6 | output_path=$checkpoints_path 7 | 8 | mkdir -p $output_path/exp 9 | 10 | python scripts/average_checkpoints.py --inputs $output_path \ 11 | --num-epoch-checkpoints $avg_checkpoints --output $output_path/$model 12 | 13 | CUDA_VISIBLE_DEVICES=$gpu python ./generate.py data/binary/iwslt14.tokenized.de-en \ 14 | --path $output_path/$model --gen-subset $subset \ 15 | --batch-size 128 --beam 4 --remove-bpe > $output_path/exp/${subset}_gen.out 16 | 17 | GEN=$output_path/exp/${subset}_gen.out 18 | 19 | SYS=$GEN.sys 20 | REF=$GEN.ref 21 | 22 | grep ^H $GEN | cut -f3- > $SYS 23 | grep ^T $GEN | cut -f2- > $REF 24 | python score.py --sys $SYS --ref $REF | tee $checkpoints_path/exp/avg_checkpoints.result 25 | -------------------------------------------------------------------------------- /configs/wmt14.en-fr/attention/multibranch_v2/embed200.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_wmt_en_de 2 | no-progress-bar: true 3 | 4 | share-all-embeddings: True 5 | log-interval: 1000 6 | optimizer: adam 7 | adam-betas: (0.9, 0.98) 8 | clip-norm: 0.0 9 | weight-decay: 0.0 10 | criterion: label_smoothed_cross_entropy 11 | label-smoothing: 0.1 12 | update-freq: 16 13 | 14 | keep-last-epochs: 20 15 | ddp-backend: no_c10d 16 | max-tokens: 4096 17 | lr-scheduler: cosine 18 | warmup-init-lr: 1e-7 19 | warmup-updates: 5000 20 | max-update: 50000 21 | lr-shrink: 1 22 | max-lr: 0.001 23 | lr: 1e-7 24 | min-lr: 1e-9 25 | t-mult: 1 26 | lr-period-updates: 45000 27 | 28 | dropout: 0.04 29 | attention-dropout: 0.03 30 | 31 | fp16: false 32 | 33 | weight-dropout: 0.03 34 | encoder-glu: 1 35 | decoder-glu: 1 36 | encoder-branch-type: [attn:1:100:4, dynamic:default:100:4] 37 | decoder-branch-type: [attn:1:100:4, dynamic:default:100:4] 38 | conv-linear: true 39 | 40 | encoder-embed-dim: 200 41 | encoder-ffn-embed-dim: 200 42 | decoder-embed-dim: 200 43 | decoder-ffn-embed-dim: 200 44 | 45 | 46 | num-workers: 2 -------------------------------------------------------------------------------- /configs/wmt14.en-fr/attention/multibranch_v2/embed408.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_wmt_en_de 2 | no-progress-bar: true 3 | 4 | share-all-embeddings: True 5 | log-interval: 1000 6 | optimizer: adam 7 | adam-betas: (0.9, 0.98) 8 | clip-norm: 0.0 9 | weight-decay: 0.0 10 | criterion: label_smoothed_cross_entropy 11 | label-smoothing: 0.1 12 | update-freq: 16 13 | 14 | keep-last-epochs: 20 15 | ddp-backend: no_c10d 16 | max-tokens: 4096 17 | lr-scheduler: cosine 18 | warmup-init-lr: 1e-7 19 | warmup-updates: 5000 20 | max-update: 50000 21 | lr-shrink: 1 22 | max-lr: 0.001 23 | lr: 1e-7 24 | min-lr: 1e-9 25 | t-mult: 1 26 | lr-period-updates: 45000 27 | 28 | dropout: 0.08 29 | attention-dropout: 0.07 30 | 31 | fp16: false 32 | 33 | weight-dropout: 0.07 34 | encoder-glu: 1 35 | decoder-glu: 1 36 | encoder-branch-type: [attn:1:204:4, dynamic:default:204:4] 37 | decoder-branch-type: [attn:1:204:4, dynamic:default:204:4] 38 | conv-linear: true 39 | 40 | encoder-embed-dim: 408 41 | encoder-ffn-embed-dim: 408 42 | decoder-embed-dim: 408 43 | decoder-ffn-embed-dim: 408 44 | 45 | -------------------------------------------------------------------------------- /configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_wmt_en_de 2 | no-progress-bar: true 3 | 4 | share-all-embeddings: True 5 | log-interval: 1000 6 | optimizer: adam 7 | adam-betas: (0.9, 0.98) 8 | clip-norm: 0.0 9 | weight-decay: 0.0 10 | criterion: label_smoothed_cross_entropy 11 | label-smoothing: 0.1 12 | update-freq: 16 13 | 14 | keep-last-epochs: 20 15 | ddp-backend: no_c10d 16 | max-tokens: 4096 17 | lr-scheduler: cosine 18 | warmup-init-lr: 1e-7 19 | warmup-updates: 5000 20 | max-update: 50000 21 | lr-shrink: 1 22 | max-lr: 0.001 23 | lr: 1e-7 24 | min-lr: 1e-9 25 | t-mult: 1 26 | lr-period-updates: 45000 27 | 28 | dropout: 0.1 29 | attention-dropout: 0.08 30 | 31 | fp16: false 32 | 33 | weight-dropout: 0.08 34 | encoder-glu: 1 35 | decoder-glu: 1 36 | encoder-branch-type: [attn:1:248:4, dynamic:default:248:4] 37 | decoder-branch-type: [attn:1:248:4, dynamic:default:248:4] 38 | conv-linear: true 39 | 40 | encoder-embed-dim: 496 41 | encoder-ffn-embed-dim: 496 42 | decoder-embed-dim: 496 43 | decoder-ffn-embed-dim: 496 44 | -------------------------------------------------------------------------------- /configs/wmt14.en-fr/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh 3 | 4 | echo 'Cloning Moses github repository (for tokenization scripts)...' 5 | git clone https://github.com/moses-smt/mosesdecoder.git 6 | 7 | echo 'Cloning Subword NMT repository (for BPE pre-processing)...' 8 | git clone https://github.com/rsennrich/subword-nmt.git 9 | 10 | SCRIPTS=mosesdecoder/scripts 11 | TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl 12 | CLEAN=$SCRIPTS/training/clean-corpus-n.perl 13 | NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl 14 | REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl 15 | BPEROOT=subword-nmt 16 | BPE_TOKENS=40000 17 | 18 | URLS=( 19 | "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" 20 | "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" 21 | "http://statmt.org/wmt13/training-parallel-un.tgz" 22 | "http://statmt.org/wmt14/training-parallel-nc-v9.tgz" 23 | "http://statmt.org/wmt10/training-giga-fren.tar" 24 | "http://statmt.org/wmt14/test-full.tgz" 25 | ) 26 | FILES=( 27 | "training-parallel-europarl-v7.tgz" 28 | "training-parallel-commoncrawl.tgz" 29 | "training-parallel-un.tgz" 30 | "training-parallel-nc-v9.tgz" 31 | "training-giga-fren.tar" 32 | "test-full.tgz" 33 | ) 34 | CORPORA=( 35 | "training/europarl-v7.fr-en" 36 | "commoncrawl.fr-en" 37 | "un/undoc.2000.fr-en" 38 | "training/news-commentary-v9.fr-en" 39 | "giga-fren.release2.fixed" 40 | ) 41 | 42 | if [ ! -d "$SCRIPTS" ]; then 43 | echo "Please set SCRIPTS variable correctly to point to Moses scripts." 44 | exit 45 | fi 46 | 47 | src=en 48 | tgt=fr 49 | lang=en-fr 50 | prep=data/wmt14_en_fr/wmt14.tokenized.en-fr 51 | orig=data/wmt14_en_fr/orig 52 | tmp=$prep/tmp 53 | 54 | mkdir -p $orig $tmp $prep 55 | 56 | cd $orig 57 | 58 | for ((i=0;i<${#URLS[@]};++i)); do 59 | file=${FILES[i]} 60 | if [ -f $file ]; then 61 | echo "$file already exists, skipping download" 62 | else 63 | url=${URLS[i]} 64 | wget "$url" 65 | if [ -f $file ]; then 66 | echo "$url successfully downloaded." 67 | else 68 | echo "$url not successfully downloaded." 69 | exit -1 70 | fi 71 | if [ ${file: -4} == ".tgz" ]; then 72 | tar zxvf $file 73 | elif [ ${file: -4} == ".tar" ]; then 74 | tar xvf $file 75 | fi 76 | fi 77 | done 78 | 79 | gunzip giga-fren.release2.fixed.*.gz 80 | cd - 81 | 82 | echo "pre-processing train data..." 83 | for l in $src $tgt; do 84 | rm $tmp/train.tags.$lang.tok.$l 85 | for f in "${CORPORA[@]}"; do 86 | cat $orig/$f.$l | \ 87 | perl $NORM_PUNC $l | \ 88 | perl $REM_NON_PRINT_CHAR | \ 89 | perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l 90 | done 91 | done 92 | 93 | echo "pre-processing test data..." 94 | for l in $src $tgt; do 95 | if [ "$l" == "$src" ]; then 96 | t="src" 97 | else 98 | t="ref" 99 | fi 100 | grep '<seg id' $orig/test-full/newstest2014-fren-$t.$l.sgm | \ 101 | sed -e 's/<seg id="[0-9]*">\s*//g' | \ 102 | sed -e 's/\s*<\/seg>\s*//g' | \ 103 | sed -e "s/\’/\'/g" | \ 104 | perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l 105 | echo "" 106 | done 107 | 108 | echo "splitting train and valid..." 109 | for l in $src $tgt; do 110 | awk '{if (NR%1333 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l 111 | awk '{if (NR%1333 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l 112 | done 113 | 114 | TRAIN=$tmp/train.fr-en 115 | BPE_CODE=$prep/code 116 | rm -f $TRAIN 117 | for l in $src $tgt; do 118 | cat $tmp/train.$l >> $TRAIN 119 | done 120 | 121 | echo "learn_bpe.py on ${TRAIN}..." 122 | python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE 123 | 124 | for L in $src $tgt; do 125 | for f in train.$L valid.$L test.$L; do 126 | echo "apply_bpe.py to ${f}..." 127 | python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f 128 | done 129 | done 130 | 131 | perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250 132 | perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250 133 | 134 | for L in $src $tgt; do 135 | cp $tmp/bpe.test.$L $prep/test.$L 136 | done 137 | 138 | TEXT=data/wmt14_en_fr/wmt14.tokenized.en-fr 139 | fairseq-preprocess \ 140 | --source-lang en --target-lang fr \ 141 | --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ 142 | --destdir data/binary/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0 \ 143 | --joined-dictionary --workers 60 144 | -------------------------------------------------------------------------------- /configs/wmt14.en-fr/test.sh: -------------------------------------------------------------------------------- 1 | checkpoints_path=$1 2 | gpu=${2:-0} 3 | subset=${3:-"test"} 4 | 5 | mkdir -p $checkpoints_path/exp 6 | 7 | CUDA_VISIBLE_DEVICES=$gpu python generate.py data/binary/wmt14_en_fr \ 8 | --path "$checkpoints_path/checkpoint_best.pt" --gen-subset $subset \ 9 | --beam 4 --batch-size 128 --remove-bpe --lenpen 0.6 > $checkpoints_path/exp/${subset}_gen.out 10 | 11 | GEN=$checkpoints_path/exp/${subset}_gen.out 12 | 13 | SYS=$GEN.sys 14 | REF=$GEN.ref 15 | 16 | grep ^H $GEN | cut -f3- > $SYS 17 | grep ^T $GEN | cut -f2- > $REF 18 | python score.py --sys $SYS --ref $REF | tee $checkpoints_path/exp/checkpoint_best.result 19 | -------------------------------------------------------------------------------- /configs/wmt16.en-de/attention/multibranch_v2/embed200.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_wmt_en_de 2 | no-progress-bar: true 3 | 4 | share-all-embeddings: True 5 | log-interval: 1000 6 | optimizer: adam 7 | adam-betas: (0.9, 0.98) 8 | clip-norm: 0.0 9 | weight-decay: 0.0 10 | criterion: label_smoothed_cross_entropy 11 | label-smoothing: 0.1 12 | update-freq: 16 13 | 14 | keep-last-epochs: 20 15 | ddp-backend: no_c10d 16 | max-tokens: 4096 17 | lr-scheduler: cosine 18 | warmup-init-lr: 1e-7 19 | warmup-updates: 10000 20 | max-update: 50000 21 | lr-shrink: 1 22 | max-lr: 0.001 23 | lr: 1e-7 24 | min-lr: 1e-9 25 | t-mult: 1 26 | lr-period-updates: 40000 27 | 28 | dropout: 0.04 29 | attention-dropout: 0.03 30 | 31 | fp16: false 32 | 33 | weight-dropout: 0.03 34 | encoder-glu: 1 35 | decoder-glu: 1 36 | encoder-branch-type: [attn:1:100:4, dynamic:default:100:4] 37 | decoder-branch-type: [attn:1:100:4, dynamic:default:100:4] 38 | conv-linear: true 39 | 40 | encoder-embed-dim: 200 41 | encoder-ffn-embed-dim: 200 42 | decoder-embed-dim: 200 43 | decoder-ffn-embed-dim: 200 44 | 45 | 46 | num-workers: 4 -------------------------------------------------------------------------------- /configs/wmt16.en-de/attention/multibranch_v2/embed408.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_wmt_en_de 2 | no-progress-bar: true 3 | 4 | share-all-embeddings: True 5 | log-interval: 1000 6 | optimizer: adam 7 | adam-betas: (0.9, 0.98) 8 | clip-norm: 0.0 9 | weight-decay: 0.0 10 | criterion: label_smoothed_cross_entropy 11 | label-smoothing: 0.1 12 | update-freq: 16 13 | 14 | keep-last-epochs: 20 15 | ddp-backend: no_c10d 16 | max-tokens: 4096 17 | lr-scheduler: cosine 18 | warmup-init-lr: 1e-7 19 | warmup-updates: 10000 20 | max-update: 50000 21 | lr-shrink: 1 22 | max-lr: 0.001 23 | lr: 1e-7 24 | min-lr: 1e-9 25 | t-mult: 1 26 | lr-period-updates: 40000 27 | 28 | dropout: 0.08 29 | attention-dropout: 0.06 30 | 31 | fp16: false 32 | 33 | weight-dropout: 0.06 34 | encoder-glu: 1 35 | decoder-glu: 1 36 | encoder-branch-type: [attn:1:204:4, dynamic:default:204:4] 37 | decoder-branch-type: [attn:1:204:4, dynamic:default:204:4] 38 | conv-linear: true 39 | 40 | encoder-embed-dim: 408 41 | encoder-ffn-embed-dim: 408 42 | decoder-embed-dim: 408 43 | decoder-ffn-embed-dim: 408 44 | 45 | -------------------------------------------------------------------------------- /configs/wmt16.en-de/attention/multibranch_v2/embed496.yml: -------------------------------------------------------------------------------- 1 | arch: transformer_multibranch_v2_wmt_en_de 2 | no-progress-bar: true 3 | 4 | share-all-embeddings: True 5 | log-interval: 1000 6 | optimizer: adam 7 | adam-betas: (0.9, 0.98) 8 | clip-norm: 0.0 9 | weight-decay: 0.0 10 | criterion: label_smoothed_cross_entropy 11 | label-smoothing: 0.1 12 | update-freq: 16 13 | 14 | keep-last-epochs: 20 15 | ddp-backend: no_c10d 16 | max-tokens: 4096 17 | lr-scheduler: cosine 18 | warmup-init-lr: 1e-7 19 | warmup-updates: 10000 20 | max-update: 50000 21 | lr-shrink: 1 22 | max-lr: 0.001 23 | lr: 1e-7 24 | min-lr: 1e-9 25 | t-mult: 1 26 | lr-period-updates: 40000 27 | 28 | dropout: 0.1 29 | attention-dropout: 0.08 30 | 31 | fp16: false 32 | 33 | weight-dropout: 0.08 34 | encoder-glu: 1 35 | decoder-glu: 1 36 | encoder-branch-type: [attn:1:248:4, dynamic:default:248:4] 37 | decoder-branch-type: [attn:1:248:4, dynamic:default:248:4] 38 | conv-linear: true 39 | 40 | encoder-embed-dim: 496 41 | encoder-ffn-embed-dim: 496 42 | decoder-embed-dim: 496 43 | decoder-ffn-embed-dim: 496 44 | 45 | 46 | num-workers: 4 47 | -------------------------------------------------------------------------------- /configs/wmt16.en-de/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Please download the data from the google drive https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8 3 | # Place the downloaded zip file into data/ 4 | 5 | DATA_DIR=data 6 | TEXT=$DATA_DIR/wmt16_en_de_bpe32k 7 | TAR=${1:-data/wmt16_en_de_bpe32k.tar.gz} 8 | mkdir -p $TEXT 9 | if [ "$TAR" == "" ]; then TAR=data/wmt16_en_de_bpe32k.tar.gz; fi 10 | tar -xvzf $TAR -C "$DATA_DIR" 11 | 12 | fairseq-preprocess --source-lang en --target-lang de \ 13 | --trainpref $TEXT/train.tok.clean.bpe.32000 \ 14 | --validpref $TEXT/newstest2013.tok.bpe.32000 \ 15 | --testpref $TEXT/newstest2014.tok.bpe.32000 \ 16 | --destdir data/binary/wmt16_en_de_bpe32k \ 17 | --nwordssrc 32768 --nwordstgt 32768 \ 18 | --joined-dictionary --workers 10 19 | -------------------------------------------------------------------------------- /configs/wmt16.en-de/test.sh: -------------------------------------------------------------------------------- 1 | checkpoints_path=$1 2 | gpu=${2:-0} 3 | subset=${3:-"test"} 4 | 5 | mkdir -p $checkpoints_path/exp 6 | 7 | CUDA_VISIBLE_DEVICES=$gpu python generate.py data/binary/wmt16_en_de_bpe32k \ 8 | --path "$checkpoints_path/checkpoint_best.pt" --gen-subset $subset \ 9 | --beam 4 --batch-size 128 --remove-bpe --lenpen 0.6 > $checkpoints_path/exp/${subset}_gen.out 10 | 11 | GEN=$checkpoints_path/exp/${subset}_gen.out 12 | 13 | SYS=$GEN.sys 14 | REF=$GEN.ref 15 | 16 | grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS 17 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 18 | python score.py --sys $SYS --ref $REF | tee $checkpoints_path/exp/checkpoint_best.result 19 | -------------------------------------------------------------------------------- /fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | __all__ = ['pdb'] 7 | __version__ = '0.8.0' 8 | 9 | import fairseq.criterions # noqa 10 | import fairseq.models # noqa 11 | import fairseq.modules # noqa 12 | import fairseq.optim # noqa 13 | import fairseq.optim.lr_scheduler # noqa 14 | import fairseq.pdb # noqa 15 | import fairseq.tasks # noqa 16 | -------------------------------------------------------------------------------- /fairseq/binarizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from collections import Counter 7 | import os 8 | 9 | from fairseq.tokenizer import tokenize_line 10 | 11 | 12 | def safe_readline(f): 13 | pos = f.tell() 14 | while True: 15 | try: 16 | return f.readline() 17 | except UnicodeDecodeError: 18 | pos -= 1 19 | f.seek(pos) # search where this character begins 20 | 21 | 22 | class Binarizer: 23 | 24 | @staticmethod 25 | def binarize(filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False, 26 | offset=0, end=-1): 27 | nseq, ntok = 0, 0 28 | replaced = Counter() 29 | 30 | def replaced_consumer(word, idx): 31 | if idx == dict.unk_index and word != dict.unk_word: 32 | replaced.update([word]) 33 | 34 | with open(filename, 'r', encoding='utf-8') as f: 35 | f.seek(offset) 36 | # next(f) breaks f.tell(), hence readline() must be used 37 | line = safe_readline(f) 38 | while line: 39 | if end > 0 and f.tell() > end: 40 | break 41 | ids = dict.encode_line( 42 | line=line, 43 | line_tokenizer=tokenize, 44 | add_if_not_exist=False, 45 | consumer=replaced_consumer, 46 | append_eos=append_eos, 47 | reverse_order=reverse_order, 48 | ) 49 | nseq += 1 50 | ntok += len(ids) 51 | consumer(ids) 52 | line = f.readline() 53 | return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced} 54 | 55 | @staticmethod 56 | def find_offsets(filename, num_chunks): 57 | with open(filename, 'r', encoding='utf-8') as f: 58 | size = os.fstat(f.fileno()).st_size 59 | chunk_size = size // num_chunks 60 | offsets = [0 for _ in range(num_chunks + 1)] 61 | for i in range(1, num_chunks): 62 | f.seek(chunk_size * i) 63 | safe_readline(f) 64 | offsets[i] = f.tell() 65 | return offsets 66 | -------------------------------------------------------------------------------- /fairseq/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import ctypes 7 | import math 8 | import torch 9 | 10 | try: 11 | from fairseq import libbleu 12 | except ImportError as e: 13 | import sys 14 | sys.stderr.write('ERROR: missing libbleu.so. run `pip install --editable .`\n') 15 | raise e 16 | 17 | 18 | C = ctypes.cdll.LoadLibrary(libbleu.__file__) 19 | 20 | 21 | class BleuStat(ctypes.Structure): 22 | _fields_ = [ 23 | ('reflen', ctypes.c_size_t), 24 | ('predlen', ctypes.c_size_t), 25 | ('match1', ctypes.c_size_t), 26 | ('count1', ctypes.c_size_t), 27 | ('match2', ctypes.c_size_t), 28 | ('count2', ctypes.c_size_t), 29 | ('match3', ctypes.c_size_t), 30 | ('count3', ctypes.c_size_t), 31 | ('match4', ctypes.c_size_t), 32 | ('count4', ctypes.c_size_t), 33 | ] 34 | 35 | 36 | class SacrebleuScorer(object): 37 | def __init__(self): 38 | import sacrebleu 39 | self.sacrebleu = sacrebleu 40 | self.reset() 41 | 42 | def reset(self, one_init=False): 43 | if one_init: 44 | raise NotImplementedError 45 | self.ref = [] 46 | self.sys = [] 47 | 48 | def add_string(self, ref, pred): 49 | self.ref.append(ref) 50 | self.sys.append(pred) 51 | 52 | def score(self, order=4): 53 | return self.result_string(order).score 54 | 55 | def result_string(self, order=4): 56 | if order != 4: 57 | raise NotImplementedError 58 | return self.sacrebleu.corpus_bleu(self.sys, [self.ref]) 59 | 60 | 61 | class Scorer(object): 62 | def __init__(self, pad, eos, unk): 63 | self.stat = BleuStat() 64 | self.pad = pad 65 | self.eos = eos 66 | self.unk = unk 67 | self.reset() 68 | 69 | def reset(self, one_init=False): 70 | if one_init: 71 | C.bleu_one_init(ctypes.byref(self.stat)) 72 | else: 73 | C.bleu_zero_init(ctypes.byref(self.stat)) 74 | 75 | def add(self, ref, pred): 76 | if not isinstance(ref, torch.IntTensor): 77 | raise TypeError('ref must be a torch.IntTensor (got {})' 78 | .format(type(ref))) 79 | if not isinstance(pred, torch.IntTensor): 80 | raise TypeError('pred must be a torch.IntTensor(got {})' 81 | .format(type(pred))) 82 | 83 | # don't match unknown words 84 | rref = ref.clone() 85 | assert not rref.lt(0).any() 86 | rref[rref.eq(self.unk)] = -999 87 | 88 | rref = rref.contiguous().view(-1) 89 | pred = pred.contiguous().view(-1) 90 | 91 | C.bleu_add( 92 | ctypes.byref(self.stat), 93 | ctypes.c_size_t(rref.size(0)), 94 | ctypes.c_void_p(rref.data_ptr()), 95 | ctypes.c_size_t(pred.size(0)), 96 | ctypes.c_void_p(pred.data_ptr()), 97 | ctypes.c_int(self.pad), 98 | ctypes.c_int(self.eos)) 99 | 100 | def score(self, order=4): 101 | psum = sum(math.log(p) if p > 0 else float('-Inf') 102 | for p in self.precision()[:order]) 103 | return self.brevity() * math.exp(psum / order) * 100 104 | 105 | def precision(self): 106 | def ratio(a, b): 107 | return a / b if b > 0 else 0 108 | 109 | return [ 110 | ratio(self.stat.match1, self.stat.count1), 111 | ratio(self.stat.match2, self.stat.count2), 112 | ratio(self.stat.match3, self.stat.count3), 113 | ratio(self.stat.match4, self.stat.count4), 114 | ] 115 | 116 | def brevity(self): 117 | r = self.stat.reflen / self.stat.predlen 118 | return min(1, math.exp(1 - r)) 119 | 120 | def result_string(self, order=4): 121 | assert order <= 4, "BLEU scores for order > 4 aren't supported" 122 | fmt = 'BLEU{} = {:2.2f}, {:2.1f}' 123 | for _ in range(1, order): 124 | fmt += '/{:2.1f}' 125 | fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})' 126 | bleup = [p * 100 for p in self.precision()[:order]] 127 | return fmt.format(order, self.score(order=order), *bleup, 128 | self.brevity(), self.stat.predlen/self.stat.reflen, 129 | self.stat.predlen, self.stat.reflen) 130 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/libbleu.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <map> 10 | #include <array> 11 | #include <cstring> 12 | #include <cstdio> 13 | 14 | typedef struct 15 | { 16 | size_t reflen; 17 | size_t predlen; 18 | size_t match1; 19 | size_t count1; 20 | size_t match2; 21 | size_t count2; 22 | size_t match3; 23 | size_t count3; 24 | size_t match4; 25 | size_t count4; 26 | } bleu_stat; 27 | 28 | // left trim (remove pad) 29 | void bleu_ltrim(size_t* len, int** sent, int pad) { 30 | size_t start = 0; 31 | while(start < *len) { 32 | if (*(*sent + start) != pad) { break; } 33 | start++; 34 | } 35 | *sent += start; 36 | *len -= start; 37 | } 38 | 39 | // right trim remove (eos) 40 | void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { 41 | size_t end = *len - 1; 42 | while (end > 0) { 43 | if (*(*sent + end) != eos && *(*sent + end) != pad) { break; } 44 | end--; 45 | } 46 | *len = end + 1; 47 | } 48 | 49 | // left and right trim 50 | void bleu_trim(size_t* len, int** sent, int pad, int eos) { 51 | bleu_ltrim(len, sent, pad); 52 | bleu_rtrim(len, sent, pad, eos); 53 | } 54 | 55 | size_t bleu_hash(int len, int* data) { 56 | size_t h = 14695981039346656037ul; 57 | size_t prime = 0x100000001b3; 58 | char* b = (char*) data; 59 | size_t blen = sizeof(int) * len; 60 | 61 | while (blen-- > 0) { 62 | h ^= *b++; 63 | h *= prime; 64 | } 65 | 66 | return h; 67 | } 68 | 69 | void bleu_addngram( 70 | size_t *ntotal, size_t *nmatch, size_t n, 71 | size_t reflen, int* ref, size_t predlen, int* pred) { 72 | 73 | if (predlen < n) { return; } 74 | 75 | predlen = predlen - n + 1; 76 | (*ntotal) += predlen; 77 | 78 | if (reflen < n) { return; } 79 | 80 | reflen = reflen - n + 1; 81 | 82 | std::map<size_t, size_t> count; 83 | while (predlen > 0) { 84 | size_t w = bleu_hash(n, pred++); 85 | count[w]++; 86 | predlen--; 87 | } 88 | 89 | while (reflen > 0) { 90 | size_t w = bleu_hash(n, ref++); 91 | if (count[w] > 0) { 92 | (*nmatch)++; 93 | count[w] -=1; 94 | } 95 | reflen--; 96 | } 97 | } 98 | 99 | extern "C" { 100 | 101 | void bleu_zero_init(bleu_stat* stat) { 102 | std::memset(stat, 0, sizeof(bleu_stat)); 103 | } 104 | 105 | void bleu_one_init(bleu_stat* stat) { 106 | bleu_zero_init(stat); 107 | stat->count1 = 0; 108 | stat->count2 = 1; 109 | stat->count3 = 1; 110 | stat->count4 = 1; 111 | stat->match1 = 0; 112 | stat->match2 = 1; 113 | stat->match3 = 1; 114 | stat->match4 = 1; 115 | } 116 | 117 | void bleu_add( 118 | bleu_stat* stat, 119 | size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) { 120 | 121 | bleu_trim(&reflen, &ref, pad, eos); 122 | bleu_trim(&predlen, &pred, pad, eos); 123 | stat->reflen += reflen; 124 | stat->predlen += predlen; 125 | 126 | bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); 127 | bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); 128 | bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); 129 | bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include <Python.h> 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | from fairseq.criterions.fairseq_criterion import FairseqCriterion 11 | 12 | 13 | build_criterion, register_criterion, CRITERION_REGISTRY = registry.setup_registry( 14 | '--criterion', 15 | base_class=FairseqCriterion, 16 | default='cross_entropy', 17 | ) 18 | 19 | 20 | # automatically import any Python files in the criterions/ directory 21 | for file in os.listdir(os.path.dirname(__file__)): 22 | if file.endswith('.py') and not file.startswith('_'): 23 | module = file[:file.find('.py')] 24 | importlib.import_module('fairseq.criterions.' + module) 25 | -------------------------------------------------------------------------------- /fairseq/criterions/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch.nn.functional as F 8 | 9 | from fairseq import utils 10 | 11 | from . import FairseqCriterion, register_criterion 12 | 13 | 14 | @register_criterion('cross_entropy') 15 | class CrossEntropyCriterion(FairseqCriterion): 16 | 17 | def __init__(self, args, task): 18 | super().__init__(args, task) 19 | 20 | def forward(self, model, sample, reduce=True): 21 | """Compute the loss for the given sample. 22 | Returns a tuple with three elements: 23 | 1) the loss 24 | 2) the sample size, which is used as the denominator for the gradient 25 | 3) logging outputs to display while training 26 | """ 27 | net_output = model(**sample['net_input']) 28 | loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce) 29 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 30 | logging_output = { 31 | 'loss': utils.item(loss.data) if reduce else loss.data, 32 | 'nll_loss': utils.item(loss.data) if reduce else loss.data, 33 | 'ntokens': sample['ntokens'], 34 | 'nsentences': sample['target'].size(0), 35 | 'sample_size': sample_size, 36 | } 37 | return loss, sample_size, logging_output 38 | 39 | def compute_loss(self, model, net_output, sample, reduce=True): 40 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 41 | lprobs = lprobs.view(-1, lprobs.size(-1)) 42 | target = model.get_targets(sample, net_output).view(-1) 43 | loss = F.nll_loss( 44 | lprobs, 45 | target, 46 | ignore_index=self.padding_idx, 47 | reduction='sum' if reduce else 'none', 48 | ) 49 | return loss, loss 50 | 51 | @staticmethod 52 | def aggregate_logging_outputs(logging_outputs): 53 | """Aggregate logging outputs from data parallel training.""" 54 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 55 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 56 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 57 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 58 | agg_output = { 59 | 'loss': loss_sum / sample_size / math.log(2) if sample_size > 0 else 0., 60 | 'ntokens': ntokens, 61 | 'nsentences': nsentences, 62 | 'sample_size': sample_size, 63 | } 64 | if sample_size != ntokens: 65 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 66 | return agg_output 67 | -------------------------------------------------------------------------------- /fairseq/criterions/fairseq_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch.nn.modules.loss import _Loss 7 | 8 | 9 | class FairseqCriterion(_Loss): 10 | 11 | def __init__(self, args, task): 12 | super().__init__() 13 | self.args = args 14 | self.task = task 15 | self.padding_idx = task.target_dictionary.pad() if task.target_dictionary is not None else -100 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add criterion-specific arguments to the parser.""" 20 | pass 21 | 22 | @classmethod 23 | def build_criterion(cls, args, task): 24 | return cls(args, task) 25 | 26 | def forward(self, model, sample, reduce=True): 27 | """Compute the loss for the given sample. 28 | 29 | Returns a tuple with three elements: 30 | 1) the loss 31 | 2) the sample size, which is used as the denominator for the gradient 32 | 3) logging outputs to display while training 33 | """ 34 | raise NotImplementedError 35 | 36 | @staticmethod 37 | def aggregate_logging_outputs(logging_outputs): 38 | """Aggregate logging outputs from data parallel training.""" 39 | raise NotImplementedError 40 | 41 | @staticmethod 42 | def grad_denom(sample_sizes): 43 | """Compute the gradient denominator for a set of sample sizes.""" 44 | return sum(sample_sizes) 45 | -------------------------------------------------------------------------------- /fairseq/criterions/label_smoothed_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | 8 | from fairseq import utils 9 | 10 | from . import FairseqCriterion, register_criterion 11 | 12 | 13 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): 14 | if target.dim() == lprobs.dim() - 1: 15 | target = target.unsqueeze(-1) 16 | nll_loss = -lprobs.gather(dim=-1, index=target) 17 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 18 | if ignore_index is not None: 19 | non_pad_mask = target.ne(ignore_index) 20 | nll_loss = nll_loss[non_pad_mask] 21 | smooth_loss = smooth_loss[non_pad_mask] 22 | else: 23 | nll_loss = nll_loss.squeeze(-1) 24 | smooth_loss = smooth_loss.squeeze(-1) 25 | if reduce: 26 | nll_loss = nll_loss.sum() 27 | smooth_loss = smooth_loss.sum() 28 | eps_i = epsilon / lprobs.size(-1) 29 | loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss 30 | return loss, nll_loss 31 | 32 | 33 | @register_criterion('label_smoothed_cross_entropy') 34 | class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): 35 | 36 | def __init__(self, args, task): 37 | super().__init__(args, task) 38 | self.eps = args.label_smoothing 39 | 40 | @staticmethod 41 | def add_args(parser): 42 | """Add criterion-specific arguments to the parser.""" 43 | # fmt: off 44 | parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', 45 | help='epsilon for label smoothing, 0 means no label smoothing') 46 | # fmt: on 47 | 48 | def forward(self, model, sample, reduce=True): 49 | """Compute the loss for the given sample. 50 | 51 | Returns a tuple with three elements: 52 | 1) the loss 53 | 2) the sample size, which is used as the denominator for the gradient 54 | 3) logging outputs to display while training 55 | """ 56 | net_output = model(**sample['net_input']) 57 | loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) 58 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 59 | logging_output = { 60 | 'loss': utils.item(loss.data) if reduce else loss.data, 61 | 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 62 | 'ntokens': sample['ntokens'], 63 | 'nsentences': sample['target'].size(0), 64 | 'sample_size': sample_size, 65 | } 66 | return loss, sample_size, logging_output 67 | 68 | def compute_loss(self, model, net_output, sample, reduce=True): 69 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 70 | lprobs = lprobs.view(-1, lprobs.size(-1)) 71 | target = model.get_targets(sample, net_output).view(-1, 1) 72 | loss, nll_loss = label_smoothed_nll_loss( 73 | lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, 74 | ) 75 | return loss, nll_loss 76 | 77 | @staticmethod 78 | def aggregate_logging_outputs(logging_outputs): 79 | """Aggregate logging outputs from data parallel training.""" 80 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 81 | nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) 82 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 83 | return { 84 | 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0., 85 | 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0., 86 | 'ntokens': ntokens, 87 | 'nsentences': nsentences, 88 | 'sample_size': sample_size, 89 | } 90 | -------------------------------------------------------------------------------- /fairseq/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .dictionary import Dictionary, TruncatedDictionary 7 | 8 | from .fairseq_dataset import FairseqDataset 9 | 10 | from .base_wrapper_dataset import BaseWrapperDataset 11 | 12 | from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset 13 | from .concat_dataset import ConcatDataset 14 | from .language_pair_dataset import LanguagePairDataset 15 | from .append_token_dataset import AppendTokenDataset 16 | from .strip_token_dataset import StripTokenDataset 17 | from .truncate_dataset import TruncateDataset 18 | 19 | from .iterators import ( 20 | CountingIterator, 21 | EpochBatchIterator, 22 | GroupedIterator, 23 | ShardedIterator, 24 | ) 25 | 26 | __all__ = [ 27 | 'BaseWrapperDataset', 28 | 'ConcatDataset', 29 | 'CountingIterator', 30 | 'Dictionary', 31 | 'EpochBatchIterator', 32 | 'FairseqDataset', 33 | 'GroupedIterator', 34 | 'IndexedCachedDataset', 35 | 'IndexedDataset', 36 | 'IndexedRawTextDataset', 37 | 'LanguagePairDataset', 38 | 'StripTokenDataset', 39 | 'TruncateDataset', 40 | 'AppendTokenDataset' 41 | ] 42 | -------------------------------------------------------------------------------- /fairseq/data/append_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from . import BaseWrapperDataset 10 | 11 | 12 | class AppendTokenDataset(BaseWrapperDataset): 13 | 14 | def __init__(self, dataset, token=None): 15 | super().__init__(dataset) 16 | self.token = token 17 | if token is not None: 18 | self._sizes = np.array(dataset.sizes) + 1 19 | else: 20 | self._sizes = dataset.sizes 21 | 22 | def __getitem__(self, idx): 23 | item = self.dataset[idx] 24 | if self.token is not None: 25 | item = torch.cat([item, item.new([self.token])]) 26 | return item 27 | 28 | @property 29 | def sizes(self): 30 | return self._sizes 31 | 32 | def num_tokens(self, index): 33 | n = self.dataset.num_tokens(index) 34 | if self.token is not None: 35 | n += 1 36 | return n 37 | 38 | def size(self, index): 39 | n = self.dataset.size(index) 40 | if self.token is not None: 41 | n += 1 42 | return n -------------------------------------------------------------------------------- /fairseq/data/base_wrapper_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | from . import FairseqDataset 9 | 10 | 11 | class BaseWrapperDataset(FairseqDataset): 12 | 13 | def __init__(self, dataset): 14 | super().__init__() 15 | self.dataset = dataset 16 | 17 | def __getitem__(self, index): 18 | return self.dataset[index] 19 | 20 | def __len__(self): 21 | return len(self.dataset) 22 | 23 | def collater(self, samples): 24 | if hasattr(self.dataset, 'collater'): 25 | return self.dataset.collater(samples) 26 | else: 27 | return default_collate(samples) 28 | 29 | @property 30 | def sizes(self): 31 | return self.dataset.sizes 32 | 33 | def num_tokens(self, index): 34 | return self.dataset.num_tokens(index) 35 | 36 | def size(self, index): 37 | return self.dataset.size(index) 38 | 39 | def ordered_indices(self): 40 | return self.dataset.ordered_indices() 41 | 42 | @property 43 | def supports_prefetch(self): 44 | return getattr(self.dataset, 'supports_prefetch', False) 45 | 46 | def prefetch(self, indices): 47 | self.dataset.prefetch(indices) 48 | 49 | def set_epoch(self, epoch): 50 | super().set_epoch(epoch) 51 | if hasattr(self.dataset, 'set_epoch'): 52 | self.dataset.set_epoch(epoch) 53 | -------------------------------------------------------------------------------- /fairseq/data/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import bisect 7 | 8 | import numpy as np 9 | from torch.utils.data.dataloader import default_collate 10 | 11 | from . import FairseqDataset 12 | 13 | 14 | class ConcatDataset(FairseqDataset): 15 | @staticmethod 16 | def cumsum(sequence, sample_ratios): 17 | r, s = [], 0 18 | for e, ratio in zip(sequence, sample_ratios): 19 | curr_len = int(ratio * len(e)) 20 | r.append(curr_len + s) 21 | s += curr_len 22 | return r 23 | 24 | def __init__(self, datasets, sample_ratios=1): 25 | super(ConcatDataset, self).__init__() 26 | assert len(datasets) > 0, "datasets should not be an empty iterable" 27 | self.datasets = list(datasets) 28 | if isinstance(sample_ratios, int): 29 | sample_ratios = [sample_ratios] * len(self.datasets) 30 | self.sample_ratios = sample_ratios 31 | self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) 32 | self.real_sizes = [len(d) for d in self.datasets] 33 | 34 | def __len__(self): 35 | return self.cumulative_sizes[-1] 36 | 37 | def __getitem__(self, idx): 38 | dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) 39 | return self.datasets[dataset_idx][sample_idx] 40 | 41 | def _get_dataset_and_sample_index(self, idx: int): 42 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 43 | if dataset_idx == 0: 44 | sample_idx = idx 45 | else: 46 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 47 | sample_idx = sample_idx % self.real_sizes[dataset_idx] 48 | return dataset_idx, sample_idx 49 | 50 | def collater(self, samples): 51 | # For now only supports datasets with same underlying collater implementations 52 | if hasattr(self.datasets[0], 'collater'): 53 | return self.datasets[0].collater(samples) 54 | else: 55 | return default_collate(samples) 56 | 57 | def size(self, idx: int): 58 | """ 59 | Return an example's size as a float or tuple. 60 | """ 61 | dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) 62 | return self.datasets[dataset_idx].size(sample_idx) 63 | 64 | def num_tokens(self, index: int): 65 | return np.max(self.size(index)) 66 | 67 | def attr(self, attr: str, index: int): 68 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) 69 | return getattr(self.datasets[dataset_idx], attr, None) 70 | 71 | @property 72 | def sizes(self): 73 | return np.concatenate( 74 | [np.tile(ds.sizes, sr) for ds, sr in zip(self.datasets, self.sample_ratios)] 75 | ) 76 | 77 | @property 78 | def supports_prefetch(self): 79 | return all(d.supports_prefetch for d in self.datasets) 80 | 81 | def ordered_indices(self): 82 | """ 83 | Returns indices sorted by length. So less padding is needed. 84 | """ 85 | return np.argsort(self.sizes) 86 | 87 | def prefetch(self, indices): 88 | frm = 0 89 | for to, ds in zip(self.cumulative_sizes, self.datasets): 90 | real_size = len(ds) 91 | if getattr(ds, 'supports_prefetch', False): 92 | ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) 93 | frm = to 94 | -------------------------------------------------------------------------------- /fairseq/data/data_utils_fast.pyx: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | cimport cython 9 | cimport numpy as np 10 | 11 | DTYPE = np.int64 12 | ctypedef np.int64_t DTYPE_t 13 | 14 | 15 | cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences): 16 | if len(batch) == 0: 17 | return 0 18 | if len(batch) == max_sentences: 19 | return 1 20 | if num_tokens > max_tokens: 21 | return 1 22 | return 0 23 | 24 | 25 | @cython.cdivision(True) 26 | cpdef list batch_by_size_fast( 27 | np.ndarray[DTYPE_t, ndim=1] indices, 28 | num_tokens_fn, 29 | long max_tokens, 30 | long max_sentences, 31 | int bsz_mult, 32 | ): 33 | cdef long sample_len = 0 34 | cdef list sample_lens = [] 35 | cdef list batch = [] 36 | cdef list batches = [] 37 | cdef long mod_len 38 | cdef long i 39 | cdef long idx 40 | cdef long num_tokens 41 | cdef DTYPE_t[:] indices_view = indices 42 | 43 | for i in range(len(indices_view)): 44 | idx = indices_view[i] 45 | num_tokens = num_tokens_fn(idx) 46 | sample_lens.append(num_tokens) 47 | sample_len = max(sample_len, num_tokens) 48 | 49 | assert sample_len <= max_tokens, ( 50 | "sentence at index {} of size {} exceeds max_tokens " 51 | "limit of {}!".format(idx, sample_len, max_tokens) 52 | ) 53 | num_tokens = (len(batch) + 1) * sample_len 54 | 55 | if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): 56 | mod_len = max( 57 | bsz_mult * (len(batch) // bsz_mult), 58 | len(batch) % bsz_mult, 59 | ) 60 | batches.append(batch[:mod_len]) 61 | batch = batch[mod_len:] 62 | sample_lens = sample_lens[mod_len:] 63 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 64 | batch.append(idx) 65 | if len(batch) > 0: 66 | batches.append(batch) 67 | return batches 68 | -------------------------------------------------------------------------------- /fairseq/data/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | 12 | 13 | build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry( 14 | '--tokenizer', 15 | default=None, 16 | ) 17 | 18 | 19 | build_bpe, register_bpe, BPE_REGISTRY = registry.setup_registry( 20 | '--bpe', 21 | default=None, 22 | ) 23 | 24 | 25 | # automatically import any Python files in the encoders/ directory 26 | for file in os.listdir(os.path.dirname(__file__)): 27 | if file.endswith('.py') and not file.startswith('_'): 28 | module = file[:file.find('.py')] 29 | importlib.import_module('fairseq.data.encoders.' + module) 30 | -------------------------------------------------------------------------------- /fairseq/data/encoders/fastbpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('fastbpe') 11 | class fastBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--bpe-codes', type=str, 17 | help='path to fastBPE BPE') 18 | # fmt: on 19 | 20 | def __init__(self, args): 21 | if args.bpe_codes is None: 22 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt') 23 | codes = file_utils.cached_path(args.bpe_codes) 24 | try: 25 | import fastBPE 26 | self.bpe = fastBPE.fastBPE(codes) 27 | self.bpe_symbol = "@@ " 28 | except ImportError: 29 | raise ImportError('Please install fastBPE with: pip install fastBPE') 30 | 31 | def encode(self, x: str) -> str: 32 | return self.bpe.apply([x])[0] 33 | 34 | def decode(self, x: str) -> str: 35 | return (x + ' ').replace(self.bpe_symbol, '').rstrip() 36 | -------------------------------------------------------------------------------- /fairseq/data/encoders/gpt2_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | from .gpt2_bpe_utils import get_encoder 10 | 11 | 12 | DEFAULT_ENCODER_JSON = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' 13 | DEFAULT_VOCAB_BPE = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' 14 | 15 | 16 | @register_bpe('gpt2') 17 | class GPT2BPE(object): 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | # fmt: off 22 | parser.add_argument('--gpt2-encoder-json', type=str, 23 | default=DEFAULT_ENCODER_JSON, 24 | help='path to encoder.json') 25 | parser.add_argument('--gpt2-vocab-bpe', type=str, 26 | default=DEFAULT_VOCAB_BPE, 27 | help='path to vocab.bpe') 28 | # fmt: on 29 | 30 | def __init__(self, args): 31 | encoder_json = file_utils.cached_path( 32 | getattr(args, 'gpt2_encoder_json', DEFAULT_ENCODER_JSON) 33 | ) 34 | vocab_bpe = file_utils.cached_path( 35 | getattr(args, 'gpt2_vocab_bpe', DEFAULT_VOCAB_BPE) 36 | ) 37 | self.bpe = get_encoder(encoder_json, vocab_bpe) 38 | 39 | def encode(self, x: str) -> str: 40 | return ' '.join(map(str, self.bpe.encode(x))) 41 | 42 | def decode(self, x: str) -> str: 43 | return self.bpe.decode(map(int, x.split())) 44 | 45 | def is_beginning_of_word(self, x: str) -> bool: 46 | return self.decode(x).startswith(' ') 47 | -------------------------------------------------------------------------------- /fairseq/data/encoders/gpt2_bpe_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Byte pair encoding utilities from GPT-2. 3 | 4 | Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py 5 | Original license: MIT 6 | """ 7 | 8 | from functools import lru_cache 9 | import json 10 | 11 | 12 | @lru_cache() 13 | def bytes_to_unicode(): 14 | """ 15 | Returns list of utf-8 byte and a corresponding list of unicode strings. 16 | The reversible bpe codes work on unicode strings. 17 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 18 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 19 | This is a signficant percentage of your normal, say, 32K bpe vocab. 20 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 21 | And avoids mapping to whitespace/control characters the bpe code barfs on. 22 | """ 23 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 24 | cs = bs[:] 25 | n = 0 26 | for b in range(2**8): 27 | if b not in bs: 28 | bs.append(b) 29 | cs.append(2**8+n) 30 | n += 1 31 | cs = [chr(n) for n in cs] 32 | return dict(zip(bs, cs)) 33 | 34 | def get_pairs(word): 35 | """Return set of symbol pairs in a word. 36 | Word is represented as tuple of symbols (symbols being variable-length strings). 37 | """ 38 | pairs = set() 39 | prev_char = word[0] 40 | for char in word[1:]: 41 | pairs.add((prev_char, char)) 42 | prev_char = char 43 | return pairs 44 | 45 | class Encoder: 46 | 47 | def __init__(self, encoder, bpe_merges, errors='replace'): 48 | self.encoder = encoder 49 | self.decoder = {v:k for k,v in self.encoder.items()} 50 | self.errors = errors # how to handle errors in decoding 51 | self.byte_encoder = bytes_to_unicode() 52 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 53 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 54 | self.cache = {} 55 | 56 | try: 57 | import regex as re 58 | self.re = re 59 | except ImportError: 60 | raise ImportError('Please install regex with: pip install regex') 61 | 62 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 63 | self.pat = self.re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 64 | 65 | def bpe(self, token): 66 | if token in self.cache: 67 | return self.cache[token] 68 | word = tuple(token) 69 | pairs = get_pairs(word) 70 | 71 | if not pairs: 72 | return token 73 | 74 | while True: 75 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 76 | if bigram not in self.bpe_ranks: 77 | break 78 | first, second = bigram 79 | new_word = [] 80 | i = 0 81 | while i < len(word): 82 | try: 83 | j = word.index(first, i) 84 | new_word.extend(word[i:j]) 85 | i = j 86 | except: 87 | new_word.extend(word[i:]) 88 | break 89 | 90 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 91 | new_word.append(first+second) 92 | i += 2 93 | else: 94 | new_word.append(word[i]) 95 | i += 1 96 | new_word = tuple(new_word) 97 | word = new_word 98 | if len(word) == 1: 99 | break 100 | else: 101 | pairs = get_pairs(word) 102 | word = ' '.join(word) 103 | self.cache[token] = word 104 | return word 105 | 106 | def encode(self, text): 107 | bpe_tokens = [] 108 | for token in self.re.findall(self.pat, text): 109 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 110 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 111 | return bpe_tokens 112 | 113 | def decode(self, tokens): 114 | text = ''.join([self.decoder[token] for token in tokens]) 115 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 116 | return text 117 | 118 | def get_encoder(encoder_json_path, vocab_bpe_path): 119 | with open(encoder_json_path, 'r') as f: 120 | encoder = json.load(f) 121 | with open(vocab_bpe_path, 'r', encoding="utf-8") as f: 122 | bpe_data = f.read() 123 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 124 | return Encoder( 125 | encoder=encoder, 126 | bpe_merges=bpe_merges, 127 | ) 128 | -------------------------------------------------------------------------------- /fairseq/data/encoders/hf_bert_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_bpe 7 | 8 | 9 | @register_bpe('bert') 10 | class BertBPE(object): 11 | 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--bpe-cased', action='store_true', 16 | help='set for cased BPE', 17 | default=False) 18 | parser.add_argument('--bpe-vocab-file', type=str, 19 | help='bpe vocab file.') 20 | # fmt: on 21 | 22 | def __init__(self, args): 23 | try: 24 | from pytorch_transformers import BertTokenizer 25 | from pytorch_transformers.tokenization_utils import clean_up_tokenization 26 | except ImportError: 27 | raise ImportError( 28 | 'Please install 1.0.0 version of pytorch_transformers' 29 | 'with: pip install pytorch-transformers' 30 | ) 31 | 32 | if 'bpe_vocab_file' in args: 33 | self.bert_tokenizer = BertTokenizer( 34 | args.bpe_vocab_file, 35 | do_lower_case=not args.bpe_cased 36 | ) 37 | else: 38 | vocab_file_name = 'bert-base-cased' if args.bpe_cased else 'bert-base-uncased' 39 | self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) 40 | self.clean_up_tokenization = clean_up_tokenization 41 | 42 | def encode(self, x: str) -> str: 43 | return ' '.join(self.bert_tokenizer.tokenize(x)) 44 | 45 | def decode(self, x: str) -> str: 46 | return self.clean_up_tokenization( 47 | self.bert_tokenizer.convert_tokens_to_string(x.split(' ')) 48 | ) 49 | 50 | def is_beginning_of_word(self, x: str) -> bool: 51 | return not x.startswith('##') 52 | -------------------------------------------------------------------------------- /fairseq/data/encoders/moses_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_tokenizer 7 | 8 | 9 | @register_tokenizer('moses') 10 | class MosesTokenizer(object): 11 | 12 | @staticmethod 13 | def add_args(parser): 14 | # fmt: off 15 | parser.add_argument('--moses-source-lang', metavar='SRC', 16 | help='source language') 17 | parser.add_argument('--moses-target-lang', metavar='TARGET', 18 | help='target language') 19 | parser.add_argument('--moses-no-dash-splits', action='store_true', default=False, 20 | help='don\'t apply dash split rules') 21 | parser.add_argument('--moses-no-escape', action='store_true', default=False, 22 | help='don\'t perform HTML escaping on apostrophy, quotes, etc.') 23 | # fmt: on 24 | 25 | def __init__(self, args): 26 | self.args = args 27 | 28 | if getattr(args, 'moses_source_lang', None) is None: 29 | args.moses_source_lang = getattr(args, 'source_lang', 'en') 30 | if getattr(args, 'moses_target_lang', None) is None: 31 | args.moses_target_lang = getattr(args, 'target_lang', 'en') 32 | 33 | try: 34 | from sacremoses import MosesTokenizer, MosesDetokenizer 35 | self.tok = MosesTokenizer(args.moses_source_lang) 36 | self.detok = MosesDetokenizer(args.moses_target_lang) 37 | except ImportError: 38 | raise ImportError('Please install Moses tokenizer with: pip install sacremoses') 39 | 40 | def encode(self, x: str) -> str: 41 | return self.tok.tokenize( 42 | x, 43 | aggressive_dash_splits=(not self.args.moses_no_dash_splits), 44 | return_str=True, 45 | escape=(not self.args.moses_no_escape), 46 | ) 47 | 48 | def decode(self, x: str) -> str: 49 | return self.detok.detokenize(x.split()) 50 | -------------------------------------------------------------------------------- /fairseq/data/encoders/nltk_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.data.encoders import register_tokenizer 7 | 8 | 9 | @register_tokenizer('nltk') 10 | class NLTKTokenizer(object): 11 | 12 | def __init__(self, source_lang=None, target_lang=None): 13 | try: 14 | from nltk.tokenize import word_tokenize 15 | self.word_tokenize = word_tokenize 16 | except ImportError: 17 | raise ImportError('Please install nltk with: pip install nltk') 18 | 19 | def encode(self, x: str) -> str: 20 | return ' '.join(self.word_tokenize(x)) 21 | 22 | def decode(self, x: str) -> str: 23 | return x 24 | -------------------------------------------------------------------------------- /fairseq/data/encoders/sentencepiece_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('sentencepiece') 11 | class SentencepieceBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--sentencepiece-vocab', type=str, 17 | help='path to sentencepiece vocab') 18 | # fmt: on 19 | 20 | def __init__(self, args): 21 | vocab = file_utils.cached_path(args.sentencepiece_vocab) 22 | try: 23 | import sentencepiece as spm 24 | self.sp = spm.SentencePieceProcessor() 25 | self.sp.Load(vocab) 26 | except ImportError: 27 | raise ImportError('Please install sentencepiece with: pip install sentencepiece') 28 | 29 | def encode(self, x: str) -> str: 30 | return ' '.join(self.sp.EncodeAsPieces(x)) 31 | 32 | def decode(self, x: str) -> str: 33 | return x.replace(' ', '').replace('\u2581', ' ').strip() 34 | -------------------------------------------------------------------------------- /fairseq/data/encoders/space_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | from fairseq.data.encoders import register_tokenizer 9 | 10 | 11 | @register_tokenizer('space') 12 | class SpaceTokenizer(object): 13 | 14 | def __init__(self, source_lang=None, target_lang=None): 15 | self.space_tok = re.compile(r"\s+") 16 | 17 | def encode(self, x: str) -> str: 18 | return self.space_tok.sub(' ', x) 19 | 20 | def decode(self, x: str) -> str: 21 | return x 22 | -------------------------------------------------------------------------------- /fairseq/data/encoders/subword_nmt_bpe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq import file_utils 7 | from fairseq.data.encoders import register_bpe 8 | 9 | 10 | @register_bpe('subword_nmt') 11 | class SubwordNMTBPE(object): 12 | 13 | @staticmethod 14 | def add_args(parser): 15 | # fmt: off 16 | parser.add_argument('--bpe-codes', type=str, 17 | help='path to subword NMT BPE') 18 | parser.add_argument('--bpe-separator', default='@@', 19 | help='BPE separator') 20 | # fmt: on 21 | 22 | def __init__(self, args): 23 | if args.bpe_codes is None: 24 | raise ValueError('--bpe-codes is required for --bpe=subword_nmt') 25 | codes = file_utils.cached_path(args.bpe_codes) 26 | try: 27 | from subword_nmt import apply_bpe 28 | bpe_parser = apply_bpe.create_parser() 29 | bpe_args = bpe_parser.parse_args([ 30 | '--codes', codes, 31 | '--separator', args.bpe_separator, 32 | ]) 33 | self.bpe = apply_bpe.BPE( 34 | bpe_args.codes, 35 | bpe_args.merges, 36 | bpe_args.separator, 37 | None, 38 | bpe_args.glossaries, 39 | ) 40 | self.bpe_symbol = bpe_args.separator + ' ' 41 | except ImportError: 42 | raise ImportError('Please install subword_nmt with: pip install subword-nmt') 43 | 44 | def encode(self, x: str) -> str: 45 | return self.bpe.process_line(x) 46 | 47 | def decode(self, x: str) -> str: 48 | return (x + ' ').replace(self.bpe_symbol, '').rstrip() 49 | -------------------------------------------------------------------------------- /fairseq/data/fairseq_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch.utils.data 8 | 9 | 10 | class FairseqDataset(torch.utils.data.Dataset): 11 | """A dataset that provides helpers for batching.""" 12 | 13 | def __getitem__(self, index): 14 | raise NotImplementedError 15 | 16 | def __len__(self): 17 | raise NotImplementedError 18 | 19 | def collater(self, samples): 20 | """Merge a list of samples to form a mini-batch. 21 | 22 | Args: 23 | samples (List[dict]): samples to collate 24 | 25 | Returns: 26 | dict: a mini-batch suitable for forwarding with a Model 27 | """ 28 | raise NotImplementedError 29 | 30 | def num_tokens(self, index): 31 | """Return the number of tokens in a sample. This value is used to 32 | enforce ``--max-tokens`` during batching.""" 33 | raise NotImplementedError 34 | 35 | def size(self, index): 36 | """Return an example's size as a float or tuple. This value is used when 37 | filtering a dataset with ``--max-positions``.""" 38 | raise NotImplementedError 39 | 40 | def ordered_indices(self): 41 | """Return an ordered list of indices. Batches will be constructed based 42 | on this order.""" 43 | return np.arange(len(self)) 44 | 45 | @property 46 | def supports_prefetch(self): 47 | """Whether this dataset supports prefetching.""" 48 | return False 49 | 50 | def attr(self, attr: str, index: int): 51 | return getattr(self, attr, None) 52 | 53 | def prefetch(self, indices): 54 | """Prefetch the data required for this epoch.""" 55 | raise NotImplementedError 56 | 57 | def set_epoch(self, epoch): 58 | pass 59 | -------------------------------------------------------------------------------- /fairseq/data/strip_token_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import BaseWrapperDataset 7 | 8 | 9 | class StripTokenDataset(BaseWrapperDataset): 10 | 11 | def __init__(self, dataset, id_to_strip): 12 | super().__init__(dataset) 13 | self.id_to_strip = id_to_strip 14 | 15 | def __getitem__(self, index): 16 | item = self.dataset[index] 17 | return item[item.ne(self.id_to_strip)] -------------------------------------------------------------------------------- /fairseq/data/truncate_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | from . import BaseWrapperDataset 9 | 10 | 11 | class TruncateDataset(BaseWrapperDataset): 12 | 13 | def __init__(self, dataset, truncation_length): 14 | super().__init__(dataset) 15 | assert truncation_length is not None 16 | self.truncation_length = truncation_length 17 | self.dataset = dataset 18 | 19 | def __getitem__(self, index): 20 | item = self.dataset[index] 21 | item_len = item.size(0) 22 | if item_len > self.truncation_length: 23 | item = item[:self.truncation_length] 24 | return item 25 | 26 | @property 27 | def sizes(self): 28 | return np.minimum(self.dataset.sizes, self.truncation_length) 29 | 30 | def __len__(self): 31 | return len(self.dataset) -------------------------------------------------------------------------------- /fairseq/file_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import shutil 10 | from typing import List, Optional 11 | 12 | 13 | try: 14 | from fvcore.common.file_io import PathManager as FVCorePathManager 15 | 16 | except (ImportError, ModuleNotFoundError): 17 | FVCorePathManager = None 18 | 19 | 20 | class PathManager: 21 | """ 22 | Wrapper for insulating OSS I/O (using Python builtin operations) from 23 | fvcore's PathManager abstraction (for transparently handling various 24 | internal backends). 25 | """ 26 | 27 | @staticmethod 28 | def open( 29 | path: str, 30 | mode: str = "r", 31 | buffering: int = -1, 32 | encoding: Optional[str] = None, 33 | errors: Optional[str] = None, 34 | newline: Optional[str] = None, 35 | ): 36 | if FVCorePathManager: 37 | return FVCorePathManager.open( 38 | path=path, 39 | mode=mode, 40 | buffering=buffering, 41 | encoding=encoding, 42 | errors=errors, 43 | newline=newline, 44 | ) 45 | return open( 46 | path, 47 | mode=mode, 48 | buffering=buffering, 49 | encoding=encoding, 50 | errors=errors, 51 | newline=newline, 52 | ) 53 | 54 | @staticmethod 55 | def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool: 56 | if FVCorePathManager: 57 | return FVCorePathManager.copy( 58 | src_path=src_path, dst_path=dst_path, overwrite=overwrite 59 | ) 60 | return shutil.copyfile(src_path, dst_path) 61 | 62 | @staticmethod 63 | def get_local_path(path: str) -> str: 64 | if FVCorePathManager: 65 | return FVCorePathManager.get_local_path(path) 66 | return path 67 | 68 | @staticmethod 69 | def exists(path: str) -> bool: 70 | if FVCorePathManager: 71 | return FVCorePathManager.exists(path) 72 | return os.path.exists(path) 73 | 74 | @staticmethod 75 | def isfile(path: str) -> bool: 76 | if FVCorePathManager: 77 | return FVCorePathManager.isfile(path) 78 | return os.path.isfile(path) 79 | 80 | @staticmethod 81 | def ls(path: str) -> List[str]: 82 | if FVCorePathManager: 83 | return FVCorePathManager.ls(path) 84 | return os.listdir(path) 85 | 86 | @staticmethod 87 | def mkdirs(path: str) -> None: 88 | if FVCorePathManager: 89 | return FVCorePathManager.mkdirs(path) 90 | os.makedirs(path, exist_ok=True) 91 | 92 | @staticmethod 93 | def rm(path: str) -> None: 94 | if FVCorePathManager: 95 | return FVCorePathManager.rm(path) 96 | os.remove(path) 97 | 98 | @staticmethod 99 | def register_handler(handler) -> None: 100 | if FVCorePathManager: 101 | return FVCorePathManager.register_handler(handler=handler) 102 | -------------------------------------------------------------------------------- /fairseq/init.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | import numpy as np 6 | 7 | import functools 8 | 9 | uniform_ = None 10 | 11 | def next_power_of_2(x): 12 | return 1 if x == 0 else 2**math.ceil(math.log2(x)) 13 | 14 | def build_init(args): 15 | global uniform_ 16 | args.init_method = getattr(args, 'init_method', 'xavier') 17 | if args.init_method == 'kaiming': 18 | uniform_ = kaiming_uniform_ 19 | elif args.init_method == 'kaiming_fanout': 20 | uniform_ = functools.partial(kaiming_uniform_, mode='fan_out') 21 | elif args.init_method == 'xavier1_2': 22 | uniform_ = xavier_uniform1_2_ 23 | elif 'xavier_origin_ratio' in args.init_method: 24 | origin_ffn_ratio = float(args.init_method.split(':')[1]) 25 | uniform_ = functools.partial(xavier_uniform_origin_ratio_, ratio=origin_ffn_ratio) 26 | elif args.init_method == 'xavier2exp': 27 | uniform_ = xavier_uniform_2exp_ 28 | elif args.init_method == 'xavier2exp_ratio': 29 | uniform_ = xavier_uniform_2exp_same_ratio_ 30 | elif 'gain' in args.init_method: 31 | gain = float(args.init_method.split(':')[1]) 32 | print("initialization gain:", gain) 33 | uniform_ = functools.partial(xavier_uniform_gain_, gain=gain) 34 | elif 'xavier_non_linear' in args.init_method: 35 | gain_ratio = float(args.init_method.split(':')[1]) 36 | uniform_ = functools.partial(xavier_uniform_non_linear_, gain_ratio=gain_ratio) 37 | else: 38 | print("[WARNING] Fallback to xavier initializer") 39 | uniform_ = None 40 | 41 | def xavier_uniform_non_linear_(tensor, gain_ratio=1., non_linear='linear'): 42 | return nn.init.xavier_uniform_(tensor, gain=gain_ratio * nn.init.calculate_gain(non_linear)) 43 | 44 | def xavier_uniform_origin_ratio_(tensor, gain=1., ratio=2, **kwargs): 45 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 46 | fan_out = ratio * fan_in 47 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 48 | a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 49 | 50 | return nn.init._no_grad_uniform_(tensor, -a, a) 51 | 52 | def kaiming_uniform_(tensor, non_linear, mode='fan_in'): 53 | return nn.init.kaiming_uniform_(tensor, mode=mode, nonlinearity=non_linear) 54 | 55 | def xavier_uniform_gain_(tensor, gain=1., **kwargs): 56 | return nn.init.xavier_uniform_(tensor, gain) 57 | 58 | def xavier_uniform1_2_(tensor, gain=1., **kwargs): 59 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 60 | if fan_out < 2 * fan_in: 61 | fan_out = 2*fan_in 62 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 63 | a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 64 | 65 | return nn.init._no_grad_uniform_(tensor, -a, a) 66 | 67 | def xavier_uniform_2exp_(tensor, gain=1., **kwargs): 68 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 69 | fan_in = next_power_of_2(fan_in) 70 | fan_out = 2 * fan_in 71 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 72 | a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 73 | 74 | return nn.init._no_grad_uniform_(tensor, -a, a) 75 | 76 | def xavier_uniform_2exp_same_ratio_(tensor, gain=1., **kwargs): 77 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 78 | ratio = fan_out / fan_in 79 | fan_in = next_power_of_2(fan_in) 80 | fan_out = fan_in * ratio 81 | std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 82 | a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 83 | 84 | return nn.init._no_grad_uniform_(tensor, -a, a) 85 | 86 | -------------------------------------------------------------------------------- /fairseq/meters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import time 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | class TimeMeter(object): 28 | """Computes the average occurrence of some event per second""" 29 | def __init__(self, init=0): 30 | self.reset(init) 31 | 32 | def reset(self, init=0): 33 | self.init = init 34 | self.start = time.time() 35 | self.n = 0 36 | 37 | def update(self, val=1): 38 | self.n += val 39 | 40 | @property 41 | def avg(self): 42 | return self.n / self.elapsed_time 43 | 44 | @property 45 | def elapsed_time(self): 46 | return self.init + (time.time() - self.start) 47 | 48 | 49 | class StopwatchMeter(object): 50 | """Computes the sum/avg duration of some event in seconds""" 51 | def __init__(self): 52 | self.reset() 53 | 54 | def start(self): 55 | self.start_time = time.time() 56 | 57 | def stop(self, n=1): 58 | if self.start_time is not None: 59 | delta = time.time() - self.start_time 60 | self.sum += delta 61 | self.n += n 62 | self.start_time = None 63 | 64 | def reset(self): 65 | self.sum = 0 66 | self.n = 0 67 | self.start_time = None 68 | 69 | @property 70 | def avg(self): 71 | return self.sum / self.n 72 | -------------------------------------------------------------------------------- /fairseq/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import importlib 8 | import os 9 | 10 | from .fairseq_decoder import FairseqDecoder 11 | from .fairseq_encoder import FairseqEncoder 12 | from .fairseq_incremental_decoder import FairseqIncrementalDecoder 13 | from .fairseq_model import ( 14 | BaseFairseqModel, 15 | FairseqEncoderModel, 16 | FairseqEncoderDecoderModel, 17 | FairseqLanguageModel, 18 | FairseqModel, 19 | FairseqMultiModel, 20 | ) 21 | 22 | from .distributed_fairseq_model import DistributedFairseqModel 23 | 24 | 25 | MODEL_REGISTRY = {} 26 | ARCH_MODEL_REGISTRY = {} 27 | ARCH_MODEL_INV_REGISTRY = {} 28 | ARCH_CONFIG_REGISTRY = {} 29 | 30 | 31 | __all__ = [ 32 | 'BaseFairseqModel', 33 | 'DistributedFairseqModel', 34 | 'FairseqDecoder', 35 | 'FairseqEncoder', 36 | 'FairseqEncoderDecoderModel', 37 | 'FairseqEncoderModel', 38 | 'FairseqIncrementalDecoder', 39 | 'FairseqLanguageModel', 40 | 'FairseqModel', 41 | 'FairseqMultiModel', 42 | ] 43 | 44 | 45 | def build_model(args, task): 46 | return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task) 47 | 48 | 49 | def register_model(name): 50 | """ 51 | New model types can be added to fairseq with the :func:`register_model` 52 | function decorator. 53 | 54 | For example:: 55 | 56 | @register_model('lstm') 57 | class LSTM(FairseqEncoderDecoderModel): 58 | (...) 59 | 60 | .. note:: All models must implement the :class:`BaseFairseqModel` interface. 61 | Typically you will extend :class:`FairseqEncoderDecoderModel` for 62 | sequence-to-sequence tasks or :class:`FairseqLanguageModel` for 63 | language modeling tasks. 64 | 65 | Args: 66 | name (str): the name of the model 67 | """ 68 | 69 | def register_model_cls(cls): 70 | if name in MODEL_REGISTRY: 71 | raise ValueError('Cannot register duplicate model ({})'.format(name)) 72 | if not issubclass(cls, BaseFairseqModel): 73 | raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__)) 74 | MODEL_REGISTRY[name] = cls 75 | return cls 76 | 77 | return register_model_cls 78 | 79 | 80 | def register_model_architecture(model_name, arch_name): 81 | """ 82 | New model architectures can be added to fairseq with the 83 | :func:`register_model_architecture` function decorator. After registration, 84 | model architectures can be selected with the ``--arch`` command-line 85 | argument. 86 | 87 | For example:: 88 | 89 | @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') 90 | def lstm_luong_wmt_en_de(args): 91 | args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) 92 | (...) 93 | 94 | The decorated function should take a single argument *args*, which is a 95 | :class:`argparse.Namespace` of arguments parsed from the command-line. The 96 | decorated function should modify these arguments in-place to match the 97 | desired architecture. 98 | 99 | Args: 100 | model_name (str): the name of the Model (Model must already be 101 | registered) 102 | arch_name (str): the name of the model architecture (``--arch``) 103 | """ 104 | 105 | def register_model_arch_fn(fn): 106 | if model_name not in MODEL_REGISTRY: 107 | raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name)) 108 | if arch_name in ARCH_MODEL_REGISTRY: 109 | raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name)) 110 | if not callable(fn): 111 | raise ValueError('Model architecture must be callable ({})'.format(arch_name)) 112 | ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] 113 | ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) 114 | ARCH_CONFIG_REGISTRY[arch_name] = fn 115 | return fn 116 | 117 | return register_model_arch_fn 118 | 119 | 120 | # automatically import any Python files in the models/ directory 121 | models_dir = os.path.dirname(__file__) 122 | for file in os.listdir(models_dir): 123 | path = os.path.join(models_dir, file) 124 | if not file.startswith('_') and not file.startswith('.') and (file.endswith('.py') or os.path.isdir(path)): 125 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 126 | module = importlib.import_module('fairseq.models.' + model_name) 127 | 128 | # extra `model_parser` for sphinx 129 | if model_name in MODEL_REGISTRY: 130 | parser = argparse.ArgumentParser(add_help=False) 131 | group_archs = parser.add_argument_group('Named architectures') 132 | group_archs.add_argument('--arch', choices=ARCH_MODEL_INV_REGISTRY[model_name]) 133 | group_args = parser.add_argument_group('Additional command-line arguments') 134 | MODEL_REGISTRY[model_name].add_args(group_args) 135 | globals()[model_name + '_parser'] = parser 136 | -------------------------------------------------------------------------------- /fairseq/models/distributed_fairseq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import inspect 7 | 8 | import torch.nn as nn 9 | 10 | from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel 11 | from fairseq.models import BaseFairseqModel 12 | 13 | 14 | def DistributedFairseqModel(args, model): 15 | """ 16 | Wrap a *model* to support distributed data parallel training. 17 | 18 | This is similar to the built-in DistributedDataParallel, but allows 19 | additional configuration of the DistributedDataParallel class to 20 | use, and also provides easier access to the wrapped model by 21 | forwarding requests for missing attributes to the wrapped model. 22 | 23 | Args: 24 | args (argparse.Namespace): fairseq args 25 | model (BaseFairseqModel): model to wrap 26 | """ 27 | # determine which DDP class to extend 28 | assert isinstance(model, nn.Module) 29 | if args.ddp_backend == 'c10d': 30 | ddp_class = nn.parallel.DistributedDataParallel 31 | init_kwargs = dict( 32 | module=model, 33 | device_ids=[args.device_id], 34 | output_device=args.device_id, 35 | broadcast_buffers=False, 36 | bucket_cap_mb=args.bucket_cap_mb, 37 | ) 38 | # Maintain backward compatibility 39 | if 'check_reduction' in inspect.getargspec(ddp_class)[0]: 40 | init_kwargs['check_reduction'] = True 41 | if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]: 42 | init_kwargs['find_unused_parameters'] = args.find_unused_parameters 43 | elif args.ddp_backend == 'no_c10d': 44 | ddp_class = LegacyDistributedDataParallel 45 | init_kwargs = dict( 46 | module=model, 47 | world_size=args.distributed_world_size, 48 | buffer_size=2**28, 49 | ) 50 | else: 51 | raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend) 52 | 53 | class _DistributedFairseqModel(ddp_class): 54 | """Extend DistributedDataParallel to check for missing 55 | attributes in the wrapped module.""" 56 | 57 | def __init__(self, *args, **kwargs): 58 | super().__init__(*args, **kwargs) 59 | 60 | def __getattr__(self, name): 61 | wrapped_module = super().__getattr__('module') 62 | if hasattr(wrapped_module, name): 63 | return getattr(wrapped_module, name) 64 | return super().__getattr__(name) 65 | 66 | return _DistributedFairseqModel(**init_kwargs) 67 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | from fairseq import utils 9 | 10 | 11 | class FairseqDecoder(nn.Module): 12 | """Base class for decoders.""" 13 | 14 | def __init__(self, dictionary): 15 | super().__init__() 16 | self.dictionary = dictionary 17 | self.onnx_trace = False 18 | 19 | def forward(self, prev_output_tokens, encoder_out=None, **kwargs): 20 | """ 21 | Args: 22 | prev_output_tokens (LongTensor): shifted output tokens of shape 23 | `(batch, tgt_len)`, for teacher forcing 24 | encoder_out (dict, optional): output from the encoder, used for 25 | encoder-side attention 26 | 27 | Returns: 28 | tuple: 29 | - the decoder's output of shape `(batch, tgt_len, vocab)` 30 | - a dictionary with any model-specific outputs 31 | """ 32 | x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs) 33 | x = self.output_layer(x) 34 | return x, extra 35 | 36 | def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs): 37 | """ 38 | Returns: 39 | tuple: 40 | - the decoder's features of shape `(batch, tgt_len, embed_dim)` 41 | - a dictionary with any model-specific outputs 42 | """ 43 | raise NotImplementedError 44 | 45 | def output_layer(self, features, **kwargs): 46 | """ 47 | Project features to the default output size, e.g., vocabulary size. 48 | 49 | Args: 50 | features (Tensor): features returned by *extract_features*. 51 | """ 52 | raise NotImplementedError 53 | 54 | def get_normalized_probs(self, net_output, log_probs, sample): 55 | """Get normalized probabilities (or log probs) from a net's output.""" 56 | 57 | if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: 58 | if sample is not None: 59 | assert 'target' in sample 60 | target = sample['target'] 61 | else: 62 | target = None 63 | out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) 64 | return out.exp_() if not log_probs else out 65 | 66 | logits = net_output[0] 67 | if log_probs: 68 | return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) 69 | else: 70 | return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) 71 | 72 | def max_positions(self): 73 | """Maximum input length supported by the decoder.""" 74 | return 1e6 # an arbitrary large number 75 | 76 | def upgrade_state_dict(self, state_dict): 77 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 78 | return state_dict 79 | 80 | def prepare_for_onnx_export_(self): 81 | self.onnx_trace = True 82 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class FairseqEncoder(nn.Module): 10 | """Base class for encoders.""" 11 | 12 | def __init__(self, dictionary): 13 | super().__init__() 14 | self.dictionary = dictionary 15 | 16 | def forward(self, src_tokens, src_lengths=None, **kwargs): 17 | """ 18 | Args: 19 | src_tokens (LongTensor): tokens in the source language of shape 20 | `(batch, src_len)` 21 | src_lengths (LongTensor): lengths of each source sentence of shape 22 | `(batch)` 23 | """ 24 | raise NotImplementedError 25 | 26 | def reorder_encoder_out(self, encoder_out, new_order): 27 | """ 28 | Reorder encoder output according to `new_order`. 29 | 30 | Args: 31 | encoder_out: output from the ``forward()`` method 32 | new_order (LongTensor): desired order 33 | 34 | Returns: 35 | `encoder_out` rearranged according to `new_order` 36 | """ 37 | raise NotImplementedError 38 | 39 | def max_positions(self): 40 | """Maximum input length supported by the encoder.""" 41 | return 1e6 # an arbitrary large number 42 | 43 | def upgrade_state_dict(self, state_dict): 44 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 45 | return state_dict 46 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_incremental_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from fairseq.models import FairseqDecoder 7 | 8 | 9 | class FairseqIncrementalDecoder(FairseqDecoder): 10 | """Base class for incremental decoders. 11 | 12 | Incremental decoding is a special mode at inference time where the Model 13 | only receives a single timestep of input corresponding to the previous 14 | output token (for teacher forcing) and must produce the next output 15 | *incrementally*. Thus the model must cache any long-term state that is 16 | needed about the sequence, e.g., hidden states, convolutional states, etc. 17 | 18 | Compared to the standard :class:`FairseqDecoder` interface, the incremental 19 | decoder interface allows :func:`forward` functions to take an extra keyword 20 | argument (*incremental_state*) that can be used to cache state across 21 | time-steps. 22 | 23 | The :class:`FairseqIncrementalDecoder` interface also defines the 24 | :func:`reorder_incremental_state` method, which is used during beam search 25 | to select and reorder the incremental state based on the selection of beams. 26 | 27 | To learn more about how incremental decoding works, refer to `this blog 28 | <http://www.telesens.co/2019/04/21/understanding-incremental-decoding-in-fairseq/>`_. 29 | """ 30 | 31 | def __init__(self, dictionary): 32 | super().__init__(dictionary) 33 | 34 | def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): 35 | """ 36 | Args: 37 | prev_output_tokens (LongTensor): shifted output tokens of shape 38 | `(batch, tgt_len)`, for teacher forcing 39 | encoder_out (dict, optional): output from the encoder, used for 40 | encoder-side attention 41 | incremental_state (dict, optional): dictionary used for storing 42 | state during :ref:`Incremental decoding` 43 | 44 | Returns: 45 | tuple: 46 | - the decoder's output of shape `(batch, tgt_len, vocab)` 47 | - a dictionary with any model-specific outputs 48 | """ 49 | raise NotImplementedError 50 | 51 | def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): 52 | """ 53 | Returns: 54 | tuple: 55 | - the decoder's features of shape `(batch, tgt_len, embed_dim)` 56 | - a dictionary with any model-specific outputs 57 | """ 58 | raise NotImplementedError 59 | 60 | def reorder_incremental_state(self, incremental_state, new_order): 61 | """Reorder incremental state. 62 | 63 | This should be called when the order of the input has changed from the 64 | previous time step. A typical use case is beam search, where the input 65 | order changes between time steps based on the selection of beams. 66 | """ 67 | seen = set() 68 | 69 | def apply_reorder_incremental_state(module): 70 | if module != self and hasattr(module, 'reorder_incremental_state') \ 71 | and module not in seen: 72 | seen.add(module) 73 | module.reorder_incremental_state(incremental_state, new_order) 74 | 75 | self.apply(apply_reorder_incremental_state) 76 | 77 | def set_beam_size(self, beam_size): 78 | """Sets the beam size in the decoder and all children.""" 79 | if getattr(self, '_beam_size', -1) != beam_size: 80 | seen = set() 81 | 82 | def apply_set_beam_size(module): 83 | if module != self and hasattr(module, 'set_beam_size') \ 84 | and module not in seen: 85 | seen.add(module) 86 | module.set_beam_size(beam_size) 87 | 88 | self.apply(apply_set_beam_size) 89 | self._beam_size = beam_size 90 | -------------------------------------------------------------------------------- /fairseq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .dynamic_convolution import DynamicConv, DynamicConv1dTBC 7 | from .gelu import gelu, gelu_accurate 8 | from .layer_norm import LayerNorm 9 | from .lightweight_convolution import LightweightConv, LightweightConv1dTBC 10 | from .multihead_attention import MultiheadAttention 11 | from .positional_embedding import PositionalEmbedding 12 | from .learned_positional_embedding import LearnedPositionalEmbedding 13 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 14 | from .multibranch import MultiBranch 15 | from .adaptive_softmax import AdaptiveSoftmax 16 | 17 | __all__ = [ 18 | 'AdaptiveSoftmax', 19 | 'DynamicConv1dTBC', 20 | 'DynamicConv', 21 | 'gelu', 22 | 'gelu_accurate', 23 | 'LayerNorm', 24 | 'LightweightConv1dTBC', 25 | 'LightweightConv', 26 | 'MultiheadAttention', 27 | 'MultiBranch', 28 | 'PositionalEmbedding', 29 | 'LearnedPositionalEmbedding', 30 | 'SinusoidalPositionalEmbedding', 31 | ] 32 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .dynamicconv_layer import DynamicconvLayer # noqa 7 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include <torch/extension.h> 9 | #include <vector> 10 | 11 | std::vector<at::Tensor> dynamicconv_cuda_forward( 12 | at::Tensor input, 13 | at::Tensor filters, 14 | int padding_l); 15 | 16 | std::vector<at::Tensor> dynamicconv_cuda_backward( 17 | at::Tensor gradOutput, 18 | int padding_l, 19 | at::Tensor input, 20 | at::Tensor filters); 21 | 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector<at::Tensor> dynamicconv_forward( 28 | at::Tensor input, 29 | at::Tensor filters, 30 | int padding_l) { 31 | 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(filters); 34 | 35 | return dynamicconv_cuda_forward(input, filters, 36 | padding_l); 37 | } 38 | 39 | std::vector<at::Tensor> dynamicconv_backward( 40 | at::Tensor gradOutput, 41 | int padding_l, 42 | at::Tensor input, 43 | at::Tensor filters) { 44 | 45 | CHECK_INPUT(gradOutput); 46 | CHECK_INPUT(input); 47 | CHECK_INPUT(filters); 48 | 49 | return dynamicconv_cuda_backward(gradOutput, padding_l, 50 | input, filters); 51 | } 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); 55 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); 56 | } 57 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include <ATen/ATen.h> 9 | #include <c10/cuda/CUDAStream.h> 10 | 11 | #include <cuda.h> 12 | #include <cuda_fp16.h> 13 | #include <cuda_runtime.h> 14 | 15 | #include <algorithm> 16 | #include <functional> 17 | #include <iostream> 18 | #include <stdexcept> 19 | #include <utility> 20 | #include <vector> 21 | 22 | #include <stdlib.h> 23 | #include <assert.h> 24 | #include <math.h> 25 | 26 | #define SHFL_MASK 0xffffffff 27 | 28 | template<int FS, int SB, int padding_l, typename scalar_t> 29 | __global__ 30 | void dynamicconv_forward_kernel(const scalar_t* input, 31 | const scalar_t* weight, 32 | int minibatch, 33 | int sequenceLength, 34 | int numFeatures, 35 | int numFiltersInBlock, 36 | int numHeads, 37 | scalar_t* output); 38 | 39 | template<int FS, int SB, int padding_l, typename scalar_t> 40 | __global__ 41 | void dynamicconv_backward_kernel( 42 | const scalar_t* gradOutput, // B * C * T 43 | const scalar_t* input, // B * C * T 44 | const scalar_t* weight, 45 | int minibatch, 46 | int sequenceLength, 47 | int numFeatures, 48 | int numFiltersInBlock, 49 | int numHeads, 50 | scalar_t* gradWeight, 51 | scalar_t* gradInput); // B * H * k * T 52 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/torch.h> 2 | #include <vector> 3 | 4 | std::vector<float*> dynamicconv_cpu_forward( 5 | float* input, 6 | float* filters, 7 | int padding_l); 8 | 9 | std::vector<float*> dynamicconv_cpu_backward( 10 | float* gradOutput, 11 | int padding_l, 12 | float* input, 13 | float* filters); 14 | 15 | std::vector<float*> dynamicconv_forward( 16 | float* input, 17 | float* filters, 18 | int padding_l) { 19 | 20 | return dynamicconv_cpu_forward(input, filters, padding_l); 21 | } 22 | 23 | std::vector<float*> dynamicconv_backward( 24 | float* gradOutput, 25 | int padding_l, 26 | float* input, 27 | float* filters) { 28 | 29 | return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); 30 | } 31 | 32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 33 | m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); 34 | m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); 35 | } 36 | -------------------------------------------------------------------------------- /fairseq/modules/dynamicconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 9 | 10 | setup( 11 | name='dynamicconv_layer', 12 | ext_modules=[ 13 | CUDAExtension( 14 | name='dynamicconv_cuda', 15 | sources=[ 16 | 'dynamicconv_cuda.cpp', 17 | 'dynamicconv_cuda_kernel.cu', 18 | ], 19 | ), 20 | ], 21 | cmdclass={ 22 | 'build_ext': BuildExtension 23 | }) 24 | -------------------------------------------------------------------------------- /fairseq/modules/gelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """ 6 | See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with 7 | the corresponding GitHub repo: https://github.com/hendrycks/GELUs 8 | """ 9 | 10 | import math 11 | 12 | import torch 13 | 14 | 15 | def gelu_accurate(x): 16 | if not hasattr(gelu_accurate, "_a"): 17 | gelu_accurate._a = math.sqrt(2 / math.pi) 18 | return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 19 | 20 | 21 | def gelu(x: torch.Tensor) -> torch.Tensor: 22 | if hasattr(torch.nn.functional, 'gelu'): 23 | return torch.nn.functional.gelu(x.float()).type_as(x) 24 | else: 25 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 26 | -------------------------------------------------------------------------------- /fairseq/modules/layer_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 10 | if not export and torch.cuda.is_available(): 11 | try: 12 | from apex.normalization import FusedLayerNorm 13 | return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 14 | except ImportError: 15 | pass 16 | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) 17 | -------------------------------------------------------------------------------- /fairseq/modules/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | from fairseq import utils 9 | 10 | 11 | class LearnedPositionalEmbedding(nn.Embedding): 12 | """ 13 | This module learns positional embeddings up to a fixed maximum size. 14 | Padding ids are ignored by either offsetting based on padding_idx 15 | or by setting padding_idx to None and ensuring that the appropriate 16 | position ids are passed to the forward function. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | num_embeddings: int, 22 | embedding_dim: int, 23 | padding_idx: int, 24 | ): 25 | super().__init__(num_embeddings, embedding_dim, padding_idx) 26 | self.onnx_trace = False 27 | 28 | def forward(self, input, incremental_state=None, positions=None): 29 | """Input is expected to be of size [bsz x seqlen].""" 30 | assert ( 31 | (positions is None) or (self.padding_idx is None) 32 | ), "If positions is pre-computed then padding_idx should not be set." 33 | 34 | if positions is None: 35 | if incremental_state is not None: 36 | # positions is the same for every token when decoding a single step 37 | # Without the int() cast, it doesn't work in some cases when exporting to ONNX 38 | positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1))) 39 | else: 40 | positions = utils.make_positions( 41 | input.data, self.padding_idx, onnx_trace=self.onnx_trace, 42 | ) 43 | return super().forward(positions) 44 | 45 | def max_positions(self): 46 | """Maximum number of supported positions.""" 47 | if self.padding_idx is not None: 48 | return self.num_embeddings - self.padding_idx - 1 49 | else: 50 | return self.num_embeddings 51 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .lightconv_layer import LightconvLayer # noqa 7 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/lightconv_cuda.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include <torch/extension.h> 9 | #include <vector> 10 | 11 | std::vector<at::Tensor> lightconv_cuda_forward( 12 | at::Tensor input, 13 | at::Tensor filters, 14 | int padding_l); 15 | 16 | std::vector<at::Tensor> lightconv_cuda_backward( 17 | at::Tensor gradOutput, 18 | int padding_l, 19 | at::Tensor input, 20 | at::Tensor filters); 21 | 22 | 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector<at::Tensor> lightconv_forward( 28 | at::Tensor input, 29 | at::Tensor filters, 30 | int padding_l) { 31 | 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(filters); 34 | 35 | return lightconv_cuda_forward(input, filters, padding_l); 36 | } 37 | 38 | std::vector<at::Tensor> lightconv_backward( 39 | at::Tensor gradOutput, 40 | int padding_l, 41 | at::Tensor input, 42 | at::Tensor filters) { 43 | 44 | CHECK_INPUT(gradOutput); 45 | CHECK_INPUT(input); 46 | CHECK_INPUT(filters); 47 | 48 | return lightconv_cuda_backward(gradOutput, padding_l, input, filters); 49 | } 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); 53 | m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); 54 | } 55 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/lightconv_cuda.cuh: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include <ATen/ATen.h> 9 | #include <c10/cuda/CUDAStream.h> 10 | 11 | #include <cuda.h> 12 | #include <cuda_runtime.h> 13 | 14 | #include <algorithm> 15 | #include <functional> 16 | #include <iostream> 17 | #include <stdexcept> 18 | #include <utility> 19 | #include <vector> 20 | 21 | #include <stdlib.h> 22 | #include <assert.h> 23 | 24 | #define SHFL_MASK 0xffffffff 25 | 26 | template<int FS, int SB, int padding_l, typename scalar_t> 27 | __global__ 28 | void lightconv_forward_kernel(const scalar_t* input, 29 | const scalar_t* filters, 30 | int minibatch, int sequenceLength, 31 | int numFeatures, int numFiltersInBlock, 32 | scalar_t* output); 33 | 34 | template<int FS, int SB, int padding_l, typename scalar_t> 35 | __global__ 36 | void lightconv_grad_wrt_input_kernel( 37 | const scalar_t* input, 38 | const scalar_t* filters, 39 | int minibatch, 40 | int sequenceLength, 41 | int numFeatures, 42 | int numFiltersInBlock, 43 | scalar_t* output); 44 | 45 | template<int FS, int SB, int padding_l, typename scalar_t> 46 | __global__ 47 | void lightconv_grad_wrt_weights_firstpass_short_kernel( 48 | const scalar_t* input, 49 | const scalar_t* gradInput, 50 | int minibatch, 51 | int sequenceLength, 52 | int numFeatures, 53 | int numFiltersInBlock, 54 | int numHeads, 55 | float* output); 56 | 57 | template<int FS, int SB, typename scalar_t> 58 | __global__ 59 | void lightconv_grad_wrt_weights_secondpass_short_kernel( 60 | const float* input, 61 | const int minibatch, 62 | const int numFiltersInBlock, 63 | scalar_t* output); 64 | 65 | template<int FS, int SB, int padding_l, typename scalar_t> 66 | __global__ 67 | void lightconv_grad_wrt_weights_firstpass_kernel( 68 | const scalar_t* input, 69 | const scalar_t* gradInput, 70 | int minibatch, 71 | int sequenceLength, 72 | int numFeatures, 73 | int numFiltersInBlock, 74 | float* output); 75 | 76 | template<int FS, int SB, typename scalar_t> 77 | __global__ 78 | void lightconv_grad_wrt_weights_secondpass_kernel( 79 | const float* input, 80 | const int minibatch, 81 | const int numFiltersInBlock, 82 | scalar_t* output); 83 | 84 | -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/lightconv_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch import nn 8 | from torch.autograd import Function 9 | import torch.nn.functional as F 10 | 11 | import lightconv_cuda 12 | from fairseq import utils 13 | 14 | 15 | class lightconvFunction(Function): 16 | 17 | @staticmethod 18 | def forward(ctx, x, weights, padding_l): 19 | ctx.padding_l = padding_l 20 | outputs = lightconv_cuda.forward(x, weights, padding_l) 21 | variables = [x, weights] 22 | ctx.save_for_backward(*variables) 23 | return outputs[0] 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | outputs = lightconv_cuda.backward( 28 | grad_output.contiguous(), 29 | ctx.padding_l, 30 | *ctx.saved_variables) 31 | grad_input, grad_weights = outputs 32 | return grad_input, grad_weights, None 33 | 34 | 35 | class LightconvLayer(nn.Module): 36 | def __init__( 37 | self, 38 | input_size, 39 | kernel_size=1, 40 | padding_l=None, 41 | weight_softmax=False, 42 | num_heads=1, 43 | weight_dropout=0., 44 | bias=False, 45 | with_linear=False, 46 | out_dim=None): 47 | super(LightconvLayer, self).__init__() 48 | self.embed_dim = input_size 49 | self.input_size = input_size 50 | self.kernel_size = kernel_size 51 | self.padding_l = padding_l 52 | self.num_heads = num_heads 53 | self.weight_softmax = weight_softmax 54 | self.weight_dropout = weight_dropout 55 | out_dim = input_size if out_dim is None else out_dim 56 | 57 | self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) 58 | if bias: 59 | self.bias = nn.Parameter(torch.Tensor(input_size)) 60 | else: 61 | self.bias = None 62 | 63 | self.linear1 = Linear(input_size, input_size) if with_linear else None 64 | self.linear2 = Linear(input_size, out_dim) if with_linear else None 65 | 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | nn.init.xavier_uniform_(self.weight) 70 | if self.bias is not None: 71 | nn.init.constant_(self.bias, 0.) 72 | 73 | def forward(self, x, incremental_state=None): 74 | if self.linear1 is not None: 75 | x = self.linear1(x) 76 | 77 | # during inference time, incremental BMM is faster 78 | if incremental_state is not None: 79 | T, B, C = x.size() 80 | K, H = self.kernel_size, self.num_heads 81 | R = C // H 82 | input_buffer = self._get_input_buffer(incremental_state) 83 | if input_buffer is None: 84 | input_buffer = x.new() 85 | x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) 86 | if self.kernel_size > 1: 87 | self._set_input_buffer(incremental_state, x_unfold[:, :, :, -self.kernel_size+1:]) 88 | x_unfold = x_unfold.view(T*B*H, R, -1) 89 | 90 | weight = self.weight 91 | if self.weight_softmax: 92 | weight = F.softmax(weight.float(), dim=1).type_as(weight) 93 | 94 | weight = weight[:, -x_unfold.size(2):] 95 | 96 | K = weight.size(1) 97 | 98 | weight = weight.view(1, H, K).expand(T*B, H, K).contiguous().view(T*B*H, K, 1) 99 | 100 | weight = F.dropout(weight, self.weight_dropout, training=self.training) 101 | output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 102 | output = output.view(T, B, C) 103 | if self.linear2 is not None: 104 | output = self.linear2(output) 105 | 106 | # during training time, use CUDA kernel 107 | else: 108 | x = x.permute(1, 2, 0).contiguous() 109 | weight = self.weight 110 | if self.weight_softmax: 111 | weight = F.softmax(self.weight, -1) 112 | if self.weight_dropout: 113 | weight = F.dropout(weight, self.weight_dropout, training=self.training) 114 | output = lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) 115 | if self.linear2 is not None: 116 | output = self.linear2(output) 117 | 118 | return output 119 | 120 | def reorder_incremental_state(self, incremental_state, new_order): 121 | input_buffer = self._get_input_buffer(incremental_state) 122 | if input_buffer is not None: 123 | input_buffer = input_buffer.index_select(1, new_order) 124 | self._set_input_buffer(incremental_state, input_buffer) 125 | 126 | def _get_input_buffer(self, incremental_state): 127 | return utils.get_incremental_state(self, incremental_state, 'input_buffer') 128 | 129 | def _set_input_buffer(self, incremental_state, new_buffer): 130 | return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) 131 | 132 | def half(self): 133 | print("HALF") 134 | return self._apply(lambda t: t.half() if t.is_floating_point() else t) 135 | 136 | 137 | def Linear(in_features, out_features, bias=True): 138 | m = nn.Linear(in_features, out_features, bias) 139 | nn.init.xavier_uniform_(m.weight) 140 | if bias: 141 | nn.init.constant_(m.bias, 0.) 142 | return m -------------------------------------------------------------------------------- /fairseq/modules/lightconv_layer/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 9 | 10 | setup( 11 | name='lightconv_layer', 12 | ext_modules=[ 13 | CUDAExtension('lightconv_cuda', [ 14 | 'lightconv_cuda.cpp', 15 | 'lightconv_cuda_kernel.cu', 16 | ]), 17 | ], 18 | cmdclass={ 19 | 'build_ext': BuildExtension 20 | }) 21 | -------------------------------------------------------------------------------- /fairseq/modules/multibranch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | from . import MultiheadAttention 8 | 9 | class MultiBranch(nn.Module): 10 | def __init__(self, branches, embed_dim_list): 11 | super().__init__() 12 | self.branches = nn.ModuleList(branches) 13 | self.embed_dim_list = embed_dim_list 14 | 15 | def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None): 16 | tgt_len, bsz, embed_size = query.size() 17 | assert sum(self.embed_dim_list) == embed_size 18 | out = [] 19 | attn = None 20 | start = 0 21 | for idx, embed_dim in enumerate(self.embed_dim_list): 22 | branch = self.branches[idx] 23 | branch_type = type(branch) 24 | 25 | q = query[...,start:start+embed_dim] 26 | if key is not None: 27 | assert value is not None 28 | k, v = key[..., start:start+embed_dim], value[..., start:start+embed_dim] 29 | start += embed_dim 30 | 31 | if branch_type == MultiheadAttention: 32 | x, attn = branch(q, k, v, key_padding_mask, incremental_state, need_weights, static_kv, attn_mask) 33 | else: 34 | mask = key_padding_mask 35 | if mask is not None: 36 | q = q.masked_fill(mask.transpose(0, 1).unsqueeze(2), 0) 37 | x = branch(q.contiguous(), incremental_state=incremental_state) 38 | out.append(x) 39 | 40 | out = torch.cat(out, dim=-1) 41 | return out, attn -------------------------------------------------------------------------------- /fairseq/modules/positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | from .learned_positional_embedding import LearnedPositionalEmbedding 9 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 10 | 11 | 12 | def PositionalEmbedding( 13 | num_embeddings: int, 14 | embedding_dim: int, 15 | padding_idx: int, 16 | learned: bool = False, 17 | ): 18 | if learned: 19 | # if padding_idx is specified then offset the embedding ids by 20 | # this index and adjust num_embeddings appropriately 21 | # TODO: The right place for this offset would be inside 22 | # LearnedPositionalEmbedding. Move this there for a cleaner implementation. 23 | if padding_idx is not None: 24 | num_embeddings = num_embeddings + padding_idx + 1 25 | m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) 26 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 27 | if padding_idx is not None: 28 | nn.init.constant_(m.weight[padding_idx], 0) 29 | else: 30 | m = SinusoidalPositionalEmbedding( 31 | embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1, 32 | ) 33 | return m 34 | -------------------------------------------------------------------------------- /fairseq/modules/sinusoidal_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.onnx.operators 11 | 12 | from fairseq import utils 13 | 14 | 15 | class SinusoidalPositionalEmbedding(nn.Module): 16 | """This module produces sinusoidal positional embeddings of any length. 17 | 18 | Padding symbols are ignored. 19 | """ 20 | 21 | def __init__(self, embedding_dim, padding_idx, init_size=1024): 22 | super().__init__() 23 | self.embedding_dim = embedding_dim 24 | self.padding_idx = padding_idx 25 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 26 | init_size, 27 | embedding_dim, 28 | padding_idx, 29 | ) 30 | self.onnx_trace = False 31 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 32 | 33 | def prepare_for_onnx_export_(self): 34 | self.onnx_trace = True 35 | 36 | @staticmethod 37 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 38 | """Build sinusoidal embeddings. 39 | 40 | This matches the implementation in tensor2tensor, but differs slightly 41 | from the description in Section 3.5 of "Attention Is All You Need". 42 | """ 43 | half_dim = embedding_dim // 2 44 | emb = math.log(10000) / (half_dim - 1) 45 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 46 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 47 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 48 | if embedding_dim % 2 == 1: 49 | # zero pad 50 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 51 | if padding_idx is not None: 52 | emb[padding_idx, :] = 0 53 | return emb 54 | 55 | def forward(self, input, incremental_state=None, timestep=None, **kwargs): 56 | """Input is expected to be of size [bsz x seqlen].""" 57 | bsz, seq_len = torch.onnx.operators.shape_as_tensor(input) 58 | max_pos = self.padding_idx + 1 + seq_len 59 | if self.weights is None or max_pos > self.weights.size(0): 60 | # recompute/expand embeddings if needed 61 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 62 | max_pos, 63 | self.embedding_dim, 64 | self.padding_idx, 65 | ) 66 | self.weights = self.weights.to(self._float_tensor) 67 | 68 | if incremental_state is not None: 69 | # positions is the same for every token when decoding a single step 70 | pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len 71 | if self.onnx_trace: 72 | return self.weights.index_select(index=self.padding_idx + pos, dim=0).unsqueeze(1).repeat(bsz, 1, 1) 73 | return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) 74 | 75 | positions = utils.make_positions(input, self.padding_idx, onnx_trace=self.onnx_trace) 76 | if self.onnx_trace: 77 | flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) 78 | embedding_shape = torch.cat((bsz.view(1), seq_len.view(1), torch.LongTensor([-1]))) 79 | embeddings = torch.onnx.operators.reshape_from_tensor_shape(flat_embeddings, embedding_shape) 80 | return embeddings 81 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 82 | 83 | def max_positions(self): 84 | """Maximum number of supported positions.""" 85 | return int(1e5) # an arbitrary large number 86 | -------------------------------------------------------------------------------- /fairseq/modules/unfold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | def unfold1d(x, kernel_size, padding_l, pad_value=0): 10 | '''unfold T x B x C to T x B x C x K''' 11 | if kernel_size > 1: 12 | T, B, C = x.size() 13 | x = F.pad(x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value) 14 | x = x.as_strided((T, B, C, kernel_size), (B*C, C, 1, B*C)) 15 | else: 16 | x = x.unsqueeze(3) 17 | return x 18 | -------------------------------------------------------------------------------- /fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | from fairseq.optim.fairseq_optimizer import FairseqOptimizer 11 | from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer 12 | from fairseq.optim.bmuf import FairseqBMUF # noqa 13 | 14 | 15 | __all__ = [ 16 | 'FairseqOptimizer', 17 | 'FP16Optimizer', 18 | 'MemoryEfficientFP16Optimizer', 19 | ] 20 | 21 | 22 | build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( 23 | '--optimizer', 24 | base_class=FairseqOptimizer, 25 | default='nag', 26 | ) 27 | 28 | 29 | # automatically import any Python files in the optim/ directory 30 | for file in os.listdir(os.path.dirname(__file__)): 31 | if file.endswith('.py') and not file.startswith('_'): 32 | module = file[:file.find('.py')] 33 | importlib.import_module('fairseq.optim.' + module) 34 | -------------------------------------------------------------------------------- /fairseq/optim/fairseq_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | 8 | import torch 9 | 10 | 11 | class FairseqOptimizer(object): 12 | 13 | def __init__(self, args): 14 | super().__init__() 15 | self.args = args 16 | 17 | @staticmethod 18 | def add_args(parser): 19 | """Add optimizer-specific arguments to the parser.""" 20 | pass 21 | 22 | @property 23 | def optimizer(self): 24 | """Return a torch.optim.optimizer.Optimizer instance.""" 25 | if not hasattr(self, '_optimizer'): 26 | raise NotImplementedError 27 | if not isinstance(self._optimizer, torch.optim.Optimizer): 28 | raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') 29 | return self._optimizer 30 | 31 | @property 32 | def optimizer_config(self): 33 | """ 34 | Return a kwarg dictionary that will be used to override optimizer 35 | args stored in checkpoints. This allows us to load a checkpoint and 36 | resume training using a different set of optimizer args, e.g., with a 37 | different learning rate. 38 | """ 39 | raise NotImplementedError 40 | 41 | @property 42 | def params(self): 43 | """Return an iterable of the parameters held by the optimizer.""" 44 | for param_group in self.optimizer.param_groups: 45 | for p in param_group['params']: 46 | yield p 47 | 48 | def __getstate__(self): 49 | return self._optimizer.__getstate__() 50 | 51 | def get_lr(self): 52 | """Return the current learning rate.""" 53 | return self.optimizer.param_groups[0]['lr'] 54 | 55 | def set_lr(self, lr): 56 | """Set the learning rate.""" 57 | for param_group in self.optimizer.param_groups: 58 | param_group['lr'] = lr 59 | 60 | def state_dict(self): 61 | """Return the optimizer's state dict.""" 62 | return self.optimizer.state_dict() 63 | 64 | def load_state_dict(self, state_dict, optimizer_overrides=None): 65 | """Load an optimizer state dict. 66 | 67 | In general we should prefer the configuration of the existing optimizer 68 | instance (e.g., learning rate) over that found in the state_dict. This 69 | allows us to resume training from a checkpoint using a new set of 70 | optimizer args. 71 | """ 72 | self.optimizer.load_state_dict(state_dict) 73 | 74 | if optimizer_overrides is not None and len(optimizer_overrides) > 0: 75 | # override learning rate, momentum, etc. with latest values 76 | for group in self.optimizer.param_groups: 77 | group.update(optimizer_overrides) 78 | 79 | def backward(self, loss): 80 | """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" 81 | loss.backward() 82 | 83 | def multiply_grads(self, c): 84 | """Multiplies grads by a constant *c*.""" 85 | for p in self.params: 86 | if p.grad is not None: 87 | p.grad.data.mul_(c) 88 | 89 | def clip_grad_norm(self, max_norm): 90 | """Clips gradient norm.""" 91 | if max_norm > 0: 92 | return torch.nn.utils.clip_grad_norm_(self.params, max_norm) 93 | else: 94 | return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params if p.grad is not None)) 95 | 96 | def step(self, closure=None): 97 | """Performs a single optimization step.""" 98 | self.optimizer.step(closure) 99 | 100 | def zero_grad(self): 101 | """Clears the gradients of all optimized parameters.""" 102 | for p in self.params: 103 | p.grad = None 104 | self.optimizer.zero_grad() 105 | 106 | @property 107 | def supports_memory_efficient_fp16(self): 108 | if hasattr(self.optimizer, 'supports_memory_efficient_fp16'): 109 | return self.optimizer.supports_memory_efficient_fp16 110 | return False 111 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import os 8 | 9 | from fairseq import registry 10 | from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler 11 | 12 | 13 | build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry( 14 | '--lr-scheduler', 15 | base_class=FairseqLRScheduler, 16 | default='fixed', 17 | ) 18 | 19 | # automatically import any Python files in the optim/lr_scheduler/ directory 20 | for file in os.listdir(os.path.dirname(__file__)): 21 | if file.endswith('.py') and not file.startswith('_'): 22 | module = file[:file.find('.py')] 23 | importlib.import_module('fairseq.optim.lr_scheduler.' + module) 24 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/cosine_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('cosine') 12 | class CosineSchedule(FairseqLRScheduler): 13 | """Assign LR based on a cyclical schedule that follows the cosine function. 14 | 15 | See https://arxiv.org/pdf/1608.03983.pdf for details. 16 | 17 | We also support a warmup phase where we linearly increase the learning rate 18 | from some initial learning rate (``--warmup-init-lr``) until the configured 19 | max learning rate (``--max-lr``). 20 | 21 | During warmup:: 22 | 23 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 24 | lr = lrs[update_num] 25 | 26 | After warmup:: 27 | 28 | lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) 29 | 30 | where ``t_curr`` is current percentage of updates within the current period 31 | range and ``t_i`` is the current period range, which is scaled by ``t_mul`` 32 | after every iteration. 33 | """ 34 | 35 | def __init__(self, args, optimizer): 36 | super().__init__(args, optimizer) 37 | if len(args.lr) > 1: 38 | raise ValueError( 39 | 'Cannot use a fixed learning rate schedule with cosine.' 40 | ' Consider --lr-scheduler=fixed instead.' 41 | ) 42 | 43 | warmup_end_lr = args.max_lr 44 | if args.warmup_init_lr < 0: 45 | args.warmup_init_lr = args.lr[0] 46 | 47 | self.min_lr = args.lr[0] 48 | self.max_lr = args.max_lr 49 | 50 | assert self.max_lr > self.min_lr, 'max_lr must be more than lr' 51 | 52 | self.t_mult = args.t_mult 53 | self.period = args.lr_period_updates 54 | 55 | if self.period <= 0: 56 | assert args.max_update >= 0, 'Either --max_update or --lr-period-updates must be set' 57 | self.period = args.max_update - args.warmup_updates 58 | 59 | if args.warmup_updates > 0: 60 | # linearly warmup for the first args.warmup_updates 61 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 62 | else: 63 | self.lr_step = 1 64 | 65 | self.warmup_updates = args.warmup_updates 66 | self.lr_shrink = args.lr_shrink 67 | 68 | # initial learning rate 69 | self.lr = args.warmup_init_lr 70 | self.optimizer.set_lr(self.lr) 71 | 72 | @staticmethod 73 | def add_args(parser): 74 | """Add arguments to the parser for this LR scheduler.""" 75 | # fmt: off 76 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 77 | help='warmup the learning rate linearly for the first N updates') 78 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 79 | help='initial learning rate during warmup phase; default is args.lr') 80 | parser.add_argument('--max-lr', type=float, metavar='LR', 81 | help='max learning rate, must be more than args.lr') 82 | parser.add_argument('--t-mult', default=1, type=float, metavar='LR', 83 | help='factor to grow the length of each period') 84 | parser.add_argument('--lr-period-updates', default=-1, type=float, metavar='LR', 85 | help='initial number of updates per period') 86 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', 87 | help='shrink factor for annealing') 88 | # fmt: on 89 | 90 | def step(self, epoch, val_loss=None): 91 | """Update the learning rate at the end of the given epoch.""" 92 | super().step(epoch, val_loss) 93 | # we don't change the learning rate at epoch boundaries 94 | return self.optimizer.get_lr() 95 | 96 | def step_update(self, num_updates): 97 | """Update the learning rate after each update.""" 98 | if num_updates < self.args.warmup_updates: 99 | self.lr = self.args.warmup_init_lr + num_updates * self.lr_step 100 | else: 101 | curr_updates = num_updates - self.args.warmup_updates 102 | if self.t_mult != 1: 103 | i = math.floor(math.log(1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult)) 104 | t_i = self.t_mult ** i * self.period 105 | t_curr = curr_updates - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period 106 | else: 107 | i = math.floor(curr_updates / self.period) 108 | t_i = self.period 109 | t_curr = curr_updates - (self.period * i) 110 | 111 | lr_shrink = self.lr_shrink ** i 112 | min_lr = self.min_lr * lr_shrink 113 | max_lr = self.max_lr * lr_shrink 114 | 115 | self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) 116 | 117 | self.optimizer.set_lr(self.lr) 118 | return self.lr 119 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .. import FairseqOptimizer 7 | 8 | 9 | class FairseqLRScheduler(object): 10 | 11 | def __init__(self, args, optimizer): 12 | super().__init__() 13 | if not isinstance(optimizer, FairseqOptimizer): 14 | raise ValueError('optimizer must be an instance of FairseqOptimizer') 15 | self.args = args 16 | self.optimizer = optimizer 17 | self.best = None 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add arguments to the parser for this LR scheduler.""" 22 | pass 23 | 24 | def state_dict(self): 25 | """Return the LR scheduler state dict.""" 26 | return {'best': self.best} 27 | 28 | def load_state_dict(self, state_dict): 29 | """Load an LR scheduler state dict.""" 30 | self.best = state_dict['best'] 31 | 32 | def step(self, epoch, val_loss=None): 33 | """Update the learning rate at the end of the given epoch.""" 34 | if val_loss is not None: 35 | if self.best is None: 36 | self.best = val_loss 37 | else: 38 | self.best = min(self.best, val_loss) 39 | 40 | def step_update(self, num_updates): 41 | """Update the learning rate after each update.""" 42 | return self.optimizer.get_lr() 43 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fixed_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import FairseqLRScheduler, register_lr_scheduler 7 | 8 | 9 | @register_lr_scheduler('fixed') 10 | class FixedSchedule(FairseqLRScheduler): 11 | """Decay the LR on a fixed schedule.""" 12 | 13 | def __init__(self, args, optimizer): 14 | super().__init__(args, optimizer) 15 | 16 | # set defaults 17 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 18 | 19 | self.lr = args.lr[0] 20 | if args.warmup_updates > 0: 21 | self.warmup_factor = 1. / args.warmup_updates 22 | else: 23 | self.warmup_factor = 1 24 | 25 | @staticmethod 26 | def add_args(parser): 27 | """Add arguments to the parser for this LR scheduler.""" 28 | # fmt: off 29 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 30 | help='force annealing at specified epoch') 31 | parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', 32 | help='shrink factor for annealing, lr_new = (lr * lr_shrink)') 33 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 34 | help='warmup the learning rate linearly for the first N updates') 35 | # fmt: on 36 | 37 | def get_next_lr(self, epoch): 38 | lrs = self.args.lr 39 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 40 | # use fixed LR schedule 41 | next_lr = lrs[min(epoch, len(lrs) - 1)] 42 | else: 43 | # annneal based on lr_shrink 44 | next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) 45 | return next_lr 46 | 47 | def step(self, epoch, val_loss=None): 48 | """Update the learning rate at the end of the given epoch.""" 49 | super().step(epoch, val_loss) 50 | self.lr = self.get_next_lr(epoch) 51 | self.optimizer.set_lr(self.warmup_factor * self.lr) 52 | return self.optimizer.get_lr() 53 | 54 | def step_update(self, num_updates): 55 | """Update the learning rate after each update.""" 56 | if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates: 57 | self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates) 58 | self.optimizer.set_lr(self.warmup_factor * self.lr) 59 | return self.optimizer.get_lr() 60 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/inverse_square_root_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from . import FairseqLRScheduler, register_lr_scheduler 7 | 8 | 9 | @register_lr_scheduler('inverse_sqrt') 10 | class InverseSquareRootSchedule(FairseqLRScheduler): 11 | """Decay the LR based on the inverse square root of the update number. 12 | We also support a warmup phase where we linearly increase the learning rate 13 | from some initial learning rate (``--warmup-init-lr``) until the configured 14 | learning rate (``--lr``). Thereafter we decay proportional to the number of 15 | updates, with a decay factor set to align with the configured learning rate. 16 | During warmup:: 17 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 18 | lr = lrs[update_num] 19 | After warmup:: 20 | decay_factor = args.lr * sqrt(args.warmup_updates) 21 | lr = decay_factor / sqrt(update_num) 22 | """ 23 | 24 | def __init__(self, args, optimizer): 25 | super().__init__(args, optimizer) 26 | if len(args.lr) > 1: 27 | raise ValueError( 28 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.' 29 | ' Consider --lr-scheduler=fixed instead.' 30 | ) 31 | warmup_end_lr = args.lr[0] 32 | if args.warmup_init_lr < 0: 33 | args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr 34 | 35 | # linearly warmup for the first args.warmup_updates 36 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 37 | 38 | # then, decay prop. to the inverse square root of the update number 39 | self.decay_factor = warmup_end_lr * args.warmup_updates**0.5 40 | 41 | # initial learning rate 42 | self.lr = args.warmup_init_lr 43 | self.optimizer.set_lr(self.lr) 44 | 45 | @staticmethod 46 | def add_args(parser): 47 | """Add arguments to the parser for this LR scheduler.""" 48 | # fmt: off 49 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', 50 | help='warmup the learning rate linearly for the first N updates') 51 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 52 | help='initial learning rate during warmup phase; default is args.lr') 53 | # fmt: on 54 | 55 | def step(self, epoch, val_loss=None): 56 | """Update the learning rate at the end of the given epoch.""" 57 | super().step(epoch, val_loss) 58 | # we don't change the learning rate at epoch boundaries 59 | return self.optimizer.get_lr() 60 | 61 | def step_update(self, num_updates): 62 | """Update the learning rate after each update.""" 63 | if num_updates < self.args.warmup_updates: 64 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step 65 | else: 66 | self.lr = self.decay_factor * num_updates**-0.5 67 | self.optimizer.set_lr(self.lr) 68 | return self.lr 69 | -------------------------------------------------------------------------------- /fairseq/optim/nag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch.optim.optimizer import Optimizer, required 8 | 9 | from . import FairseqOptimizer, register_optimizer 10 | 11 | 12 | @register_optimizer('nag') 13 | class FairseqNAG(FairseqOptimizer): 14 | def __init__(self, args, params): 15 | super().__init__(args) 16 | self._optimizer = NAG(params, **self.optimizer_config) 17 | 18 | @staticmethod 19 | def add_args(parser): 20 | """Add optimizer-specific arguments to the parser.""" 21 | # fmt: off 22 | parser.add_argument('--momentum', default=0.99, type=float, metavar='M', 23 | help='momentum factor') 24 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 25 | help='weight decay') 26 | # fmt: on 27 | 28 | @property 29 | def optimizer_config(self): 30 | """ 31 | Return a kwarg dictionary that will be used to override optimizer 32 | args stored in checkpoints. This allows us to load a checkpoint and 33 | resume training using a different set of optimizer args, e.g., with a 34 | different learning rate. 35 | """ 36 | return { 37 | 'lr': self.args.lr[0], 38 | 'momentum': self.args.momentum, 39 | 'weight_decay': self.args.weight_decay, 40 | } 41 | 42 | 43 | class NAG(Optimizer): 44 | def __init__(self, params, lr=required, momentum=0, weight_decay=0): 45 | defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) 46 | super(NAG, self).__init__(params, defaults) 47 | 48 | @property 49 | def supports_memory_efficient_fp16(self): 50 | return True 51 | 52 | def step(self, closure=None): 53 | """Performs a single optimization step. 54 | Arguments: 55 | closure (callable, optional): A closure that reevaluates the model 56 | and returns the loss. 57 | """ 58 | loss = None 59 | if closure is not None: 60 | loss = closure() 61 | 62 | for group in self.param_groups: 63 | weight_decay = group['weight_decay'] 64 | momentum = group['momentum'] 65 | lr = group['lr'] 66 | lr_old = group.get('lr_old', lr) 67 | lr_correct = lr / lr_old 68 | 69 | for p in group['params']: 70 | if p.grad is None: 71 | continue 72 | 73 | p_data_fp32 = p.data.float() 74 | 75 | d_p = p.grad.data.float() 76 | param_state = self.state[p] 77 | if 'momentum_buffer' not in param_state: 78 | param_state['momentum_buffer'] = torch.zeros_like(d_p) 79 | else: 80 | param_state['momentum_buffer'] = param_state['momentum_buffer'].type_as(d_p) 81 | 82 | buf = param_state['momentum_buffer'] 83 | 84 | if weight_decay != 0: 85 | p_data_fp32.mul_(1 - lr * weight_decay) 86 | p_data_fp32.add_(momentum * momentum * lr_correct, buf) 87 | p_data_fp32.add_(-(1 + momentum) * lr, d_p) 88 | 89 | buf.mul_(momentum * lr_correct).add_(-lr, d_p) 90 | 91 | p.data.copy_(p_data_fp32) 92 | 93 | group['lr_old'] = lr 94 | 95 | return loss 96 | -------------------------------------------------------------------------------- /fairseq/pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import multiprocessing 7 | import os 8 | import pdb 9 | import sys 10 | 11 | 12 | __all__ = ['set_trace'] 13 | 14 | 15 | _stdin = [None] 16 | _stdin_lock = multiprocessing.Lock() 17 | try: 18 | _stdin_fd = sys.stdin.fileno() 19 | except Exception: 20 | _stdin_fd = None 21 | 22 | 23 | class MultiprocessingPdb(pdb.Pdb): 24 | """A Pdb wrapper that works in a multiprocessing environment. 25 | 26 | Usage: `from fairseq import pdb; pdb.set_trace()` 27 | """ 28 | 29 | def __init__(self): 30 | pdb.Pdb.__init__(self, nosigint=True) 31 | 32 | def _cmdloop(self): 33 | stdin_bak = sys.stdin 34 | with _stdin_lock: 35 | try: 36 | if _stdin_fd is not None: 37 | if not _stdin[0]: 38 | _stdin[0] = os.fdopen(_stdin_fd) 39 | sys.stdin = _stdin[0] 40 | self.cmdloop() 41 | finally: 42 | sys.stdin = stdin_bak 43 | 44 | 45 | def set_trace(): 46 | pdb = MultiprocessingPdb() 47 | pdb.set_trace(sys._getframe().f_back) 48 | -------------------------------------------------------------------------------- /fairseq/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | 8 | 9 | REGISTRIES = {} 10 | 11 | 12 | def setup_registry( 13 | registry_name: str, 14 | base_class=None, 15 | default=None, 16 | ): 17 | assert registry_name.startswith('--') 18 | registry_name = registry_name[2:].replace('-', '_') 19 | 20 | REGISTRY = {} 21 | REGISTRY_CLASS_NAMES = set() 22 | 23 | # maintain a registry of all registries 24 | if registry_name in REGISTRIES: 25 | return # registry already exists 26 | REGISTRIES[registry_name] = { 27 | 'registry': REGISTRY, 28 | 'default': default, 29 | } 30 | 31 | def build_x(args, *extra_args, **extra_kwargs): 32 | choice = getattr(args, registry_name, None) 33 | if choice is None: 34 | return None 35 | cls = REGISTRY[choice] 36 | if hasattr(cls, 'build_' + registry_name): 37 | builder = getattr(cls, 'build_' + registry_name) 38 | else: 39 | builder = cls 40 | set_defaults(args, cls) 41 | return builder(args, *extra_args, **extra_kwargs) 42 | 43 | def register_x(name): 44 | 45 | def register_x_cls(cls): 46 | if name in REGISTRY: 47 | raise ValueError('Cannot register duplicate {} ({})'.format(registry_name, name)) 48 | if cls.__name__ in REGISTRY_CLASS_NAMES: 49 | raise ValueError( 50 | 'Cannot register {} with duplicate class name ({})'.format( 51 | registry_name, cls.__name__, 52 | ) 53 | ) 54 | if base_class is not None and not issubclass(cls, base_class): 55 | raise ValueError('{} must extend {}'.format(cls.__name__, base_class.__name__)) 56 | REGISTRY[name] = cls 57 | REGISTRY_CLASS_NAMES.add(cls.__name__) 58 | return cls 59 | 60 | return register_x_cls 61 | 62 | return build_x, register_x, REGISTRY 63 | 64 | 65 | def set_defaults(args, cls): 66 | """Helper to set default arguments based on *add_args*.""" 67 | if not hasattr(cls, 'add_args'): 68 | return 69 | parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False) 70 | cls.add_args(parser) 71 | # copied from argparse.py: 72 | defaults = argparse.Namespace() 73 | for action in parser._actions: 74 | if action.dest is not argparse.SUPPRESS: 75 | if not hasattr(defaults, action.dest): 76 | if action.default is not argparse.SUPPRESS: 77 | setattr(defaults, action.dest, action.default) 78 | for key, default_value in vars(defaults).items(): 79 | if not hasattr(args, key): 80 | setattr(args, key, default_value) 81 | -------------------------------------------------------------------------------- /fairseq/sequence_scorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import sys 8 | 9 | from fairseq import utils 10 | 11 | 12 | class SequenceScorer(object): 13 | """Scores the target for a given source sentence.""" 14 | 15 | def __init__(self, tgt_dict, softmax_batch=None): 16 | self.pad = tgt_dict.pad() 17 | self.softmax_batch = softmax_batch or sys.maxsize 18 | assert self.softmax_batch > 0 19 | 20 | @torch.no_grad() 21 | def generate(self, models, sample, **kwargs): 22 | """Score a batch of translations.""" 23 | net_input = sample['net_input'] 24 | 25 | def batch_for_softmax(dec_out, target): 26 | # assumes decoder_out[0] is the only thing needed (may not be correct for future models!) 27 | first, rest = dec_out[0], dec_out[1:] 28 | bsz, tsz, dim = first.shape 29 | if bsz * tsz < self.softmax_batch: 30 | yield dec_out, target, True 31 | else: 32 | flat = first.contiguous().view(1, -1, dim) 33 | flat_tgt = target.contiguous().view(flat.shape[:-1]) 34 | s = 0 35 | while s < flat.size(1): 36 | e = s + self.softmax_batch 37 | yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False 38 | s = e 39 | 40 | def gather_target_probs(probs, target): 41 | probs = probs.gather( 42 | dim=2, 43 | index=target.unsqueeze(-1), 44 | ) 45 | return probs 46 | 47 | orig_target = sample['target'] 48 | 49 | # compute scores for each model in the ensemble 50 | avg_probs = None 51 | avg_attn = None 52 | for model in models: 53 | model.eval() 54 | decoder_out = model.forward(**net_input) 55 | attn = decoder_out[1] 56 | 57 | batched = batch_for_softmax(decoder_out, orig_target) 58 | probs, idx = None, 0 59 | for bd, tgt, is_single in batched: 60 | sample['target'] = tgt 61 | curr_prob = model.get_normalized_probs(bd, log_probs=len(models) == 1, sample=sample).data 62 | if is_single: 63 | probs = gather_target_probs(curr_prob, orig_target) 64 | else: 65 | if probs is None: 66 | probs = curr_prob.new(orig_target.numel()) 67 | step = curr_prob.size(0) * curr_prob.size(1) 68 | end = step + idx 69 | tgt_probs = gather_target_probs(curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt) 70 | probs[idx:end] = tgt_probs.view(-1) 71 | idx = end 72 | sample['target'] = orig_target 73 | 74 | probs = probs.view(sample['target'].shape) 75 | 76 | if avg_probs is None: 77 | avg_probs = probs 78 | else: 79 | avg_probs.add_(probs) 80 | if attn is not None and torch.is_tensor(attn): 81 | attn = attn.data 82 | if avg_attn is None: 83 | avg_attn = attn 84 | else: 85 | avg_attn.add_(attn) 86 | if len(models) > 1: 87 | avg_probs.div_(len(models)) 88 | avg_probs.log_() 89 | if avg_attn is not None: 90 | avg_attn.div_(len(models)) 91 | 92 | bsz = avg_probs.size(0) 93 | hypos = [] 94 | start_idxs = sample['start_indices'] if 'start_indices' in sample else [0] * bsz 95 | for i in range(bsz): 96 | # remove padding from ref 97 | ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \ 98 | if sample['target'] is not None else None 99 | tgt_len = ref.numel() 100 | avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len] 101 | score_i = avg_probs_i.sum() / tgt_len 102 | if avg_attn is not None: 103 | avg_attn_i = avg_attn[i, start_idxs[i]:] 104 | _, alignment = avg_attn_i.max(dim=0) 105 | else: 106 | avg_attn_i = alignment = None 107 | hypos.append([{ 108 | 'tokens': ref, 109 | 'score': score_i, 110 | 'attention': avg_attn_i, 111 | 'alignment': alignment, 112 | 'positional_scores': avg_probs_i, 113 | }]) 114 | return hypos 115 | -------------------------------------------------------------------------------- /fairseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import importlib 8 | import os 9 | 10 | from .fairseq_task import FairseqTask 11 | 12 | TASK_REGISTRY = {} 13 | TASK_CLASS_NAMES = set() 14 | 15 | 16 | def setup_task(args, **kwargs): 17 | return TASK_REGISTRY[args.task].setup_task(args, **kwargs) 18 | 19 | 20 | def register_task(name): 21 | """ 22 | New tasks can be added to fairseq with the 23 | :func:`~fairseq.tasks.register_task` function decorator. 24 | 25 | For example:: 26 | 27 | @register_task('classification') 28 | class ClassificationTask(FairseqTask): 29 | (...) 30 | 31 | .. note:: 32 | 33 | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` 34 | interface. 35 | 36 | Please see the 37 | 38 | Args: 39 | name (str): the name of the task 40 | """ 41 | 42 | def register_task_cls(cls): 43 | if name in TASK_REGISTRY: 44 | raise ValueError('Cannot register duplicate task ({})'.format(name)) 45 | if not issubclass(cls, FairseqTask): 46 | raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__)) 47 | if cls.__name__ in TASK_CLASS_NAMES: 48 | raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__)) 49 | TASK_REGISTRY[name] = cls 50 | TASK_CLASS_NAMES.add(cls.__name__) 51 | return cls 52 | 53 | return register_task_cls 54 | 55 | 56 | # automatically import any Python files in the tasks/ directory 57 | for file in os.listdir(os.path.dirname(__file__)): 58 | if file.endswith('.py') and not file.startswith('_'): 59 | task_name = file[:file.find('.py')] 60 | importlib.import_module('fairseq.tasks.' + task_name) 61 | 62 | # expose `task_parser` for sphinx 63 | if task_name in TASK_REGISTRY: 64 | parser = argparse.ArgumentParser(add_help=False) 65 | group_task = parser.add_argument_group('Task name') 66 | # fmt: off 67 | group_task.add_argument('--task', metavar=task_name, 68 | help='Enable this task with: ``--task=' + task_name + '``') 69 | # fmt: on 70 | group_args = parser.add_argument_group('Additional command-line arguments') 71 | TASK_REGISTRY[task_name].add_args(group_args) 72 | globals()[task_name + '_parser'] = parser 73 | 74 | 75 | def get_task(name): 76 | return TASK_REGISTRY[name] 77 | -------------------------------------------------------------------------------- /fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import re 7 | 8 | SPACE_NORMALIZER = re.compile(r"\s+") 9 | 10 | 11 | def tokenize_line(line): 12 | line = SPACE_NORMALIZER.sub(" ", line) 13 | line = line.strip() 14 | return line.split() 15 | -------------------------------------------------------------------------------- /figures/compression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/lite-transformer/d2d98c396d45b4088cb4fd573f199ed6c3454b85/figures/compression.png -------------------------------------------------------------------------------- /figures/et.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/lite-transformer/d2d98c396d45b4088cb4fd573f199ed6c3454b85/figures/et.png -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/lite-transformer/d2d98c396d45b4088cb4fd573f199ed6c3454b85/figures/overview.png -------------------------------------------------------------------------------- /figures/tradeoff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/lite-transformer/d2d98c396d45b4088cb4fd573f199ed6c3454b85/figures/tradeoff.png -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | BLEU scoring of generated translations against reference translations. 8 | """ 9 | 10 | import argparse 11 | import os 12 | import sys 13 | 14 | from fairseq import bleu 15 | from fairseq.data import dictionary 16 | 17 | 18 | def get_parser(): 19 | parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.') 20 | # fmt: off 21 | parser.add_argument('-s', '--sys', default='-', help='system output') 22 | parser.add_argument('-r', '--ref', required=True, help='references') 23 | parser.add_argument('-o', '--order', default=4, metavar='N', 24 | type=int, help='consider ngrams up to this order') 25 | parser.add_argument('--ignore-case', action='store_true', 26 | help='case-insensitive scoring') 27 | parser.add_argument('--sacrebleu', action='store_true', 28 | help='score with sacrebleu') 29 | parser.add_argument('--sentence-bleu', action='store_true', 30 | help='report sentence-level BLEUs (i.e., with +1 smoothing)') 31 | # fmt: on 32 | return parser 33 | 34 | 35 | def main(): 36 | parser = get_parser() 37 | args = parser.parse_args() 38 | print(args) 39 | 40 | assert args.sys == '-' or os.path.exists(args.sys), \ 41 | "System output file {} does not exist".format(args.sys) 42 | assert os.path.exists(args.ref), \ 43 | "Reference file {} does not exist".format(args.ref) 44 | 45 | dict = dictionary.Dictionary() 46 | 47 | def readlines(fd): 48 | for line in fd.readlines(): 49 | if args.ignore_case: 50 | yield line.lower() 51 | else: 52 | yield line 53 | 54 | if args.sacrebleu: 55 | import sacrebleu 56 | 57 | def score(fdsys): 58 | with open(args.ref) as fdref: 59 | print(sacrebleu.corpus_bleu(fdsys, [fdref])) 60 | elif args.sentence_bleu: 61 | def score(fdsys): 62 | with open(args.ref) as fdref: 63 | scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) 64 | for i, (sys_tok, ref_tok) in enumerate(zip(readlines(fdsys), readlines(fdref))): 65 | scorer.reset(one_init=True) 66 | sys_tok = dict.encode_line(sys_tok) 67 | ref_tok = dict.encode_line(ref_tok) 68 | scorer.add(ref_tok, sys_tok) 69 | print(i, scorer.result_string(args.order)) 70 | else: 71 | def score(fdsys): 72 | with open(args.ref) as fdref: 73 | scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) 74 | for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): 75 | sys_tok = dict.encode_line(sys_tok) 76 | ref_tok = dict.encode_line(ref_tok) 77 | scorer.add(ref_tok, sys_tok) 78 | print(scorer.result_string(args.order)) 79 | 80 | if args.sys == '-': 81 | score(sys.stdin) 82 | else: 83 | with open(args.sys, 'r') as f: 84 | score(f) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/lite-transformer/d2d98c396d45b4088cb4fd573f199ed6c3454b85/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import collections 9 | import torch 10 | import os 11 | import re 12 | 13 | 14 | def average_checkpoints(inputs): 15 | """Loads checkpoints from inputs and returns a model with averaged weights. 16 | 17 | Args: 18 | inputs: An iterable of string paths of checkpoints to load from. 19 | 20 | Returns: 21 | A dict of string keys mapping to various values. The 'model' key 22 | from the returned dict should correspond to an OrderedDict mapping 23 | string parameter names to torch Tensors. 24 | """ 25 | params_dict = collections.OrderedDict() 26 | params_keys = None 27 | new_state = None 28 | num_models = len(inputs) 29 | 30 | for f in inputs: 31 | state = torch.load( 32 | f, 33 | map_location=( 34 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu') 35 | ), 36 | ) 37 | # Copies over the settings from the first checkpoint 38 | if new_state is None: 39 | new_state = state 40 | 41 | model_params = state['model'] 42 | 43 | model_params_keys = list(model_params.keys()) 44 | if params_keys is None: 45 | params_keys = model_params_keys 46 | elif params_keys != model_params_keys: 47 | raise KeyError( 48 | 'For checkpoint {}, expected list of params: {}, ' 49 | 'but found: {}'.format(f, params_keys, model_params_keys) 50 | ) 51 | 52 | for k in params_keys: 53 | p = model_params[k] 54 | if isinstance(p, torch.HalfTensor): 55 | p = p.float() 56 | if k not in params_dict: 57 | params_dict[k] = p.clone() 58 | # NOTE: clone() is needed in case of p is a shared parameter 59 | else: 60 | params_dict[k] += p 61 | 62 | averaged_params = collections.OrderedDict() 63 | for k, v in params_dict.items(): 64 | averaged_params[k] = v 65 | averaged_params[k].div_(num_models) 66 | new_state['model'] = averaged_params 67 | return new_state 68 | 69 | 70 | def last_n_checkpoints(paths, n, update_based, upper_bound=None): 71 | assert len(paths) == 1 72 | path = paths[0] 73 | if update_based: 74 | pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') 75 | else: 76 | pt_regexp = re.compile(r'checkpoint(\d+)\.pt') 77 | files = os.listdir(path) 78 | 79 | entries = [] 80 | for f in files: 81 | m = pt_regexp.fullmatch(f) 82 | if m is not None: 83 | sort_key = int(m.group(1)) 84 | if upper_bound is None or sort_key <= upper_bound: 85 | entries.append((sort_key, m.group(0))) 86 | if len(entries) < n: 87 | raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) 88 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser( 93 | description='Tool to average the params of input checkpoints to ' 94 | 'produce a new checkpoint', 95 | ) 96 | # fmt: off 97 | parser.add_argument('--inputs', required=True, nargs='+', 98 | help='Input checkpoint file paths.') 99 | parser.add_argument('--output', required=True, metavar='FILE', 100 | help='Write the new checkpoint containing the averaged weights to this path.') 101 | num_group = parser.add_mutually_exclusive_group() 102 | num_group.add_argument('--num-epoch-checkpoints', type=int, 103 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 104 | 'and average last this many of them.') 105 | num_group.add_argument('--num-update-checkpoints', type=int, 106 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 107 | 'and average last this many of them.') 108 | parser.add_argument('--checkpoint-upper-bound', type=int, 109 | help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, ' 110 | 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.') 111 | # fmt: on 112 | args = parser.parse_args() 113 | print(args) 114 | 115 | num = None 116 | is_update_based = False 117 | if args.num_update_checkpoints is not None: 118 | num = args.num_update_checkpoints 119 | is_update_based = True 120 | elif args.num_epoch_checkpoints is not None: 121 | num = args.num_epoch_checkpoints 122 | 123 | assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \ 124 | '--checkpoint-upper-bound requires --num-epoch-checkpoints' 125 | assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \ 126 | 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints' 127 | 128 | if num is not None: 129 | args.inputs = last_n_checkpoints( 130 | args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound, 131 | ) 132 | print('averaging checkpoints: ', args.inputs) 133 | 134 | new_state = average_checkpoints(args.inputs) 135 | torch.save(new_state, args.output) 136 | print('Finished writing averaged checkpoint to {}.'.format(args.output)) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /scripts/compare_namespaces.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Helper script to compare two argparse.Namespace objects.""" 3 | 4 | from argparse import Namespace # noqa 5 | 6 | 7 | def main(): 8 | 9 | ns1 = eval(input('Namespace 1: ')) 10 | ns2 = eval(input('Namespace 2: ')) 11 | 12 | def keys(ns): 13 | ks = set() 14 | for k in dir(ns): 15 | if not k.startswith('_'): 16 | ks.add(k) 17 | return ks 18 | 19 | k1 = keys(ns1) 20 | k2 = keys(ns2) 21 | 22 | def print_keys(ks, ns1, ns2=None): 23 | for k in ks: 24 | if ns2 is None: 25 | print('{}\t{}'.format(k, getattr(ns1, k, None))) 26 | else: 27 | print('{}\t{}\t{}'.format(k, getattr(ns1, k, None), getattr(ns2, k, None))) 28 | 29 | print('Keys unique to namespace 1:') 30 | print_keys(k1 - k2, ns1) 31 | print() 32 | 33 | print('Keys unique to namespace 2:') 34 | print_keys(k2 - k1, ns2) 35 | print() 36 | 37 | print('Overlapping keys with different values:') 38 | ks = [k for k in k1 & k2 if getattr(ns1, k, 'None') != getattr(ns2, k, 'None')] 39 | print_keys(ks, ns1, ns2) 40 | print() 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /scripts/compound_split_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 1 ]; then 4 | echo "usage: $0 GENERATE_PY_OUTPUT" 5 | exit 1 6 | fi 7 | 8 | GEN=$1 9 | 10 | SYS=$GEN.sys 11 | REF=$GEN.ref 12 | 13 | if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then 14 | echo "not done generating" 15 | exit 16 | fi 17 | 18 | grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS 19 | grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF 20 | fairseq-score --sys $SYS --ref $REF 21 | -------------------------------------------------------------------------------- /scripts/convert_dictionary.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) Facebook, Inc. and its affiliates. 2 | -- 3 | -- This source code is licensed under the MIT license found in the 4 | -- LICENSE file in the root directory of this source tree. 5 | -- 6 | -- Usage: convert_dictionary.lua <dict.th7> 7 | require 'fairseq' 8 | require 'torch' 9 | require 'paths' 10 | 11 | if #arg < 1 then 12 | print('usage: convert_dictionary.lua <dict.th7>') 13 | os.exit(1) 14 | end 15 | if not paths.filep(arg[1]) then 16 | print('error: file does not exit: ' .. arg[1]) 17 | os.exit(1) 18 | end 19 | 20 | dict = torch.load(arg[1]) 21 | dst = paths.basename(arg[1]):gsub('.th7', '.txt') 22 | assert(dst:match('.txt$')) 23 | 24 | f = io.open(dst, 'w') 25 | for idx, symbol in ipairs(dict.index_to_symbol) do 26 | if idx > dict.cutoff then 27 | break 28 | end 29 | f:write(symbol) 30 | f:write(' ') 31 | f:write(dict.index_to_freq[idx]) 32 | f:write('\n') 33 | end 34 | f:close() 35 | -------------------------------------------------------------------------------- /scripts/convert_model.lua: -------------------------------------------------------------------------------- 1 | -- Copyright (c) Facebook, Inc. and its affiliates. 2 | -- 3 | -- This source code is licensed under the MIT license found in the 4 | -- LICENSE file in the root directory of this source tree. 5 | -- 6 | -- Usage: convert_model.lua <model_epoch1.th7> 7 | require 'torch' 8 | local fairseq = require 'fairseq' 9 | 10 | model = torch.load(arg[1]) 11 | 12 | function find_weight_norm(container, module) 13 | for _, wn in ipairs(container:listModules()) do 14 | if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then 15 | return wn 16 | end 17 | end 18 | end 19 | 20 | function push_state(dict, key, module) 21 | if torch.type(module) == 'nn.Linear' then 22 | local wn = find_weight_norm(model.module, module) 23 | assert(wn) 24 | dict[key .. '.weight_v'] = wn.v:float() 25 | dict[key .. '.weight_g'] = wn.g:float() 26 | elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then 27 | local wn = find_weight_norm(model.module, module) 28 | assert(wn) 29 | local v = wn.v:float():view(wn.viewOut):transpose(2, 3) 30 | dict[key .. '.weight_v'] = v 31 | dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1) 32 | else 33 | dict[key .. '.weight'] = module.weight:float() 34 | end 35 | if module.bias then 36 | dict[key .. '.bias'] = module.bias:float() 37 | end 38 | end 39 | 40 | encoder_dict = {} 41 | decoder_dict = {} 42 | combined_dict = {} 43 | 44 | function encoder_state(encoder) 45 | luts = encoder:findModules('nn.LookupTable') 46 | push_state(encoder_dict, 'embed_tokens', luts[1]) 47 | push_state(encoder_dict, 'embed_positions', luts[2]) 48 | 49 | fcs = encoder:findModules('nn.Linear') 50 | assert(#fcs >= 2) 51 | local nInputPlane = fcs[1].weight:size(1) 52 | push_state(encoder_dict, 'fc1', table.remove(fcs, 1)) 53 | push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs)) 54 | 55 | for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do 56 | push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module) 57 | if nInputPlane ~= module.weight:size(3) / 2 then 58 | push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 59 | end 60 | nInputPlane = module.weight:size(3) / 2 61 | end 62 | assert(#fcs == 0) 63 | end 64 | 65 | function decoder_state(decoder) 66 | luts = decoder:findModules('nn.LookupTable') 67 | push_state(decoder_dict, 'embed_tokens', luts[1]) 68 | push_state(decoder_dict, 'embed_positions', luts[2]) 69 | 70 | fcs = decoder:findModules('nn.Linear') 71 | local nInputPlane = fcs[1].weight:size(1) 72 | push_state(decoder_dict, 'fc1', table.remove(fcs, 1)) 73 | push_state(decoder_dict, 'fc2', fcs[#fcs - 1]) 74 | push_state(decoder_dict, 'fc3', fcs[#fcs]) 75 | 76 | table.remove(fcs, #fcs) 77 | table.remove(fcs, #fcs) 78 | 79 | for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do 80 | if nInputPlane ~= module.weight:size(3) / 2 then 81 | push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) 82 | end 83 | nInputPlane = module.weight:size(3) / 2 84 | 85 | local prefix = 'attention.' .. tostring(i - 1) 86 | push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1)) 87 | push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1)) 88 | push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module) 89 | end 90 | assert(#fcs == 0) 91 | end 92 | 93 | 94 | _encoder = model.module.modules[2] 95 | _decoder = model.module.modules[3] 96 | 97 | encoder_state(_encoder) 98 | decoder_state(_decoder) 99 | 100 | for k, v in pairs(encoder_dict) do 101 | combined_dict['encoder.' .. k] = v 102 | end 103 | for k, v in pairs(decoder_dict) do 104 | combined_dict['decoder.' .. k] = v 105 | end 106 | 107 | 108 | torch.save('state_dict.t7', combined_dict) 109 | -------------------------------------------------------------------------------- /scripts/count_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Count the number of documents and average number of lines and tokens per 8 | document in a large file. Documents should be separated by a single empty line. 9 | """ 10 | 11 | import argparse 12 | import gzip 13 | import sys 14 | 15 | import numpy as np 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('input') 21 | parser.add_argument('--gzip', action='store_true') 22 | args = parser.parse_args() 23 | 24 | def gopen(): 25 | if args.gzip: 26 | return gzip.open(args.input, 'r') 27 | else: 28 | return open(args.input, 'r', encoding='utf-8') 29 | 30 | num_lines = [] 31 | num_toks = [] 32 | with gopen() as h: 33 | num_docs = 1 34 | num_lines_in_doc = 0 35 | num_toks_in_doc = 0 36 | for i, line in enumerate(h): 37 | if len(line.strip()) == 0: # empty line indicates new document 38 | num_docs += 1 39 | num_lines.append(num_lines_in_doc) 40 | num_toks.append(num_toks_in_doc) 41 | num_lines_in_doc = 0 42 | num_toks_in_doc = 0 43 | else: 44 | num_lines_in_doc += 1 45 | num_toks_in_doc += len(line.rstrip().split()) 46 | if i % 1000000 == 0: 47 | print(i, file=sys.stderr, end="", flush=True) 48 | elif i % 100000 == 0: 49 | print(".", file=sys.stderr, end="", flush=True) 50 | print(file=sys.stderr, flush=True) 51 | 52 | print("found {} docs".format(num_docs)) 53 | print("average num lines per doc: {}".format(np.mean(num_lines))) 54 | print("average num toks per doc: {}".format(np.mean(num_toks))) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /scripts/parse_profile.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | 4 | def main(args): 5 | significant = [] 6 | with open(args.file, 'r') as infile: 7 | ops = defaultdict(list) 8 | for line in infile.readlines(): 9 | line = line.strip().split() 10 | try: 11 | op, time = line[0], float(line[-1][:-len('us')]) 12 | except: 13 | continue 14 | ops[op].append(time) 15 | if time > 10000: 16 | significant.append((op, len(ops[op]), time)) 17 | print(significant) 18 | # print(len(ops[args.operator_name]), ops[args.operator_name]) 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('-op', '--operator-name', type=str) 23 | parser.add_argument('file', type=str) 24 | args = parser.parse_args() 25 | main(args) -------------------------------------------------------------------------------- /scripts/read_binarized.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | 9 | from fairseq.data import data_utils, Dictionary, indexed_dataset 10 | 11 | 12 | def get_parser(): 13 | parser = argparse.ArgumentParser( 14 | description='writes text from binarized file to stdout') 15 | # fmt: off 16 | parser.add_argument('--dataset-impl', help='dataset implementation', 17 | choices=indexed_dataset.get_available_dataset_impl()) 18 | parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None) 19 | parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') 20 | # fmt: on 21 | 22 | return parser 23 | 24 | 25 | def main(): 26 | parser = get_parser() 27 | args = parser.parse_args() 28 | 29 | dictionary = Dictionary.load(args.dict) if args.dict is not None else None 30 | dataset = data_utils.load_indexed_dataset( 31 | args.input, 32 | dictionary, 33 | dataset_impl=args.dataset_impl, 34 | default='lazy', 35 | ) 36 | 37 | for tensor_line in dataset: 38 | if dictionary is None: 39 | line = ' '.join([str(int(x)) for x in tensor_line]) 40 | else: 41 | line = dictionary.string(tensor_line) 42 | 43 | print(line) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /scripts/rm_pt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import re 10 | import shutil 11 | import sys 12 | 13 | 14 | pt_regexp = re.compile(r'checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt') 15 | pt_regexp_epoch_based = re.compile(r'checkpoint(\d+)\.pt') 16 | pt_regexp_update_based = re.compile(r'checkpoint_\d+_(\d+)\.pt') 17 | 18 | 19 | def parse_checkpoints(files): 20 | entries = [] 21 | for f in files: 22 | m = pt_regexp_epoch_based.fullmatch(f) 23 | if m is not None: 24 | entries.append((int(m.group(1)), m.group(0))) 25 | else: 26 | m = pt_regexp_update_based.fullmatch(f) 27 | if m is not None: 28 | entries.append((int(m.group(1)), m.group(0))) 29 | return entries 30 | 31 | 32 | def last_n_checkpoints(files, n): 33 | entries = parse_checkpoints(files) 34 | return [x[1] for x in sorted(entries, reverse=True)[:n]] 35 | 36 | 37 | def every_n_checkpoints(files, n): 38 | entries = parse_checkpoints(files) 39 | return [x[1] for x in sorted(sorted(entries)[::-n])] 40 | 41 | 42 | def main(): 43 | parser = argparse.ArgumentParser( 44 | description=( 45 | 'Recursively delete checkpoint files from `root_dir`, ' 46 | 'but preserve checkpoint_best.pt and checkpoint_last.pt' 47 | ) 48 | ) 49 | parser.add_argument('root_dirs', nargs='*') 50 | parser.add_argument('--save-last', type=int, default=0, help='number of last checkpoints to save') 51 | parser.add_argument('--save-every', type=int, default=0, help='interval of checkpoints to save') 52 | parser.add_argument('--preserve-test', action='store_true', 53 | help='preserve checkpoints in dirs that start with test_ prefix (default: delete them)') 54 | parser.add_argument('--delete-best', action='store_true', help='delete checkpoint_best.pt') 55 | parser.add_argument('--delete-last', action='store_true', help='delete checkpoint_last.pt') 56 | parser.add_argument('--no-dereference', action='store_true', help='don\'t dereference symlinks') 57 | args = parser.parse_args() 58 | 59 | files_to_desymlink = [] 60 | files_to_preserve = [] 61 | files_to_delete = [] 62 | for root_dir in args.root_dirs: 63 | for root, _subdirs, files in os.walk(root_dir): 64 | if args.save_last > 0: 65 | to_save = last_n_checkpoints(files, args.save_last) 66 | else: 67 | to_save = [] 68 | if args.save_every > 0: 69 | to_save += every_n_checkpoints(files, args.save_every) 70 | for file in files: 71 | if not pt_regexp.fullmatch(file): 72 | continue 73 | full_path = os.path.join(root, file) 74 | if ( 75 | ( 76 | not os.path.basename(root).startswith('test_') 77 | or args.preserve_test 78 | ) 79 | and ( 80 | (file == 'checkpoint_last.pt' and not args.delete_last) 81 | or (file == 'checkpoint_best.pt' and not args.delete_best) 82 | or file in to_save 83 | ) 84 | ): 85 | if os.path.islink(full_path) and not args.no_dereference: 86 | files_to_desymlink.append(full_path) 87 | else: 88 | files_to_preserve.append(full_path) 89 | else: 90 | files_to_delete.append(full_path) 91 | 92 | if len(files_to_desymlink) == 0 and len(files_to_delete) == 0: 93 | print('Nothing to do.') 94 | sys.exit(0) 95 | 96 | files_to_desymlink = sorted(files_to_desymlink) 97 | files_to_preserve = sorted(files_to_preserve) 98 | files_to_delete = sorted(files_to_delete) 99 | 100 | print('Operations to perform (in order):') 101 | if len(files_to_desymlink) > 0: 102 | for file in files_to_desymlink: 103 | print(' - preserve (and dereference symlink): ' + file) 104 | if len(files_to_preserve) > 0: 105 | for file in files_to_preserve: 106 | print(' - preserve: ' + file) 107 | if len(files_to_delete) > 0: 108 | for file in files_to_delete: 109 | print(' - delete: ' + file) 110 | while True: 111 | resp = input('Continue? (Y/N): ') 112 | if resp.strip().lower() == 'y': 113 | break 114 | elif resp.strip().lower() == 'n': 115 | sys.exit(0) 116 | 117 | print('Executing...') 118 | if len(files_to_desymlink) > 0: 119 | for file in files_to_desymlink: 120 | realpath = os.path.realpath(file) 121 | print('rm ' + file) 122 | os.remove(file) 123 | print('cp {} {}'.format(realpath, file)) 124 | shutil.copyfile(realpath, file) 125 | if len(files_to_delete) > 0: 126 | for file in files_to_delete: 127 | print('rm ' + file) 128 | os.remove(file) 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /scripts/sacrebleu_pregen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -ne 4 ]; then 4 | echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" 5 | exit 1 6 | fi 7 | 8 | TESTSET=$1 9 | SRCLANG=$2 10 | TGTLANG=$3 11 | 12 | GEN=$4 13 | 14 | echo 'Cloning Moses github repository (for tokenization scripts)...' 15 | git clone https://github.com/moses-smt/mosesdecoder.git 16 | 17 | SCRIPTS=mosesdecoder/scripts 18 | DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl 19 | 20 | grep ^H $GEN \ 21 | | sed 's/^H\-//' \ 22 | | sort -n -k 1 \ 23 | | cut -f 3 \ 24 | | perl $DETOKENIZER -l $TGTLANG \ 25 | | sed "s/ - /-/g" \ 26 | > $GEN.sorted.detok 27 | 28 | sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok 29 | -------------------------------------------------------------------------------- /scripts/shard_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Split a large file into shards while respecting document boundaries. Documents 8 | should be separated by a single empty line. 9 | """ 10 | 11 | import argparse 12 | import contextlib 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('input') 18 | parser.add_argument('--num-shards', type=int) 19 | args = parser.parse_args() 20 | 21 | assert args.num_shards is not None and args.num_shards > 1 22 | 23 | with open(args.input, 'r', encoding='utf-8') as h: 24 | with contextlib.ExitStack() as stack: 25 | outputs = [ 26 | stack.enter_context(open(args.input + ".shard" + str(i), "w", encoding="utf-8")) 27 | for i in range(args.num_shards) 28 | ] 29 | 30 | doc = [] 31 | first_doc = [True]*args.num_shards 32 | def output_doc(i): 33 | if not first_doc[i]: 34 | outputs[i].write("\n") 35 | first_doc[i] = False 36 | for line in doc: 37 | outputs[i].write(line) 38 | doc.clear() 39 | 40 | num_docs = 0 41 | for line in h: 42 | if line.strip() == "": # empty line indicates new document 43 | output_doc(num_docs % args.num_shards) 44 | num_docs += 1 45 | else: 46 | doc.append(line) 47 | output_doc(num_docs % args.num_shards) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /scripts/split_train_valid_docs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Split a large file into a train and valid set while respecting document 8 | boundaries. Documents should be separated by a single empty line. 9 | """ 10 | 11 | import argparse 12 | import random 13 | import sys 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('input') 19 | parser.add_argument('sample_output', help='train output file') 20 | parser.add_argument('remainder_output', help='valid output file') 21 | parser.add_argument('-k', type=int, help="remainder size") 22 | parser.add_argument('--lines', action='store_true', 23 | help='split lines instead of docs') 24 | args = parser.parse_args() 25 | 26 | assert args.k is not None 27 | 28 | sample = [] 29 | remainder = [] 30 | num_docs = [0] 31 | 32 | def update_sample(doc): 33 | if len(sample) < args.k: 34 | sample.append(doc.copy()) 35 | else: 36 | i = num_docs[0] 37 | j = random.randrange(i + 1) 38 | if j < args.k: 39 | remainder.append(sample[j]) 40 | sample[j] = doc.copy() 41 | else: 42 | remainder.append(doc.copy()) 43 | num_docs[0] += 1 44 | doc.clear() 45 | 46 | with open(args.input, 'r', encoding='utf-8') as h: 47 | doc = [] 48 | for i, line in enumerate(h): 49 | if line.strip() == "": # empty line indicates new document 50 | update_sample(doc) 51 | else: 52 | doc.append(line) 53 | if args.lines: 54 | update_sample(doc) 55 | if i % 1000000 == 0: 56 | print(i, file=sys.stderr, end="", flush=True) 57 | elif i % 100000 == 0: 58 | print(".", file=sys.stderr, end="", flush=True) 59 | if len(doc) > 0: 60 | update_sample(doc) 61 | print(file=sys.stderr, flush=True) 62 | 63 | assert len(sample) == args.k 64 | 65 | with open(args.sample_output, 'w', encoding='utf-8') as out: 66 | first = True 67 | for doc in sample: 68 | if not first and not args.lines: 69 | out.write("\n") 70 | first = False 71 | for line in doc: 72 | out.write(line) 73 | 74 | with open(args.remainder_output, 'w', encoding='utf-8') as out: 75 | first = True 76 | for doc in remainder: 77 | if not first and not args.lines: 78 | out.write("\n") 79 | first = False 80 | for line in doc: 81 | out.write(line) 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /scripts/spm_decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--model", required=True, 18 | help="sentencepiece model to use for decoding") 19 | parser.add_argument("--input", required=True, help="input file to decode") 20 | parser.add_argument("--input_format", choices=["piece", "id"], default="piece") 21 | args = parser.parse_args() 22 | 23 | sp = spm.SentencePieceProcessor() 24 | sp.Load(args.model) 25 | 26 | if args.input_format == "piece": 27 | def decode(l): 28 | return "".join(sp.DecodePieces(l)) 29 | elif args.input_format == "id": 30 | def decode(l): 31 | return "".join(sp.DecodeIds(l)) 32 | else: 33 | raise NotImplementedError 34 | 35 | def tok2int(tok): 36 | # remap reference-side <unk> (represented as <<unk>>) to 0 37 | return int(tok) if tok != "<<unk>>" else 0 38 | 39 | with open(args.input, "r", encoding="utf-8") as h: 40 | for line in h: 41 | print(decode(list(map(tok2int, line.rstrip().split())))) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /scripts/spm_encode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import argparse 11 | import contextlib 12 | import sys 13 | 14 | import sentencepiece as spm 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--model", required=True, 20 | help="sentencepiece model to use for encoding") 21 | parser.add_argument("--inputs", nargs="+", default=['-'], 22 | help="input files to filter/encode") 23 | parser.add_argument("--outputs", nargs="+", default=['-'], 24 | help="path to save encoded outputs") 25 | parser.add_argument("--output_format", choices=["piece", "id"], default="piece") 26 | parser.add_argument("--min-len", type=int, metavar="N", 27 | help="filter sentence pairs with fewer than N tokens") 28 | parser.add_argument("--max-len", type=int, metavar="N", 29 | help="filter sentence pairs with more than N tokens") 30 | args = parser.parse_args() 31 | 32 | assert len(args.inputs) == len(args.outputs), \ 33 | "number of input and output paths should match" 34 | 35 | sp = spm.SentencePieceProcessor() 36 | sp.Load(args.model) 37 | 38 | if args.output_format == "piece": 39 | def encode(l): 40 | return sp.EncodeAsPieces(l) 41 | elif args.output_format == "id": 42 | def encode(l): 43 | return list(map(str, sp.EncodeAsIds(l))) 44 | else: 45 | raise NotImplementedError 46 | 47 | if args.min_len is not None or args.max_len is not None: 48 | def valid(line): 49 | return ( 50 | (args.min_len is None or len(line) >= args.min_len) 51 | and (args.max_len is None or len(line) <= args.max_len) 52 | ) 53 | else: 54 | def valid(lines): 55 | return True 56 | 57 | with contextlib.ExitStack() as stack: 58 | inputs = [ 59 | stack.enter_context(open(input, "r", encoding="utf-8")) \ 60 | if input != "-" else sys.stdin 61 | for input in args.inputs 62 | ] 63 | outputs = [ 64 | stack.enter_context(open(output, "w", encoding="utf-8")) \ 65 | if output != "-" else sys.stdout 66 | for output in args.outputs 67 | ] 68 | 69 | stats = { 70 | "num_empty": 0, 71 | "num_filtered": 0, 72 | } 73 | 74 | def encode_line(line): 75 | line = line.strip() 76 | if len(line) > 0: 77 | line = encode(line) 78 | if valid(line): 79 | return line 80 | else: 81 | stats["num_filtered"] += 1 82 | else: 83 | stats["num_empty"] += 1 84 | return None 85 | 86 | for i, lines in enumerate(zip(*inputs), start=1): 87 | enc_lines = list(map(encode_line, lines)) 88 | if not any(enc_line is None for enc_line in enc_lines): 89 | for enc_line, output_h in zip(enc_lines, outputs): 90 | print(" ".join(enc_line), file=output_h) 91 | if i % 10000 == 0: 92 | print("processed {} lines".format(i), file=sys.stderr) 93 | 94 | print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr) 95 | print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /scripts/spm_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | import sys 11 | 12 | import sentencepiece as spm 13 | 14 | 15 | if __name__ == "__main__": 16 | spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup, find_packages, Extension 8 | import sys 9 | 10 | 11 | if sys.version_info < (3,): 12 | sys.exit('Sorry, Python3 is required for fairseq.') 13 | 14 | with open('README.md') as f: 15 | readme = f.read() 16 | 17 | if sys.platform == 'darwin': 18 | extra_compile_args = ['-stdlib=libc++', '-O3'] 19 | extra_link_args = ['-stdlib=libc++'] 20 | else: 21 | extra_compile_args = ['-std=c++11', '-O3'] 22 | extra_link_args = ['-std=c++11'] 23 | 24 | bleu = Extension( 25 | 'fairseq.libbleu', 26 | sources=[ 27 | 'fairseq/clib/libbleu/libbleu.cpp', 28 | 'fairseq/clib/libbleu/module.cpp', 29 | ], 30 | extra_compile_args=extra_compile_args, 31 | ) 32 | 33 | 34 | def get_cython_modules(): 35 | token_block_utils = Extension( 36 | "fairseq.data.token_block_utils_fast", 37 | ["fairseq/data/token_block_utils_fast.pyx"], 38 | extra_compile_args=extra_compile_args, 39 | extra_link_args=extra_link_args, 40 | ) 41 | data_utils_fast = Extension( 42 | "fairseq.data.data_utils_fast", 43 | ["fairseq/data/data_utils_fast.pyx"], 44 | language="c++", 45 | extra_compile_args=extra_compile_args, 46 | extra_link_args=extra_link_args, 47 | ) 48 | return [token_block_utils, data_utils_fast] 49 | 50 | 51 | def my_build_ext(pars): 52 | """ 53 | Delay loading of numpy headers. 54 | More details: https://stackoverflow.com/questions/54117786/add-numpy-get-include-argument-to-setuptools-without-preinstalled-numpy 55 | """ 56 | from setuptools.command.build_ext import build_ext as _build_ext 57 | 58 | class build_ext(_build_ext): 59 | def finalize_options(self): 60 | _build_ext.finalize_options(self) 61 | __builtins__.__NUMPY_SETUP__ = False 62 | import numpy 63 | self.include_dirs.append(numpy.get_include()) 64 | return build_ext(pars) 65 | 66 | 67 | setup( 68 | name='fairseq', 69 | version='0.8.0', 70 | description='Facebook AI Research Sequence-to-Sequence Toolkit', 71 | url='https://github.com/pytorch/fairseq', 72 | classifiers=[ 73 | 'Intended Audience :: Science/Research', 74 | 'License :: OSI Approved :: MIT License', 75 | 'Programming Language :: Python :: 3.5', 76 | 'Programming Language :: Python :: 3.6', 77 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 78 | ], 79 | long_description=readme, 80 | long_description_content_type='text/markdown', 81 | setup_requires=[ 82 | 'numpy', 83 | 'cython', 84 | 'setuptools>=18.0', 85 | ], 86 | install_requires=[ 87 | 'cffi', 88 | 'fastBPE', 89 | 'numpy', 90 | 'regex', 91 | 'sacrebleu', 92 | 'torch', 93 | 'tqdm', 94 | ], 95 | packages=find_packages(exclude=['scripts', 'tests']), 96 | ext_modules=get_cython_modules() + [bleu], 97 | test_suite='tests', 98 | entry_points={ 99 | 'console_scripts': [ 100 | 'fairseq-eval-lm = eval_lm:cli_main', 101 | 'fairseq-generate = generate:cli_main', 102 | 'fairseq-preprocess = preprocess:cli_main', 103 | 'fairseq-score = score:main', 104 | 'fairseq-train = train:cli_main', 105 | ], 106 | }, 107 | cmdclass={'build_ext': my_build_ext}, 108 | zip_safe=False, 109 | ) -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | #!/usr/bin/env python3 -u 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | 10 | from fairseq import checkpoint_utils, options, progress_bar, utils 11 | 12 | 13 | def main(args, override_args=None): 14 | utils.import_user_module(args) 15 | 16 | use_fp16 = args.fp16 17 | use_cuda = torch.cuda.is_available() and not args.cpu 18 | 19 | if override_args is not None: 20 | overrides = vars(override_args) 21 | overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) 22 | else: 23 | overrides = None 24 | 25 | # Load ensemble 26 | print('| loading model(s) from {}'.format(args.path)) 27 | models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( 28 | [args.path], 29 | arg_overrides=overrides, 30 | ) 31 | model = models[0] 32 | 33 | # Move models to GPU 34 | for model in models: 35 | if use_fp16: 36 | model.half() 37 | if use_cuda: 38 | model.cuda() 39 | 40 | # Print args 41 | print(model_args) 42 | 43 | # Build criterion 44 | criterion = task.build_criterion(model_args) 45 | criterion.eval() 46 | 47 | # Load valid dataset (we load training data below, based on the latest checkpoint) 48 | for subset in args.valid_subset.split(','): 49 | try: 50 | task.load_dataset(subset, combine=False, epoch=0) 51 | dataset = task.dataset(subset) 52 | except KeyError: 53 | raise Exception('Cannot find dataset: ' + subset) 54 | 55 | # Initialize data iterator 56 | itr = task.get_batch_iterator( 57 | dataset=dataset, 58 | max_tokens=args.max_tokens, 59 | max_sentences=args.max_sentences, 60 | max_positions=utils.resolve_max_positions( 61 | task.max_positions(), 62 | *[m.max_positions() for m in models], 63 | ), 64 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 65 | required_batch_size_multiple=args.required_batch_size_multiple, 66 | seed=args.seed, 67 | num_workers=args.num_workers, 68 | ).next_epoch_itr(shuffle=False) 69 | progress = progress_bar.build_progress_bar( 70 | args, itr, 71 | prefix='valid on \'{}\' subset'.format(subset), 72 | no_progress_bar='simple' 73 | ) 74 | 75 | log_outputs = [] 76 | for i, sample in enumerate(progress): 77 | sample = utils.move_to_cuda(sample) if use_cuda else sample 78 | _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) 79 | progress.log(log_output, step=i) 80 | log_outputs.append(log_output) 81 | 82 | log_output = task.aggregate_logging_outputs(log_outputs, criterion) 83 | 84 | progress.print(log_output, tag=subset, step=i) 85 | 86 | 87 | def cli_main(): 88 | parser = options.get_validation_parser() 89 | args = options.parse_args_and_arch(parser) 90 | 91 | # only override args that are explicitly given on the command line 92 | override_parser = options.get_validation_parser() 93 | override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) 94 | 95 | main(args, override_args) 96 | 97 | 98 | if __name__ == '__main__': 99 | cli_main() 100 | --------------------------------------------------------------------------------