├── CONTRIBUTING.md ├── LICENSE ├── PATENTS ├── README.md ├── constants.py ├── data └── gap │ ├── gap-development.bert.tsv │ ├── gap-development.tsv │ ├── gap-test.bert.tsv │ ├── gap-test.tsv │ ├── gap-validation.bert.tsv │ └── gap-validation.tsv ├── distributed_train.py ├── eval_lm.py ├── fairseq.gif ├── fairseq ├── __init__.py ├── bleu.py ├── clib │ └── libbleu │ │ ├── libbleu.cpp │ │ └── module.cpp ├── criterions │ ├── __init__.py │ ├── adaptive_loss.py │ ├── cross_entropy.py │ ├── fairseq_criterion.py │ ├── gatekeeper_gap_bert.py │ └── label_smoothed_cross_entropy.py ├── data │ ├── __init__.py │ ├── bert_reader.py │ ├── data_utils.py │ ├── dictionary.py │ ├── fairseq_dataset.py │ ├── gap_reader.py │ ├── indexed_dataset.py │ ├── language_pair_dataset.py │ ├── monolingual_dataset.py │ ├── monolingual_gap_bert_dataset.py │ ├── token_block_dataset.py │ └── token_block_dataset_gap_bert.py ├── distributed_utils.py ├── fp16_trainer.py ├── meters.py ├── models │ ├── __init__.py │ ├── composite_encoder.py │ ├── fairseq_decoder.py │ ├── fairseq_encoder.py │ ├── fairseq_incremental_decoder.py │ ├── fairseq_model.py │ ├── fconv.py │ ├── fconv_self_att.py │ ├── gap_evaluator.py │ ├── lstm.py │ ├── lstm_cache.py │ ├── pronouns.py │ ├── refreader_gap_bert.py │ └── transformer.py ├── modules │ ├── __init__.py │ ├── adaptive_softmax.py │ ├── beamable_mm.py │ ├── character_token_embedder.py │ ├── conv_tbc.py │ ├── downsampled_multihead_attention.py │ ├── grad_multiply.py │ ├── highway.py │ ├── learned_positional_embedding.py │ ├── linearized_convolution.py │ ├── multihead_attention.py │ ├── scalar_bias.py │ └── sinusoidal_positional_embedding.py ├── multiprocessing_pdb.py ├── optim │ ├── __init__.py │ ├── adagrad.py │ ├── adam.py │ ├── fairseq_optimizer.py │ ├── lr_scheduler │ │ ├── __init__.py │ │ ├── fairseq_lr_scheduler.py │ │ ├── fixed_schedule.py │ │ ├── inverse_square_root_schedule.py │ │ ├── reduce_lr_on_plateau.py │ │ └── reduce_lr_on_plateau_patience.py │ ├── nag.py │ ├── rmsprop.py │ └── sgd.py ├── options.py ├── progress_bar.py ├── sequence_generator.py ├── sequence_scorer.py ├── tasks │ ├── __init__.py │ ├── fairseq_task.py │ ├── language_modeling.py │ └── language_modeling_refreader_gap_bert.py ├── tokenizer.py ├── trainer.py └── utils.py ├── gap_scorer.py ├── generate.py ├── interactive.py ├── preprocess.py ├── score.py ├── setup.py ├── train.py └── txt2dict.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to FAIR Sequence-to-Sequence Toolkit (PyTorch) 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## Coding Style 26 | We try to follow the PEP style guidelines and encourage you to as well. 27 | 28 | ## License 29 | By contributing to FAIR Sequence-to-Sequence Toolkit, you agree that your contributions will be licensed 30 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For fairseq software 4 | 5 | Copyright (c) 2017-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the fairseq software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | RefReader 2 | ======== 3 | 4 | This repository hosts the implementation of [the referential reader](https://www.aclweb.org/anthology/P19-1593), built on a fork of [fairseq](https://github.com/pytorch/fairseq). 5 | 6 | ``` 7 | @inproceedings{liu-etal-2019-referential, 8 | title = "The Referential Reader: A Recurrent Entity Network for Anaphora Resolution", 9 | author = "Liu, Fei and Zettlemoyer, Luke and Eisenstein, Jacob", 10 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", 11 | month = jul, 12 | year = "2019", 13 | address = "Florence, Italy", 14 | publisher = "Association for Computational Linguistics", 15 | url = "https://www.aclweb.org/anthology/P19-1593", 16 | doi = "10.18653/v1/P19-1593", 17 | pages = "5918--5925", 18 | } 19 | ``` 20 | 21 | # Environment 22 | 23 | ``` 24 | Python 3.6 25 | PyTorch 1.1 26 | tqdm 4.32.1 27 | ``` 28 | 29 | # Prepare GAP text for BERT pre-processing 30 | Replace `${REFREADER_PATH}` with path to your local copy of the RefReader repo. 31 | ``` 32 | $ tail -n +2 ${REFREADER_PATH}/data/gap/gap-development.tsv | cut -d$'\t' -f2 > ${REFREADER_PATH}/data/gap/gap-development.txt 33 | $ tail -n +2 ${REFREADER_PATH}/data/gap/gap-test.tsv | cut -d$'\t' -f2 > ${REFREADER_PATH}/data/gap/gap-test.txt 34 | $ tail -n +2 ${REFREADER_PATH}/data/gap/gap-validation.tsv | cut -d$'\t' -f2 > ${REFREADER_PATH}/data/gap/gap-validation.txt 35 | ``` 36 | 37 | # Prepare GAP text for fairseq pre-processing 38 | Extract text from (BERT-)tokenized GAP tsv files (credit goes to [Mandar Joshi](https://homes.cs.washington.edu/~mandar90/) for preparing these files): 39 | ``` 40 | $ tail -n +2 ${REFREADER_PATH}/data/gap/gap-development.bert.tsv | cut -d$'\t' -f2 > ${REFREADER_PATH}/data/gap/gap-development.bert.txt 41 | $ tail -n +2 ${REFREADER_PATH}/data/gap/gap-test.bert.tsv | cut -d$'\t' -f2 > ${REFREADER_PATH}/data/gap/gap-test.bert.txt 42 | $ tail -n +2 ${REFREADER_PATH}/data/gap/gap-validation.bert.tsv | cut -d$'\t' -f2 > ${REFREADER_PATH}/data/gap/gap-validation.bert.txt 43 | ``` 44 | Construct dictionary: 45 | ``` 46 | $ python txt2dict.py ${REFREADER_PATH}/data/gap/gap-development.bert.txt ${REFREADER_PATH}/data/gap/gap-test.bert.txt ${REFREADER_PATH}/data/gap/gap-validation.bert.txt ${REFREADER_PATH}/data/gap/gap-bert.dict 47 | ``` 48 | 49 | # Fairseq pre-processing 50 | ``` 51 | $ python preprocess.py --only-source --trainpref ${REFREADER_PATH}/data/gap/gap-development.bert.txt --validpref ${REFREADER_PATH}/data/gap/gap-validation.bert.txt --testpref ${REFREADER_PATH}/data/gap/gap-test.bert.txt --destdir ${REFREADER_PATH}/data/gap-bert-bin/ --srcdict ${REFREADER_PATH}/data/gap/gap-bert.dict 52 | 53 | ``` 54 | And then copy and re-name files: 55 | ``` 56 | $ cp ${REFREADER_PATH}/data/gap/gap-development.bert.tsv ${REFREADER_PATH}/data/gap-bert-bin/gap-train.bert.tsv 57 | $ cp ${REFREADER_PATH}/data/gap/gap-validation.bert.tsv ${REFREADER_PATH}/data/gap-bert-bin/gap-valid.bert.tsv 58 | $ cp ${REFREADER_PATH}/data/gap/gap-test.bert.tsv ${REFREADER_PATH}/data/gap-bert-bin/gap-test.bert.tsv 59 | ``` 60 | 61 | # Extract BERT features for GAP text 62 | + Clone code at [BERT GitHub repo](https://github.com/google-research/bert) and follow the instructions in [README.md](https://github.com/google-research/bert/blob/master/README.md) to configure the environment. 63 | + Download [BERT pre-trained model](https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip). 64 | + Replace `${BERT_PATH}` with path to your local copy of the BERT repo. 65 | 66 | 67 | - Extract BERT features for `gap-development.txt` `gap-train.bert.jsonl` 68 | ``` 69 | $ python ${BERT_PATH}/extract_features.py --input_file=${REFREADER_PATH}/data/gap/gap-development.txt --output_file=${REFREADER_PATH}/data/gap-bert-bin/gap-train.bert.jsonl --vocab_file=${BERT_MODEL}/cased_L-12_H-768_A-12/vocab.txt --bert_config_file=${BERT_MODEL}/cased_L-12_H-768_A-12/bert_config.json --init_checkpoint=${BERT_MODEL}/cased_L-12_H-768_A-12/bert_model.ckpt --layers=-1,-2,-3,-4 --max_seq_length=512 --batch_size=1 --do_lower_case=False 70 | ``` 71 | - Extract BERT features for `gap-validation.txt` `gap-valid.bert.jsonl` 72 | ``` 73 | $ python ${BERT_PATH}/extract_features.py --input_file=${REFREADER_PATH}/data/gap/gap-validation.txt --output_file=${REFREADER_PATH}/data/gap-bert-bin/gap-valid.bert.jsonl --vocab_file=${BERT_MODEL}/cased_L-12_H-768_A-12/vocab.txt --bert_config_file=${BERT_MODEL}/cased_L-12_H-768_A-12/bert_config.json --init_checkpoint=${BERT_MODEL}/cased_L-12_H-768_A-12/bert_model.ckpt --layers=-1,-2,-3,-4 --max_seq_length=512 --batch_size=1 --do_lower_case=False 74 | ``` 75 | - Extract BERT features for `gap-test.txt` `gap-test.bert.jsonl` 76 | ``` 77 | $ python ${BERT_PATH}/extract_features.py --input_file=${REFREADER_PATH}/data/gap/gap-test.txt --output_file=${REFREADER_PATH}/data/gap-bert-bin/gap-test.bert.jsonl --vocab_file=${BERT_MODEL}/cased_L-12_H-768_A-12/vocab.txt --bert_config_file=${BERT_MODEL}/cased_L-12_H-768_A-12/bert_config.json --init_checkpoint=${BERT_MODEL}/cased_L-12_H-768_A-12/bert_model.ckpt --layers=-1,-2,-3,-4 --max_seq_length=512 --batch_size=1 --do_lower_case=False 78 | ``` 79 | 80 | # Training 81 | ``` 82 | $ python train.py --task language_modeling_refreader_gap_bert --arch refreader_gap_bert data/gap-bert-bin/ --restore-file=NoFile --distributed-world-size=1 --max-tokens 5000 --lr 1e-3 --optimizer adam --train-subset train --valid-subset valid --criterion gatekeeper_gap_bert --no-epoch-checkpoints --save-dir ${SAVE_DIR} 83 | ``` 84 | 85 | When training ends, the final output should look like this: 86 | ``` 87 | | epoch 035 | valid on 'valid' subset | valid_loss -78.4106 | valid_ppl 24505.21 | num_updates 2975 | best 80 | best_threshold 0.07 | f@0.04 78.4106 | mf@0.04 78.5908 | ff@0.04 78.2383 88 | ``` 89 | + the last three items show the F-scores based on the current model in three categories at the threshold value of `@0.04` on the `valid` set: 90 | + overall: `f@0.04 78.4106` 91 | + masculine `mf@0.04 78.5908` 92 | + feminine `ff@0.04 78.2383` 93 | + `best 80` indicates the best overall F-score up to the current epoch on the `valid` set (the snapshot of this best performing model is saved at `${SAVE_DIR}`) 94 | + `best_threshold 0.07` is the threshold value the `best 80` validation F-score was achieved at 95 | 96 | # Predict 97 | ``` 98 | $ python train.py --task language_modeling_refreader_gap_bert --arch refreader_gap_bert data/gap-bert-bin/ --distributed-world-size=1 --max-tokens 5000 --lr 1e-3 --optimizer adam --train-subset train --valid-subset test --criterion gatekeeper_gap_bert --no-epoch-checkpoints --restore-dir ${SAVE_DIR} --restore-file checkpoint_best.pt --threshold ${BEST_THRESHOLD} --no-save --no-train 99 | ``` 100 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Constants. 18 | """ 19 | 20 | from enum import Enum 21 | 22 | 23 | class Gender(Enum): 24 | UNKNOWN = 0 25 | MASCULINE = 1 26 | FEMININE = 2 27 | 28 | 29 | # Mapping of (lowercased) pronoun form to gender value. Note that reflexives 30 | # are not included in GAP, so do not appear here. 31 | PRONOUNS = { 32 | 'she': Gender.FEMININE, 33 | 'her': Gender.FEMININE, 34 | 'hers': Gender.FEMININE, 35 | 'he': Gender.MASCULINE, 36 | 'his': Gender.MASCULINE, 37 | 'him': Gender.MASCULINE, 38 | } 39 | 40 | # Fieldnames used in the gold dataset .tsv file. 41 | GOLD_FIELDNAMES = [ 42 | 'ID', 'Text', 'Pronoun', 'Pronoun-offset', 'A', 'A-offset', 'A-coref', 'B', 43 | 'B-offset', 'B-coref', 'URL' 44 | ] 45 | 46 | # Fieldnames expected in system output .tsv files. 47 | SYSTEM_FIELDNAMES = ['ID', 'A-coref', 'B-coref'] 48 | -------------------------------------------------------------------------------- /distributed_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import os 10 | import socket 11 | import subprocess 12 | 13 | from train import main as single_process_main 14 | from fairseq import distributed_utils, options 15 | 16 | 17 | def main(args): 18 | if args.distributed_init_method is None and args.distributed_port > 0: 19 | # We can determine the init method automatically for Slurm. 20 | node_list = os.environ.get('SLURM_JOB_NODELIST') 21 | if node_list is not None: 22 | try: 23 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list]) 24 | args.distributed_init_method = 'tcp://{host}:{port}'.format( 25 | host=hostnames.split()[0].decode('utf-8'), 26 | port=args.distributed_port) 27 | args.distributed_rank = int(os.environ.get('SLURM_PROCID')) 28 | args.device_id = int(os.environ.get('SLURM_LOCALID')) 29 | except subprocess.CalledProcessError as e: # scontrol failed 30 | raise e 31 | except FileNotFoundError as e: # Slurm is not installed 32 | pass 33 | if args.distributed_init_method is None: 34 | raise ValueError('--distributed-init-method or --distributed-port ' 35 | 'must be specified for distributed training') 36 | 37 | args.distributed_rank = distributed_utils.distributed_init(args) 38 | print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) 39 | single_process_main(args) 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = options.get_training_parser() 44 | args = options.parse_args_and_arch(parser) 45 | main(args) 46 | -------------------------------------------------------------------------------- /eval_lm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from fairseq import data, options, progress_bar, tasks, utils 13 | from fairseq.meters import StopwatchMeter, TimeMeter 14 | from fairseq.sequence_scorer import SequenceScorer 15 | 16 | 17 | def main(args): 18 | assert args.path is not None, '--path required for evaluation!' 19 | 20 | args.tokens_per_sample = getattr(args, 'tokens_per_sample', 1024) 21 | print(args) 22 | 23 | use_cuda = torch.cuda.is_available() and not args.cpu 24 | 25 | # Load dataset splits 26 | task = tasks.setup_task(args) 27 | task.load_dataset(args.gen_subset) 28 | print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) 29 | 30 | # Load ensemble 31 | print('| loading model(s) from {}'.format(args.path)) 32 | models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task) 33 | 34 | # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) 35 | for model in models: 36 | model.make_generation_fast_() 37 | if args.fp16: 38 | model.half() 39 | 40 | assert len(models) > 0 41 | 42 | itr = data.EpochBatchIterator( 43 | dataset=task.dataset(args.gen_subset), 44 | max_tokens=args.max_tokens or 36000, 45 | max_sentences=args.max_sentences, 46 | max_positions=models[0].max_positions(), 47 | num_shards=args.num_shards, 48 | shard_id=args.shard_id, 49 | ignore_invalid_inputs=True, 50 | ).next_epoch_itr(shuffle=False) 51 | 52 | gen_timer = StopwatchMeter() 53 | scorer = SequenceScorer(models, task.target_dictionary) 54 | if use_cuda: 55 | scorer.cuda() 56 | 57 | score_sum = 0. 58 | count = 0 59 | 60 | if args.remove_bpe is not None: 61 | bpe_cont = args.remove_bpe.rstrip() 62 | bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont)) 63 | bpe_len = len(bpe_cont) 64 | else: 65 | bpe_toks = None 66 | bpe_len = 0 67 | 68 | with progress_bar.build_progress_bar(args, itr) as t: 69 | results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) 70 | wps_meter = TimeMeter() 71 | for _, src_tokens, __, hypos in results: 72 | for hypo in hypos: 73 | pos_scores = hypo['positional_scores'] 74 | 75 | skipped_toks = 0 76 | if bpe_toks is not None: 77 | for i in range(len(hypo['tokens']) - 1): 78 | if hypo['tokens'][i].item() in bpe_toks: 79 | skipped_toks += 1 80 | pos_scores[i + 1] += pos_scores[i] 81 | pos_scores[i] = 0 82 | 83 | inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) 84 | if inf_scores.any(): 85 | print('| Skipping tokens with inf scores:', 86 | task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()])) 87 | pos_scores = pos_scores[(~inf_scores).nonzero()] 88 | score_sum += pos_scores.sum() 89 | count += pos_scores.numel() - skipped_toks 90 | 91 | if args.output_word_probs: 92 | w = '' 93 | word_prob = [] 94 | for i in range(len(hypo['tokens'])): 95 | w_ind = hypo['tokens'][i].item() 96 | w += task.dictionary[w_ind] 97 | if bpe_toks is not None and w_ind in bpe_toks: 98 | w = w[:-bpe_len] 99 | else: 100 | word_prob.append((w, pos_scores[i].item())) 101 | w = '' 102 | print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) 103 | 104 | wps_meter.update(src_tokens.size(0)) 105 | t.log({'wps': round(wps_meter.avg)}) 106 | 107 | avg_nll_loss = -score_sum / count 108 | print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) 109 | print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss))) 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = options.get_eval_lm_parser() 114 | args = options.parse_args_and_arch(parser) 115 | main(args) 116 | -------------------------------------------------------------------------------- /fairseq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liufly/refreader/25d371fc08d89174cfdac1c7e29984d8cb3beff2/fairseq.gif -------------------------------------------------------------------------------- /fairseq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .multiprocessing_pdb import pdb 9 | 10 | __all__ = ['pdb'] 11 | -------------------------------------------------------------------------------- /fairseq/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import ctypes 9 | import math 10 | import torch 11 | 12 | try: 13 | from fairseq import libbleu 14 | except ImportError as e: 15 | import sys 16 | sys.stderr.write('ERROR: missing libbleu.so. run `python setup.py install`\n') 17 | raise e 18 | 19 | 20 | C = ctypes.cdll.LoadLibrary(libbleu.__file__) 21 | 22 | 23 | class BleuStat(ctypes.Structure): 24 | _fields_ = [ 25 | ('reflen', ctypes.c_size_t), 26 | ('predlen', ctypes.c_size_t), 27 | ('match1', ctypes.c_size_t), 28 | ('count1', ctypes.c_size_t), 29 | ('match2', ctypes.c_size_t), 30 | ('count2', ctypes.c_size_t), 31 | ('match3', ctypes.c_size_t), 32 | ('count3', ctypes.c_size_t), 33 | ('match4', ctypes.c_size_t), 34 | ('count4', ctypes.c_size_t), 35 | ] 36 | 37 | 38 | class Scorer(object): 39 | def __init__(self, pad, eos, unk): 40 | self.stat = BleuStat() 41 | self.pad = pad 42 | self.eos = eos 43 | self.unk = unk 44 | self.reset() 45 | 46 | def reset(self, one_init=False): 47 | if one_init: 48 | C.bleu_one_init(ctypes.byref(self.stat)) 49 | else: 50 | C.bleu_zero_init(ctypes.byref(self.stat)) 51 | 52 | def add(self, ref, pred): 53 | if not isinstance(ref, torch.IntTensor): 54 | raise TypeError('ref must be a torch.IntTensor (got {})' 55 | .format(type(ref))) 56 | if not isinstance(pred, torch.IntTensor): 57 | raise TypeError('pred must be a torch.IntTensor(got {})' 58 | .format(type(pred))) 59 | 60 | # don't match unknown words 61 | rref = ref.clone() 62 | assert not rref.lt(0).any() 63 | rref[rref.eq(self.unk)] = -999 64 | 65 | rref = rref.contiguous().view(-1) 66 | pred = pred.contiguous().view(-1) 67 | 68 | C.bleu_add( 69 | ctypes.byref(self.stat), 70 | ctypes.c_size_t(rref.size(0)), 71 | ctypes.c_void_p(rref.data_ptr()), 72 | ctypes.c_size_t(pred.size(0)), 73 | ctypes.c_void_p(pred.data_ptr()), 74 | ctypes.c_int(self.pad), 75 | ctypes.c_int(self.eos)) 76 | 77 | def score(self, order=4): 78 | psum = sum(math.log(p) if p > 0 else float('-Inf') 79 | for p in self.precision()[:order]) 80 | return self.brevity() * math.exp(psum / order) * 100 81 | 82 | def precision(self): 83 | def ratio(a, b): 84 | return a / b if b > 0 else 0 85 | 86 | return [ 87 | ratio(self.stat.match1, self.stat.count1), 88 | ratio(self.stat.match2, self.stat.count2), 89 | ratio(self.stat.match3, self.stat.count3), 90 | ratio(self.stat.match4, self.stat.count4), 91 | ] 92 | 93 | def brevity(self): 94 | r = self.stat.reflen / self.stat.predlen 95 | return min(1, math.exp(1 - r)) 96 | 97 | def result_string(self, order=4): 98 | assert order <= 4, "BLEU scores for order > 4 aren't supported" 99 | fmt = 'BLEU{} = {:2.2f}, {:2.1f}' 100 | for _ in range(1, order): 101 | fmt += '/{:2.1f}' 102 | fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})' 103 | bleup = [p * 100 for p in self.precision()[:order]] 104 | return fmt.format(order, self.score(order=order), *bleup, 105 | self.brevity(), self.stat.predlen/self.stat.reflen, 106 | self.stat.predlen, self.stat.reflen) 107 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/libbleu.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | typedef struct 15 | { 16 | size_t reflen; 17 | size_t predlen; 18 | size_t match1; 19 | size_t count1; 20 | size_t match2; 21 | size_t count2; 22 | size_t match3; 23 | size_t count3; 24 | size_t match4; 25 | size_t count4; 26 | } bleu_stat; 27 | 28 | // left trim (remove pad) 29 | void bleu_ltrim(size_t* len, int** sent, int pad) { 30 | size_t start = 0; 31 | while(start < *len) { 32 | if (*(*sent + start) != pad) { break; } 33 | start++; 34 | } 35 | *sent += start; 36 | *len -= start; 37 | } 38 | 39 | // right trim remove (eos) 40 | void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { 41 | size_t end = *len - 1; 42 | while (end > 0) { 43 | if (*(*sent + end) != eos && *(*sent + end) != pad) { break; } 44 | end--; 45 | } 46 | *len = end + 1; 47 | } 48 | 49 | // left and right trim 50 | void bleu_trim(size_t* len, int** sent, int pad, int eos) { 51 | bleu_ltrim(len, sent, pad); 52 | bleu_rtrim(len, sent, pad, eos); 53 | } 54 | 55 | size_t bleu_hash(int len, int* data) { 56 | size_t h = 14695981039346656037ul; 57 | size_t prime = 0x100000001b3; 58 | char* b = (char*) data; 59 | size_t blen = sizeof(int) * len; 60 | 61 | while (blen-- > 0) { 62 | h ^= *b++; 63 | h *= prime; 64 | } 65 | 66 | return h; 67 | } 68 | 69 | void bleu_addngram( 70 | size_t *ntotal, size_t *nmatch, size_t n, 71 | size_t reflen, int* ref, size_t predlen, int* pred) { 72 | 73 | if (predlen < n) { return; } 74 | 75 | predlen = predlen - n + 1; 76 | (*ntotal) += predlen; 77 | 78 | if (reflen < n) { return; } 79 | 80 | reflen = reflen - n + 1; 81 | 82 | std::map count; 83 | while (predlen > 0) { 84 | size_t w = bleu_hash(n, pred++); 85 | count[w]++; 86 | predlen--; 87 | } 88 | 89 | while (reflen > 0) { 90 | size_t w = bleu_hash(n, ref++); 91 | if (count[w] > 0) { 92 | (*nmatch)++; 93 | count[w] -=1; 94 | } 95 | reflen--; 96 | } 97 | } 98 | 99 | extern "C" { 100 | 101 | void bleu_zero_init(bleu_stat* stat) { 102 | std::memset(stat, 0, sizeof(bleu_stat)); 103 | } 104 | 105 | void bleu_one_init(bleu_stat* stat) { 106 | bleu_zero_init(stat); 107 | stat->count1 = 0; 108 | stat->count2 = 1; 109 | stat->count3 = 1; 110 | stat->count4 = 1; 111 | stat->match1 = 0; 112 | stat->match2 = 1; 113 | stat->match3 = 1; 114 | stat->match4 = 1; 115 | } 116 | 117 | void bleu_add( 118 | bleu_stat* stat, 119 | size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) { 120 | 121 | bleu_trim(&reflen, &ref, pad, eos); 122 | bleu_trim(&predlen, &pred, pad, eos); 123 | stat->reflen += reflen; 124 | stat->predlen += predlen; 125 | 126 | bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); 127 | bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); 128 | bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); 129 | bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); 130 | } 131 | 132 | } 133 | -------------------------------------------------------------------------------- /fairseq/clib/libbleu/module.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | 11 | 12 | static PyMethodDef method_def[] = { 13 | {NULL, NULL, 0, NULL} 14 | }; 15 | 16 | static struct PyModuleDef module_def = { 17 | PyModuleDef_HEAD_INIT, 18 | "libbleu", /* name of module */ 19 | NULL, /* module documentation, may be NULL */ 20 | -1, /* size of per-interpreter state of the module, 21 | or -1 if the module keeps state in global variables. */ 22 | method_def 23 | }; 24 | 25 | 26 | #if PY_MAJOR_VERSION == 2 27 | PyMODINIT_FUNC init_libbleu() 28 | #else 29 | PyMODINIT_FUNC PyInit_libbleu() 30 | #endif 31 | { 32 | PyObject *m = PyModule_Create(&module_def); 33 | if (!m) { 34 | return NULL; 35 | } 36 | return m; 37 | } 38 | -------------------------------------------------------------------------------- /fairseq/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_criterion import FairseqCriterion 12 | 13 | 14 | CRITERION_REGISTRY = {} 15 | CRITERION_CLASS_NAMES = set() 16 | 17 | 18 | def build_criterion(args, task): 19 | return CRITERION_REGISTRY[args.criterion](args, task) 20 | 21 | 22 | def register_criterion(name): 23 | """Decorator to register a new criterion.""" 24 | 25 | def register_criterion_cls(cls): 26 | if name in CRITERION_REGISTRY: 27 | raise ValueError('Cannot register duplicate criterion ({})'.format(name)) 28 | if not issubclass(cls, FairseqCriterion): 29 | raise ValueError('Criterion ({}: {}) must extend FairseqCriterion'.format(name, cls.__name__)) 30 | if cls.__name__ in CRITERION_CLASS_NAMES: 31 | # We use the criterion class name as a unique identifier in 32 | # checkpoints, so all criterions must have unique class names. 33 | raise ValueError('Cannot register criterion with duplicate class name ({})'.format(cls.__name__)) 34 | CRITERION_REGISTRY[name] = cls 35 | CRITERION_CLASS_NAMES.add(cls.__name__) 36 | return cls 37 | 38 | return register_criterion_cls 39 | 40 | 41 | # automatically import any Python files in the criterions/ directory 42 | for file in os.listdir(os.path.dirname(__file__)): 43 | if file.endswith('.py') and not file.startswith('_'): 44 | module = file[:file.find('.py')] 45 | importlib.import_module('fairseq.criterions.' + module) 46 | -------------------------------------------------------------------------------- /fairseq/criterions/adaptive_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import math 10 | import torch.nn.functional as F 11 | 12 | from fairseq import utils 13 | from . import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion('adaptive_loss') 17 | class AdaptiveLoss(FairseqCriterion): 18 | """This is an implementation of the loss function accompanying the adaptive softmax approximation for 19 | graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" 20 | (http://arxiv.org/abs/1609.04309).""" 21 | 22 | def __init__(self, args, task): 23 | super().__init__(args, task) 24 | 25 | def forward(self, model, sample, reduce=True): 26 | """Compute the loss for the given sample. 27 | 28 | Returns a tuple with three elements: 29 | 1) the loss 30 | 2) the sample size, which is used as the denominator for the gradient 31 | 3) logging outputs to display while training 32 | """ 33 | 34 | assert hasattr(model.decoder, 'adaptive_softmax') and model.decoder.adaptive_softmax is not None 35 | adaptive_softmax = model.decoder.adaptive_softmax 36 | 37 | net_output = model(**sample['net_input']) 38 | target = model.get_targets(sample, net_output).view(-1) 39 | 40 | bsz = target.size(0) 41 | 42 | logits, target = adaptive_softmax(net_output[0], target) 43 | assert len(target) == len(logits) 44 | 45 | loss = net_output[0].new(1 if reduce else bsz).zero_() 46 | 47 | for i in range(len(target)): 48 | if target[i] is not None: 49 | assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1)) 50 | loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx, 51 | reduce=reduce) 52 | 53 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 54 | logging_output = { 55 | 'loss': utils.item(loss.data) if reduce else loss.data, 56 | 'ntokens': sample['ntokens'], 57 | 'sample_size': sample_size, 58 | } 59 | return loss, sample_size, logging_output 60 | 61 | @staticmethod 62 | def aggregate_logging_outputs(logging_outputs): 63 | """Aggregate logging outputs from data parallel training.""" 64 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 65 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 66 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 67 | agg_output = { 68 | 'loss': loss_sum / sample_size / math.log(2), 69 | 'sample_size': sample_size, 70 | } 71 | if sample_size != ntokens: 72 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 73 | return agg_output 74 | -------------------------------------------------------------------------------- /fairseq/criterions/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | 13 | from . import FairseqCriterion, register_criterion 14 | 15 | 16 | @register_criterion('cross_entropy') 17 | class CrossEntropyCriterion(FairseqCriterion): 18 | 19 | def __init__(self, args, task): 20 | super().__init__(args, task) 21 | 22 | def forward(self, model, sample, reduce=True): 23 | """Compute the loss for the given sample. 24 | 25 | Returns a tuple with three elements: 26 | 1) the loss 27 | 2) the sample size, which is used as the denominator for the gradient 28 | 3) logging outputs to display while training 29 | """ 30 | net_output = model(**sample['net_input']) 31 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 32 | lprobs = lprobs.view(-1, lprobs.size(-1)) 33 | target = model.get_targets(sample, net_output).view(-1) 34 | loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, 35 | reduce=reduce) 36 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 37 | logging_output = { 38 | 'loss': utils.item(loss.data) if reduce else loss.data, 39 | 'ntokens': sample['ntokens'], 40 | 'sample_size': sample_size, 41 | } 42 | return loss, sample_size, logging_output 43 | 44 | @staticmethod 45 | def aggregate_logging_outputs(logging_outputs): 46 | """Aggregate logging outputs from data parallel training.""" 47 | loss_sum = sum(log.get('loss', 0) for log in logging_outputs) 48 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 49 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 50 | agg_output = { 51 | 'loss': loss_sum / sample_size / math.log(2), 52 | 'sample_size': sample_size, 53 | } 54 | if sample_size != ntokens: 55 | agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) 56 | return agg_output 57 | -------------------------------------------------------------------------------- /fairseq/criterions/fairseq_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.nn.modules.loss import _Loss 9 | 10 | 11 | class FairseqCriterion(_Loss): 12 | 13 | def __init__(self, args, task): 14 | super().__init__() 15 | self.args = args 16 | self.task = task 17 | self.padding_idx = task.target_dictionary.pad() 18 | 19 | @staticmethod 20 | def add_args(parser): 21 | """Add criterion-specific arguments to the parser.""" 22 | pass 23 | 24 | def forward(self, model, sample, reduce=True): 25 | """Compute the loss for the given sample. 26 | 27 | Returns a tuple with three elements: 28 | 1) the loss 29 | 2) the sample size, which is used as the denominator for the gradient 30 | 3) logging outputs to display while training 31 | """ 32 | raise NotImplementedError 33 | 34 | @staticmethod 35 | def aggregate_logging_outputs(logging_outputs): 36 | """Aggregate logging outputs from data parallel training.""" 37 | raise NotImplementedError 38 | 39 | @staticmethod 40 | def grad_denom(sample_sizes): 41 | """Compute the gradient denominator for a set of sample sizes.""" 42 | return sum(sample_sizes) 43 | -------------------------------------------------------------------------------- /fairseq/criterions/label_smoothed_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | from fairseq import utils 11 | 12 | from . import FairseqCriterion, register_criterion 13 | 14 | 15 | @register_criterion('label_smoothed_cross_entropy') 16 | class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): 17 | 18 | def __init__(self, args, task): 19 | super().__init__(args, task) 20 | self.eps = args.label_smoothing 21 | 22 | @staticmethod 23 | def add_args(parser): 24 | """Add criterion-specific arguments to the parser.""" 25 | parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', 26 | help='epsilon for label smoothing, 0 means no label smoothing') 27 | 28 | def forward(self, model, sample, reduce=True): 29 | """Compute the loss for the given sample. 30 | 31 | Returns a tuple with three elements: 32 | 1) the loss 33 | 2) the sample size, which is used as the denominator for the gradient 34 | 3) logging outputs to display while training 35 | """ 36 | net_output = model(**sample['net_input']) 37 | lprobs = model.get_normalized_probs(net_output, log_probs=True) 38 | lprobs = lprobs.view(-1, lprobs.size(-1)) 39 | target = model.get_targets(sample, net_output).view(-1, 1) 40 | non_pad_mask = target.ne(self.padding_idx) 41 | nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] 42 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] 43 | if reduce: 44 | nll_loss = nll_loss.sum() 45 | smooth_loss = smooth_loss.sum() 46 | eps_i = self.eps / lprobs.size(-1) 47 | loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss 48 | 49 | sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] 50 | logging_output = { 51 | 'loss': utils.item(loss.data) if reduce else loss.data, 52 | 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 53 | 'ntokens': sample['ntokens'], 54 | 'sample_size': sample_size, 55 | } 56 | return loss, sample_size, logging_output 57 | 58 | @staticmethod 59 | def aggregate_logging_outputs(logging_outputs): 60 | """Aggregate logging outputs from data parallel training.""" 61 | ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) 62 | sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) 63 | return { 64 | 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 65 | 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2), 66 | 'sample_size': sample_size, 67 | } 68 | -------------------------------------------------------------------------------- /fairseq/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .dictionary import Dictionary 9 | from .fairseq_dataset import FairseqDataset 10 | from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset # noqa: F401 11 | from .language_pair_dataset import LanguagePairDataset 12 | from .monolingual_dataset import MonolingualDataset 13 | from .monolingual_gap_bert_dataset import MonolingualGapBertDataset 14 | from .token_block_dataset import TokenBlockDataset 15 | from .token_block_dataset_gap_bert import TokenBlockGapBertDataset 16 | 17 | from .data_utils import EpochBatchIterator 18 | 19 | from .gap_reader import GAP_Reader 20 | from .bert_reader import Bert_Reader -------------------------------------------------------------------------------- /fairseq/data/bert_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import numpy as np 4 | 5 | class Bert_Reader: 6 | 7 | def __init__(self, in_bert_json_file_path): 8 | self.in_bert_json_file_path = in_bert_json_file_path 9 | 10 | def read(self): 11 | data = [] 12 | with open(self.in_bert_json_file_path, 'r') as f: 13 | for line in f: 14 | line = line.strip() 15 | bert_json = json.loads(line) 16 | tokens = [feature['token'] for feature in bert_json['features']] 17 | layers = [] 18 | nb_layers = len(bert_json['features'][0]['layers']) 19 | for i in range(nb_layers): 20 | cur_layer = [] 21 | for feature in bert_json['features']: 22 | assert -(i+1) == feature['layers'][i]['index'] 23 | values = feature['layers'][i]['values'] 24 | values = np.array(values) 25 | cur_layer.append(values) 26 | assert len(cur_layer) == len(tokens) 27 | layers.append(cur_layer) 28 | layers = np.array(layers) 29 | # [nb_layers (4), seq_len, emb_size] 30 | data.append((tokens, layers)) 31 | return data 32 | 33 | if __name__ == "__main__": 34 | in_bert_json_file_path = "data/gap-bert-bin/gap-train.bert.jsonl" 35 | reader = Bert_Reader(in_bert_json_file_path) 36 | bert_data = reader.read() -------------------------------------------------------------------------------- /fairseq/data/dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from collections import Counter 9 | import os 10 | 11 | import torch 12 | 13 | 14 | class Dictionary(object): 15 | """A mapping from symbols to consecutive integers""" 16 | def __init__(self, pad='', eos='', unk=''): 17 | self.unk_word, self.pad_word, self.eos_word = unk, pad, eos 18 | self.symbols = [] 19 | self.count = [] 20 | self.indices = {} 21 | # dictionary indexing starts at 1 for consistency with Lua 22 | self.add_symbol('') 23 | self.pad_index = self.add_symbol(pad) 24 | self.eos_index = self.add_symbol(eos) 25 | self.unk_index = self.add_symbol(unk) 26 | self.nspecial = len(self.symbols) 27 | 28 | def __eq__(self, other): 29 | return self.indices == other.indices 30 | 31 | def __getitem__(self, idx): 32 | if idx < len(self.symbols): 33 | return self.symbols[idx] 34 | return self.unk_word 35 | 36 | def __len__(self): 37 | """Returns the number of symbols in the dictionary""" 38 | return len(self.symbols) 39 | 40 | def index(self, sym): 41 | """Returns the index of the specified symbol""" 42 | if sym in self.indices: 43 | return self.indices[sym] 44 | return self.unk_index 45 | 46 | def string(self, tensor, bpe_symbol=None, escape_unk=False): 47 | """Helper for converting a tensor of token indices to a string. 48 | 49 | Can optionally remove BPE symbols or escape words. 50 | """ 51 | if torch.is_tensor(tensor) and tensor.dim() == 2: 52 | return '\n'.join(self.string(t) for t in tensor) 53 | 54 | def token_string(i): 55 | if i == self.unk(): 56 | return self.unk_string(escape_unk) 57 | else: 58 | return self[i] 59 | 60 | sent = ' '.join(token_string(i) for i in tensor if i != self.eos()) 61 | if bpe_symbol is not None: 62 | sent = (sent + ' ').replace(bpe_symbol, '').rstrip() 63 | return sent 64 | 65 | def unk_string(self, escape=False): 66 | """Return unknown string, optionally escaped as: <>""" 67 | if escape: 68 | return '<{}>'.format(self.unk_word) 69 | else: 70 | return self.unk_word 71 | 72 | def add_symbol(self, word, n=1): 73 | """Adds a word to the dictionary""" 74 | if word in self.indices: 75 | idx = self.indices[word] 76 | self.count[idx] = self.count[idx] + n 77 | return idx 78 | else: 79 | idx = len(self.symbols) 80 | self.indices[word] = idx 81 | self.symbols.append(word) 82 | self.count.append(n) 83 | return idx 84 | 85 | def update(self, new_dict): 86 | """Updates counts from new dictionary.""" 87 | for word in new_dict.symbols: 88 | idx2 = new_dict.indices[word] 89 | if word in self.indices: 90 | idx = self.indices[word] 91 | self.count[idx] = self.count[idx] + new_dict.count[idx2] 92 | else: 93 | idx = len(self.symbols) 94 | self.indices[word] = idx 95 | self.symbols.append(word) 96 | self.count.append(new_dict.count[idx2]) 97 | 98 | def finalize(self, threshold=-1, nwords=-1, padding_factor=8): 99 | """Sort symbols by frequency in descending order, ignoring special ones. 100 | 101 | Args: 102 | - threshold defines the minimum word count 103 | - nwords defines the total number of words in the final dictionary, 104 | including special symbols 105 | - padding_factor can be used to pad the dictionary size to be a 106 | multiple of 8, which is important on some hardware (e.g., Nvidia 107 | Tensor Cores). 108 | """ 109 | if nwords <= 0: 110 | nwords = len(self) 111 | 112 | new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial))) 113 | new_symbols = self.symbols[:self.nspecial] 114 | new_count = self.count[:self.nspecial] 115 | 116 | c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:]))) 117 | for symbol, count in c.most_common(nwords - self.nspecial): 118 | if count >= threshold: 119 | new_indices[symbol] = len(new_symbols) 120 | new_symbols.append(symbol) 121 | new_count.append(count) 122 | else: 123 | break 124 | 125 | threshold_nwords = len(new_symbols) 126 | if padding_factor > 1: 127 | i = 0 128 | while threshold_nwords % padding_factor != 0: 129 | symbol = 'madeupword{:04d}'.format(i) 130 | new_indices[symbol] = len(new_symbols) 131 | new_symbols.append(symbol) 132 | new_count.append(0) 133 | i += 1 134 | threshold_nwords += 1 135 | 136 | assert len(new_symbols) % padding_factor == 0 137 | assert len(new_symbols) == len(new_indices) 138 | 139 | self.count = list(new_count) 140 | self.symbols = list(new_symbols) 141 | self.indices = new_indices 142 | 143 | def pad(self): 144 | """Helper to get index of pad symbol""" 145 | return self.pad_index 146 | 147 | def eos(self): 148 | """Helper to get index of end-of-sentence symbol""" 149 | return self.eos_index 150 | 151 | def unk(self): 152 | """Helper to get index of unk symbol""" 153 | return self.unk_index 154 | 155 | @classmethod 156 | def load(cls, f, ignore_utf_errors=False): 157 | """Loads the dictionary from a text file with the format: 158 | 159 | ``` 160 | 161 | 162 | ... 163 | ``` 164 | """ 165 | if isinstance(f, str): 166 | try: 167 | if not ignore_utf_errors: 168 | with open(f, 'r', encoding='utf-8') as fd: 169 | return cls.load(fd) 170 | else: 171 | with open(f, 'r', encoding='utf-8', errors='ignore') as fd: 172 | return cls.load(fd) 173 | except FileNotFoundError as fnfe: 174 | raise fnfe 175 | except Exception: 176 | raise Exception("Incorrect encoding detected in {}, please " 177 | "rebuild the dataset".format(f)) 178 | 179 | d = cls() 180 | for line in f.readlines(): 181 | idx = line.rfind(' ') 182 | word = line[:idx] 183 | count = int(line[idx+1:]) 184 | d.indices[word] = len(d.symbols) 185 | d.symbols.append(word) 186 | d.count.append(count) 187 | return d 188 | 189 | def save(self, f): 190 | """Stores dictionary into a text file""" 191 | if isinstance(f, str): 192 | os.makedirs(os.path.dirname(f), exist_ok=True) 193 | with open(f, 'w', encoding='utf-8') as fd: 194 | return self.save(fd) 195 | for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]): 196 | print('{} {}'.format(symbol, count), file=f) 197 | 198 | def dummy_sentence(self, length): 199 | t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() 200 | t[-1] = self.eos() 201 | return t 202 | -------------------------------------------------------------------------------- /fairseq/data/fairseq_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.utils.data 9 | 10 | 11 | class FairseqDataset(torch.utils.data.Dataset): 12 | """A dataset that provides helpers for batching.""" 13 | 14 | def __getitem__(self, index): 15 | raise NotImplementedError 16 | 17 | def __len__(self): 18 | raise NotImplementedError 19 | 20 | def collater(self, samples): 21 | """Merge a list of samples to form a mini-batch.""" 22 | raise NotImplementedError 23 | 24 | def get_dummy_batch(self, num_tokens, max_positions): 25 | """Return a dummy batch with a given number of tokens.""" 26 | raise NotImplementedError 27 | 28 | def num_tokens(self, index): 29 | """Return an example's length (number of tokens), used for batching.""" 30 | raise NotImplementedError 31 | 32 | def ordered_indices(self): 33 | """Ordered indices for batching.""" 34 | raise NotImplementedError 35 | 36 | def valid_size(self, index, max_positions): 37 | """Check if an example's size is valid according to max_positions.""" 38 | raise NotImplementedError 39 | -------------------------------------------------------------------------------- /fairseq/data/gap_reader.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | import random 4 | import numpy as np 5 | import string 6 | import re 7 | 8 | from collections import namedtuple 9 | 10 | # Fieldnames used in the gold dataset .tsv file. 11 | GOLD_FIELDNAMES = [ 12 | 'ID', 'Text', 'Pronoun', 'Pronoun-offset', 'A', 'A-offset', 'A-coref', 'B', 13 | 'B-offset', 'B-coref', 'URL' 14 | ] 15 | 16 | # Fieldnames expected in system output .tsv files. 17 | SYSTEM_FIELDNAMES = ['ID', 'A-coref', 'B-coref'] 18 | 19 | GAP_Record = namedtuple("GAP_Record", ["example_id", "text", "pronoun", 20 | "pronoun_offset_start", "pronoun_offset_end", 21 | "a", "a_offset_start", "a_offset_end", "a_coref", 22 | "b", "b_offset_start", "b_offset_end", "b_coref"]) 23 | 24 | class GAP_Reader: 25 | 26 | def __init__(self, in_tsv_file_path, is_gold=True): 27 | self.in_tsv_file_path = in_tsv_file_path 28 | self.is_gold = is_gold 29 | 30 | def read(self): 31 | data = [] 32 | fieldnames = GOLD_FIELDNAMES if self.is_gold else SYSTEM_FIELDNAMES 33 | with open(self.in_tsv_file_path, 'r') as in_tsv_file: 34 | reader = csv.DictReader(in_tsv_file, fieldnames=fieldnames, delimiter='\t') 35 | if self.is_gold: 36 | next(reader, None) 37 | for i, row in enumerate(reader): 38 | example_id = row['ID'] 39 | 40 | text, pronoun, pronoun_offset_start, pronoun_offset_end = None, None, None, None 41 | a, a_offset_start, a_offset_end = None, None, None 42 | b, b_offset_start, b_offset_end = None, None, None 43 | 44 | if self.is_gold: 45 | text = row['Text'] 46 | pronoun = row['Pronoun'] 47 | pronoun_offset_start, pronoun_offset_end = row['Pronoun-offset'].split(":") 48 | pronoun_offset_start = int(pronoun_offset_start) 49 | pronoun_offset_end = int(pronoun_offset_end) - 1 50 | assert pronoun_offset_start == pronoun_offset_end 51 | 52 | a = row['A'] 53 | a_offset_start, a_offset_end = row['A-offset'].split(":") 54 | a_offset_start = int(a_offset_start) 55 | a_offset_end = int(a_offset_end) - 1 # -1 to be inclusive 56 | assert a_offset_start <= a_offset_end 57 | assert a == ' '.join(text.split(' ')[a_offset_start:a_offset_end+1]) 58 | 59 | b = row['B'] 60 | b_offset_start, b_offset_end = row['B-offset'].split(":") 61 | b_offset_start = int(b_offset_start) 62 | b_offset_end = int(b_offset_end) - 1 # -1 to be inclusive 63 | assert a_offset_start <= a_offset_end 64 | assert b == ' '.join(text.split(' ')[b_offset_start:b_offset_end+1]) 65 | 66 | assert a_offset_start < b_offset_start 67 | assert a_offset_end < b_offset_end 68 | 69 | assert row['A-coref'].upper() in ['TRUE', 'FALSE'] 70 | a_coref = True if row['A-coref'].upper() == 'TRUE' else False 71 | assert row['B-coref'].upper() in ['TRUE', 'FALSE'] 72 | b_coref = True if row['B-coref'].upper() == 'TRUE' else False 73 | 74 | # data.append(( 75 | # example_id, text, 76 | # (pronoun, pronoun_offset_start, pronoun_offset_end), 77 | # (a, a_offset_start, a_offset_end, a_coref), 78 | # (b, b_offset_start, b_offset_end, b_coref), 79 | # )) 80 | data.append(GAP_Record( 81 | example_id, text, 82 | pronoun, pronoun_offset_start, pronoun_offset_end, 83 | a, a_offset_start, a_offset_end, a_coref, 84 | b, b_offset_start, b_offset_end, b_coref 85 | )) 86 | return data 87 | 88 | if __name__ == "__main__": 89 | gap_reader = GAP_Reader('data/gap/gap-test.tok.tsv') 90 | a_cnt, b_cnt = 0, 0 91 | for gid, text, (p, pos, poe), (a, aos, aoe, ac), (b, bos, boe, bc) in gap_reader.read(): 92 | assert (ac and (not bc)) or ((not ac) and bc) or ((not ac) and (not bc)) 93 | if ac: 94 | a_cnt += 1 95 | if bc: 96 | b_cnt += 1 97 | print(a_cnt, b_cnt) -------------------------------------------------------------------------------- /fairseq/data/indexed_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import os 9 | import struct 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from fairseq.tokenizer import Tokenizer 15 | 16 | 17 | def read_longs(f, n): 18 | a = np.empty(n, dtype=np.int64) 19 | f.readinto(a) 20 | return a 21 | 22 | 23 | def write_longs(f, a): 24 | f.write(np.array(a, dtype=np.int64)) 25 | 26 | 27 | dtypes = { 28 | 1: np.uint8, 29 | 2: np.int8, 30 | 3: np.int16, 31 | 4: np.int32, 32 | 5: np.int64, 33 | 6: np.float, 34 | 7: np.double, 35 | } 36 | 37 | 38 | def code(dtype): 39 | for k in dtypes.keys(): 40 | if dtypes[k] == dtype: 41 | return k 42 | 43 | 44 | def index_file_path(prefix_path): 45 | return prefix_path + '.idx' 46 | 47 | 48 | def data_file_path(prefix_path): 49 | return prefix_path + '.bin' 50 | 51 | 52 | class IndexedDataset(torch.utils.data.Dataset): 53 | """Loader for TorchNet IndexedDataset""" 54 | 55 | def __init__(self, path, fix_lua_indexing=False): 56 | super().__init__() 57 | self.fix_lua_indexing = fix_lua_indexing 58 | with open(index_file_path(path), 'rb') as f: 59 | magic = f.read(8) 60 | assert magic == b'TNTIDX\x00\x00' 61 | version = f.read(8) 62 | assert struct.unpack('= self.size: 76 | raise IndexError('index out of range') 77 | 78 | def __del__(self): 79 | self.data_file.close() 80 | 81 | def __getitem__(self, i): 82 | self.check_index(i) 83 | tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] 84 | a = np.empty(tensor_size, dtype=self.dtype) 85 | self.data_file.seek(self.data_offsets[i] * self.element_size) 86 | self.data_file.readinto(a) 87 | item = torch.from_numpy(a).long() 88 | if self.fix_lua_indexing: 89 | item -= 1 # subtract 1 for 0-based indexing 90 | return item 91 | 92 | def __len__(self): 93 | return self.size 94 | 95 | @staticmethod 96 | def exists(path): 97 | return ( 98 | os.path.exists(index_file_path(path)) and 99 | os.path.exists(data_file_path(path)) 100 | ) 101 | 102 | 103 | class IndexedInMemoryDataset(IndexedDataset): 104 | """Loader for TorchNet IndexedDataset, keeps all the data in memory""" 105 | 106 | def read_data(self, path): 107 | self.data_file = open(data_file_path(path), 'rb') 108 | self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype) 109 | self.data_file.readinto(self.buffer) 110 | self.data_file.close() 111 | if self.fix_lua_indexing: 112 | self.buffer -= 1 # subtract 1 for 0-based indexing 113 | 114 | def __del__(self): 115 | pass 116 | 117 | def __getitem__(self, i): 118 | self.check_index(i) 119 | tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] 120 | a = np.empty(tensor_size, dtype=self.dtype) 121 | np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]]) 122 | return torch.from_numpy(a).long() 123 | 124 | 125 | class IndexedRawTextDataset(IndexedDataset): 126 | """Takes a text file as input and binarizes it in memory at instantiation. 127 | Original lines are also kept in memory""" 128 | 129 | def __init__(self, path, dictionary, append_eos=True, reverse_order=False): 130 | self.tokens_list = [] 131 | self.lines = [] 132 | self.sizes = [] 133 | self.append_eos = append_eos 134 | self.reverse_order = reverse_order 135 | self.read_data(path, dictionary) 136 | self.size = len(self.tokens_list) 137 | 138 | def read_data(self, path, dictionary): 139 | with open(path, 'r') as f: 140 | for line in f: 141 | self.lines.append(line.strip('\n')) 142 | tokens = Tokenizer.tokenize( 143 | line, dictionary, add_if_not_exist=False, 144 | append_eos=self.append_eos, reverse_order=self.reverse_order, 145 | ).long() 146 | self.tokens_list.append(tokens) 147 | self.sizes.append(len(tokens)) 148 | self.sizes = np.array(self.sizes) 149 | 150 | def __getitem__(self, i): 151 | self.check_index(i) 152 | return self.tokens_list[i] 153 | 154 | def get_original_text(self, i): 155 | self.check_index(i) 156 | return self.lines[i] 157 | 158 | def __del__(self): 159 | pass 160 | 161 | def __len__(self): 162 | return self.size 163 | 164 | @staticmethod 165 | def exists(path): 166 | return os.path.exists(path) 167 | 168 | 169 | class IndexedDatasetBuilder(object): 170 | element_sizes = { 171 | np.uint8: 1, 172 | np.int8: 1, 173 | np.int16: 2, 174 | np.int32: 4, 175 | np.int64: 8, 176 | np.float: 4, 177 | np.double: 8 178 | } 179 | 180 | def __init__(self, out_file, dtype=np.int32): 181 | self.out_file = open(out_file, 'wb') 182 | self.dtype = dtype 183 | self.data_offsets = [0] 184 | self.dim_offsets = [0] 185 | self.sizes = [] 186 | self.element_size = self.element_sizes[self.dtype] 187 | 188 | def add_item(self, tensor): 189 | # +1 for Lua compatibility 190 | bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype)) 191 | self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) 192 | for s in tensor.size(): 193 | self.sizes.append(s) 194 | self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) 195 | 196 | def finalize(self, index_file): 197 | self.out_file.close() 198 | index = open(index_file, 'wb') 199 | index.write(b'TNTIDX\x00\x00') 200 | index.write(struct.pack(' 40 | 41 | def _increase_offsets(offsets): 42 | return (offsets[0] + 1, offsets[1] + 1) 43 | 44 | def _increase_offsets_corefs(corefs): 45 | seqlen = corefs.shape[0] 46 | ret = np.zeros((seqlen, seqlen)) 47 | ret[1:, 1:] = corefs[:-1, :-1] 48 | return ret 49 | 50 | def _increase_offsets(data): 51 | return GAP_Record( 52 | data.example_id, 53 | data.text, 54 | data.pronoun, 55 | data.pronoun_offset_start + 1, 56 | data.pronoun_offset_end + 1, 57 | data.a, 58 | data.a_offset_start + 1, 59 | data.a_offset_end + 1, 60 | data.a_coref, 61 | data.b, 62 | data.b_offset_start + 1, 63 | data.b_offset_end + 1, 64 | data.b_coref 65 | ) 66 | 67 | def _increase_offsets_bert(bert_weights): 68 | bert_weights_shape = bert_weights.shape 69 | eos_padding = np.zeros((bert_weights_shape[0], 1, bert_weights_shape[2])) 70 | return np.concatenate([eos_padding, bert_weights], axis=1) 71 | 72 | return ( 73 | torch.LongTensor(source_token), 74 | _increase_offsets(self.gap_data[index]), 75 | torch.FloatTensor(_increase_offsets_corefs(self.gap_corefs[index])), 76 | torch.FloatTensor(_increase_offsets_bert(self.gap_bert_weights[index])), 77 | token_item, 78 | ) -------------------------------------------------------------------------------- /fairseq/distributed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import pickle 9 | 10 | import torch.distributed 11 | 12 | from fairseq import utils 13 | 14 | 15 | def is_master(args): 16 | return args.distributed_rank == 0 17 | 18 | 19 | def distributed_init(args): 20 | if args.distributed_world_size == 1: 21 | raise ValueError('Cannot initialize distributed with distributed_world_size=1') 22 | 23 | print('| distributed init (rank {}): {}'.format( 24 | args.distributed_rank, args.distributed_init_method), flush=True) 25 | if args.distributed_init_method.startswith('tcp://'): 26 | torch.distributed.init_process_group( 27 | backend=args.distributed_backend, init_method=args.distributed_init_method, 28 | world_size=args.distributed_world_size, rank=args.distributed_rank) 29 | else: 30 | torch.distributed.init_process_group( 31 | backend=args.distributed_backend, init_method=args.distributed_init_method, 32 | world_size=args.distributed_world_size) 33 | 34 | args.distributed_rank = torch.distributed.get_rank() 35 | if not is_master(args): 36 | suppress_output() 37 | 38 | return args.distributed_rank 39 | 40 | 41 | def suppress_output(): 42 | """Suppress printing on the current device. Force printing with `force=True`.""" 43 | import builtins as __builtin__ 44 | builtin_print = __builtin__.print 45 | 46 | def print(*args, **kwargs): 47 | if 'force' in kwargs: 48 | force = kwargs.pop('force') 49 | if force: 50 | builtin_print(*args, **kwargs) 51 | 52 | __builtin__.print = print 53 | 54 | 55 | def all_gather_list(data, max_size=4096): 56 | """Gathers arbitrary data from all nodes into a list.""" 57 | world_size = torch.distributed.get_world_size() 58 | if not hasattr(all_gather_list, '_in_buffer') or \ 59 | max_size != all_gather_list._in_buffer.size(): 60 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 61 | all_gather_list._out_buffers = [ 62 | torch.cuda.ByteTensor(max_size) 63 | for i in range(world_size) 64 | ] 65 | in_buffer = all_gather_list._in_buffer 66 | out_buffers = all_gather_list._out_buffers 67 | 68 | enc = pickle.dumps(data) 69 | enc_size = len(enc) 70 | if enc_size + 2 > max_size: 71 | raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) 72 | assert max_size < 255*256 73 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 74 | in_buffer[1] = enc_size % 255 75 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 76 | 77 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 78 | 79 | result = [] 80 | for i in range(world_size): 81 | out_buffer = out_buffers[i] 82 | size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1]) 83 | result.append( 84 | pickle.loads(bytes(out_buffer[2:size+2].tolist())) 85 | ) 86 | return result 87 | -------------------------------------------------------------------------------- /fairseq/fp16_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | """ 9 | Train a network on multiple GPUs. 10 | """ 11 | 12 | import torch 13 | 14 | from fairseq import optim, utils 15 | from fairseq.meters import AverageMeter 16 | from fairseq.optim import lr_scheduler 17 | from fairseq.trainer import Trainer 18 | 19 | 20 | class DynamicLossScaler: 21 | 22 | def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000): 23 | self.loss_scale = init_scale 24 | self.scale_factor = scale_factor 25 | self.scale_window = scale_window 26 | self._iter = 0 27 | self._last_overflow_iter = -1 28 | 29 | def update_scale(self, overflow): 30 | if overflow: 31 | self.loss_scale /= self.scale_factor 32 | self._last_overflow_iter = self._iter 33 | elif (self._iter - self._last_overflow_iter) % self.scale_window == 0: 34 | self.loss_scale *= self.scale_factor 35 | self._iter += 1 36 | 37 | @staticmethod 38 | def has_overflow(grad_norm): 39 | # detect inf and nan 40 | if grad_norm == float('inf') or grad_norm != grad_norm: 41 | return True 42 | return False 43 | 44 | 45 | class FP16Trainer(Trainer): 46 | """Modified trainer for FP16. 47 | 48 | We maintain two copies of the model's parameters, both in FP16 and FP32. 49 | We do forward/backward with FP16 and compute the loss + optimize with FP32. 50 | """ 51 | 52 | def __init__(self, args, task, model, criterion): 53 | super().__init__(args, task, model, criterion) 54 | 55 | # convert model to FP16 (but keep criterion FP32) 56 | self.model.half() 57 | 58 | # dynamically scale loss to reduce overflow 59 | self.scaler = DynamicLossScaler(init_scale=2.**7) 60 | self.meters['loss_scale'] = AverageMeter() 61 | 62 | def _build_optimizer(self): 63 | # create FP32 copy of parameters and grads 64 | params = [p for p in self.model.parameters() if p.requires_grad] 65 | total_param_size = sum(p.data.numel() for p in params) 66 | self.fp32_params = params[0].new(0).float().new(total_param_size) 67 | offset = 0 68 | for p in params: 69 | numel = p.data.numel() 70 | self.fp32_params[offset:offset+numel].copy_(p.data.view(-1)) 71 | offset += numel 72 | self.fp32_params = torch.nn.Parameter(self.fp32_params) 73 | self.fp32_params.grad = self.fp32_params.data.new(total_param_size) 74 | 75 | # create optimizer using the copied FP32 params 76 | self._optimizer = optim.build_optimizer(self.args, [self.fp32_params]) 77 | self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) 78 | 79 | def save_checkpoint(self, filename, extra_state): 80 | """Save all training state in a checkpoint file.""" 81 | extra_state['loss_scale'] = self.scaler.loss_scale 82 | super().save_checkpoint(filename, extra_state) 83 | 84 | def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): 85 | """Load all training state from a checkpoint file.""" 86 | extra_state = super().load_checkpoint(filename, reset_optimizer, reset_lr_scheduler, optimizer_overrides) 87 | if extra_state is not None and 'loss_scale' in extra_state: 88 | self.scaler.loss_scale = extra_state['loss_scale'] 89 | return extra_state 90 | 91 | def zero_grad(self): 92 | # zero both the FP16 and FP32 grads 93 | self.model.zero_grad() # FP16 94 | self.optimizer.zero_grad() # FP32 95 | 96 | def _backward(self, loss): 97 | self.meters['loss_scale'].reset() 98 | self.meters['loss_scale'].update(self.scaler.loss_scale) 99 | if loss is not None: 100 | # dynamically rescale loss to stay in FP16 range 101 | loss = loss * self.scaler.loss_scale 102 | return super()._backward(loss) 103 | 104 | def _all_reduce_and_rescale(self, grad_denom): 105 | # undo effect of dynamic loss scaling on gradients 106 | grad_denom *= self.scaler.loss_scale 107 | 108 | if self.args.distributed_world_size > 1: 109 | # flatten grads into a single buffer 110 | flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) 111 | 112 | # scale gradients to avoid overflow in all-reduce 113 | flat_grads.div_(self.args.distributed_world_size) 114 | grad_denom /= self.args.distributed_world_size 115 | 116 | # all-reduce flat grads 117 | torch.distributed.all_reduce(flat_grads) 118 | 119 | # copy grads back to FP32 120 | self.fp32_params.grad.data.copy_(flat_grads) 121 | else: 122 | # single worker: copy grads directly to FP32 123 | self._get_flat_grads(out=self.fp32_params.grad.data) 124 | 125 | # rescale and clip grads 126 | self.fp32_params.grad.data.div_(grad_denom) 127 | grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm) 128 | 129 | # detect overflow and adjust loss scale 130 | overflow = DynamicLossScaler.has_overflow(grad_norm) 131 | self.scaler.update_scale(overflow) 132 | if overflow: 133 | if self.scaler.loss_scale <= self.args.min_loss_scale: 134 | raise Exception(( 135 | 'Minimum loss scale reached ({}). Your loss is probably exploding. ' 136 | 'Try lowering the learning rate, using gradient clipping or ' 137 | 'increasing the batch size.' 138 | ).format(self.args.min_loss_scale)) 139 | raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) 140 | 141 | return grad_norm 142 | 143 | def _opt(self): 144 | # take an optimization step using the FP32 params and grads 145 | super()._opt() 146 | 147 | # copy FP32 params back into FP16 model 148 | offset = 0 149 | for p in self.model.parameters(): 150 | if not p.requires_grad: 151 | continue 152 | numel = p.data.numel() 153 | p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) 154 | offset += numel 155 | -------------------------------------------------------------------------------- /fairseq/meters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import time 9 | 10 | 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | class TimeMeter(object): 30 | """Computes the average occurrence of some event per second""" 31 | def __init__(self, init=0): 32 | self.reset(init) 33 | 34 | def reset(self, init=0): 35 | self.init = init 36 | self.start = time.time() 37 | self.n = 0 38 | 39 | def update(self, val=1): 40 | self.n += val 41 | 42 | @property 43 | def avg(self): 44 | return self.n / self.elapsed_time 45 | 46 | @property 47 | def elapsed_time(self): 48 | return self.init + (time.time() - self.start) 49 | 50 | 51 | class StopwatchMeter(object): 52 | """Computes the sum/avg duration of some event in seconds""" 53 | def __init__(self): 54 | self.reset() 55 | 56 | def start(self): 57 | self.start_time = time.time() 58 | 59 | def stop(self, n=1): 60 | if self.start_time is not None: 61 | delta = time.time() - self.start_time 62 | self.sum += delta 63 | self.n += n 64 | self.start_time = None 65 | 66 | def reset(self): 67 | self.sum = 0 68 | self.n = 0 69 | self.start_time = None 70 | 71 | @property 72 | def avg(self): 73 | return self.sum / self.n 74 | -------------------------------------------------------------------------------- /fairseq/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_decoder import FairseqDecoder, GapBertDecoder # noqa: F401 12 | from .fairseq_encoder import FairseqEncoder # noqa: F401 13 | from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 14 | from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401p 15 | 16 | from .composite_encoder import CompositeEncoder # noqa: F401 17 | 18 | 19 | MODEL_REGISTRY = {} 20 | ARCH_MODEL_REGISTRY = {} 21 | ARCH_CONFIG_REGISTRY = {} 22 | 23 | 24 | def build_model(args, task): 25 | return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task) 26 | 27 | 28 | def register_model(name): 29 | """Decorator to register a new model (e.g., LSTM).""" 30 | 31 | print("registering", name) 32 | 33 | 34 | def register_model_cls(cls): 35 | if name in MODEL_REGISTRY: 36 | raise ValueError('Cannot register duplicate model ({})'.format(name)) 37 | if not issubclass(cls, BaseFairseqModel): 38 | raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__)) 39 | MODEL_REGISTRY[name] = cls 40 | return cls 41 | 42 | return register_model_cls 43 | 44 | 45 | def register_model_architecture(model_name, arch_name): 46 | """Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de).""" 47 | 48 | def register_model_arch_fn(fn): 49 | if model_name not in MODEL_REGISTRY: 50 | raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name)) 51 | if arch_name in ARCH_MODEL_REGISTRY: 52 | raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name)) 53 | if not callable(fn): 54 | raise ValueError('Model architecture must be callable ({})'.format(arch_name)) 55 | ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] 56 | ARCH_CONFIG_REGISTRY[arch_name] = fn 57 | return fn 58 | 59 | return register_model_arch_fn 60 | 61 | 62 | # automatically import any Python files in the models/ directory 63 | for file in os.listdir(os.path.dirname(__file__)): 64 | if file.endswith('.py') and not file.startswith('_'): 65 | module = file[:file.find('.py')] 66 | importlib.import_module('fairseq.models.' + module) 67 | -------------------------------------------------------------------------------- /fairseq/models/composite_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqEncoder 9 | 10 | 11 | class CompositeEncoder(FairseqEncoder): 12 | """ 13 | Encoder class that forwards on multiple encoders, for example for a fusion model or question-answering 14 | Accepts a dictionary of encoder, the first encoder's dictionary is used for initialization 15 | """ 16 | 17 | def __init__(self, encoders): 18 | super().__init__(next(iter(encoders.values())).dictionary) 19 | self.encoders = encoders 20 | for key in self.encoders: 21 | self.add_module(key, self.encoders[key]) 22 | 23 | def forward(self, src_tokens, src_lengths): 24 | encoder_out = {} 25 | for key in self.encoders: 26 | encoder_out[key] = self.encoders[key](src_tokens, src_lengths) 27 | return encoder_out 28 | 29 | def reorder_encoder_out(self, encoder_out, new_order): 30 | """Reorder encoder output according to new_order.""" 31 | for key in self.encoders: 32 | encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order) 33 | return encoder_out 34 | 35 | def max_positions(self): 36 | return min([self.encoders[key].max_positions() for key in self.encoders]) 37 | 38 | def upgrade_state_dict(self, state_dict): 39 | for key in self.encoders: 40 | self.encoders[key].upgrade_state_dict(state_dict) 41 | return state_dict 42 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class FairseqDecoder(nn.Module): 13 | """Base class for decoders.""" 14 | 15 | def __init__(self, dictionary): 16 | super().__init__() 17 | self.dictionary = dictionary 18 | 19 | def forward(self, prev_output_tokens, encoder_out): 20 | raise NotImplementedError 21 | 22 | def get_normalized_probs(self, net_output, log_probs, sample): 23 | """Get normalized probabilities (or log probs) from a net's output.""" 24 | 25 | if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None: 26 | assert sample is not None and 'target' in sample 27 | out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target']) 28 | return out.exp_() if not log_probs else out 29 | 30 | logits = net_output[0].float() 31 | if log_probs: 32 | return F.log_softmax(logits, dim=-1) 33 | else: 34 | return F.softmax(logits, dim=-1) 35 | 36 | def max_positions(self): 37 | """Maximum input length supported by the decoder.""" 38 | raise NotImplementedError 39 | 40 | def upgrade_state_dict(self, state_dict): 41 | return state_dict 42 | 43 | class GapBertDecoder(FairseqDecoder): 44 | """ 45 | a decoder that computes an additional gate loss from soft supervision 46 | """ 47 | def __init__(self, dictionary): 48 | super().__init__(dictionary) 49 | self._epoch = 0 50 | 51 | def gate_loss(self): 52 | raise NotImplementedError() 53 | 54 | def set_epoch(self, epoch): 55 | self._epoch = epoch 56 | 57 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class FairseqEncoder(nn.Module): 12 | """Base class for encoders.""" 13 | 14 | def __init__(self, dictionary): 15 | super().__init__() 16 | self.dictionary = dictionary 17 | 18 | def forward(self, src_tokens, src_lengths): 19 | raise NotImplementedError 20 | 21 | def reorder_encoder_out(self, encoder_out, new_order): 22 | """Reorder encoder output according to new_order.""" 23 | raise NotImplementedError 24 | 25 | def max_positions(self): 26 | """Maximum input length supported by the encoder.""" 27 | raise NotImplementedError 28 | 29 | def upgrade_state_dict(self, state_dict): 30 | return state_dict 31 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_incremental_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqDecoder 9 | 10 | 11 | class FairseqIncrementalDecoder(FairseqDecoder): 12 | """Base class for incremental decoders.""" 13 | 14 | def __init__(self, dictionary): 15 | super().__init__(dictionary) 16 | 17 | def forward(self, prev_output_tokens, encoder_out, incremental_state=None): 18 | raise NotImplementedError 19 | 20 | def reorder_incremental_state(self, incremental_state, new_order): 21 | """Reorder incremental state. 22 | 23 | This should be called when the order of the input has changed from the 24 | previous time step. A typical use case is beam search, where the input 25 | order changes between time steps based on the selection of beams. 26 | """ 27 | def apply_reorder_incremental_state(module): 28 | if module != self and hasattr(module, 'reorder_incremental_state'): 29 | module.reorder_incremental_state( 30 | incremental_state, 31 | new_order, 32 | ) 33 | self.apply(apply_reorder_incremental_state) 34 | 35 | def set_beam_size(self, beam_size): 36 | """Sets the beam size in the decoder and all children.""" 37 | if getattr(self, '_beam_size', -1) != beam_size: 38 | def apply_set_beam_size(module): 39 | if module != self and hasattr(module, 'set_beam_size'): 40 | module.set_beam_size(beam_size) 41 | self.apply(apply_set_beam_size) 42 | self._beam_size = beam_size 43 | -------------------------------------------------------------------------------- /fairseq/models/fairseq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import torch.nn as nn 10 | 11 | from . import FairseqDecoder, FairseqEncoder 12 | 13 | 14 | class BaseFairseqModel(nn.Module): 15 | """Base class for fairseq models.""" 16 | 17 | def __init__(self): 18 | super().__init__() 19 | self._is_generation_fast = False 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | """Add model-specific arguments to the parser.""" 24 | pass 25 | 26 | @classmethod 27 | def build_model(cls, args, task): 28 | """Build a new model instance.""" 29 | raise NotImplementedError 30 | 31 | def get_targets(self, sample, net_output): 32 | """Get targets from either the sample or the net's output.""" 33 | return sample['target'] 34 | 35 | def get_normalized_probs(self, net_output, log_probs, sample=None): 36 | """Get normalized probabilities (or log probs) from a net's output.""" 37 | return self.decoder.get_normalized_probs(net_output, log_probs, sample) 38 | 39 | def max_positions(self): 40 | """Maximum length supported by the model.""" 41 | raise NotImplementedError 42 | 43 | def max_decoder_positions(self): 44 | """Maximum length supported by the decoder.""" 45 | return self.decoder.max_positions() 46 | 47 | def load_state_dict(self, state_dict, strict=True): 48 | """Copies parameters and buffers from state_dict into this module and 49 | its descendants. 50 | 51 | Overrides the method in nn.Module; compared with that method this 52 | additionally "upgrades" state_dicts from old checkpoints. 53 | """ 54 | self.upgrade_state_dict(state_dict) 55 | super().load_state_dict(state_dict, strict) 56 | 57 | def upgrade_state_dict(self, state_dict): 58 | assert state_dict is not None 59 | 60 | def do_upgrade(m, prefix): 61 | if len(prefix) > 0: 62 | prefix += '.' 63 | 64 | for n, c in m.named_children(): 65 | name = prefix + n 66 | if hasattr(c, 'upgrade_state_dict_named'): 67 | c.upgrade_state_dict_named(state_dict, name) 68 | elif hasattr(c, 'upgrade_state_dict'): 69 | c.upgrade_state_dict(state_dict) 70 | do_upgrade(c, name) 71 | 72 | do_upgrade(self, '') 73 | 74 | def make_generation_fast_(self, **kwargs): 75 | """Optimize model for faster generation.""" 76 | if self._is_generation_fast: 77 | return # only apply once 78 | self._is_generation_fast = True 79 | 80 | # remove weight norm from all modules in the network 81 | def apply_remove_weight_norm(module): 82 | try: 83 | nn.utils.remove_weight_norm(module) 84 | except ValueError: # this module didn't have weight norm 85 | return 86 | 87 | self.apply(apply_remove_weight_norm) 88 | 89 | def apply_make_generation_fast_(module): 90 | if module != self and hasattr(module, 'make_generation_fast_'): 91 | module.make_generation_fast_(**kwargs) 92 | 93 | self.apply(apply_make_generation_fast_) 94 | 95 | def train(mode): 96 | if mode: 97 | raise RuntimeError('cannot train after make_generation_fast') 98 | 99 | # this model should no longer be used for training 100 | self.eval() 101 | self.train = train 102 | 103 | 104 | class FairseqModel(BaseFairseqModel): 105 | """Base class for encoder-decoder models.""" 106 | 107 | def __init__(self, encoder, decoder): 108 | super().__init__() 109 | 110 | self.encoder = encoder 111 | self.decoder = decoder 112 | assert isinstance(self.encoder, FairseqEncoder) 113 | assert isinstance(self.decoder, FairseqDecoder) 114 | 115 | def forward(self, src_tokens, src_lengths, prev_output_tokens): 116 | encoder_out = self.encoder(src_tokens, src_lengths) 117 | decoder_out = self.decoder(prev_output_tokens, encoder_out) 118 | return decoder_out 119 | 120 | def max_positions(self): 121 | """Maximum length supported by the model.""" 122 | return (self.encoder.max_positions(), self.decoder.max_positions()) 123 | 124 | 125 | class FairseqLanguageModel(BaseFairseqModel): 126 | """Base class for decoder-only models.""" 127 | 128 | def __init__(self, decoder): 129 | super().__init__() 130 | self.decoder = decoder 131 | assert isinstance(self.decoder, FairseqDecoder) 132 | 133 | def forward(self, src_tokens): 134 | return self.decoder(src_tokens) 135 | 136 | def max_positions(self): 137 | """Maximum length supported by the model.""" 138 | return self.decoder.max_positions() 139 | 140 | # class FairseqLanguageModelSpan(FairseqLanguageModel): 141 | # def __init__(self, decoder): 142 | # super().__init__(decoder) 143 | 144 | # def forward(self, src_tokens, src_ner=None, src_chunk=None): 145 | # return self.decoder(src_tokens, src_ner, src_chunk) 146 | -------------------------------------------------------------------------------- /fairseq/models/gap_evaluator.py: -------------------------------------------------------------------------------- 1 | from gap_scorer import Annotation, calculate_scores, make_scorecard_simple 2 | from constants import PRONOUNS, Gender 3 | 4 | class GAPEvaluator: 5 | 6 | def __init__(self): 7 | pass 8 | 9 | def eval(self, gold_annotations, system_annotations): 10 | scores = calculate_scores(gold_annotations, system_annotations) 11 | return make_scorecard_simple(scores) -------------------------------------------------------------------------------- /fairseq/models/lstm_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import LongTensor 5 | 6 | from fairseq import options, utils 7 | 8 | from . import ( 9 | FairseqIncrementalDecoder, FairseqLanguageModel, 10 | GapBertDecoder, 11 | register_model, register_model_architecture, 12 | ) 13 | 14 | from .fconv import Embedding 15 | from .lstm import LSTM, Linear, LSTMCell 16 | 17 | 18 | @register_model('lstm_cache_lm') 19 | class LSTMCacheLanguageModel(FairseqLanguageModel): 20 | def __init__(self, decoder): 21 | super().__init__(decoder) 22 | 23 | @staticmethod 24 | def add_args(parser): 25 | """Add model-specific arguments to the parser.""" 26 | parser.add_argument('--dropout', default=0.1, type=float, metavar='D', 27 | help='dropout probability') 28 | parser.add_argument('--decoder-embed-dim',type=int,metavar='N', 29 | help='decoder embedding dimension') 30 | parser.add_argument('--decoder-out-embed-dim',type=int,metavar='N', 31 | help='decoder output embedding dimension') 32 | parser.add_argument('--decoder-layers',type=int,metavar='N', 33 | help='number of layers in decoder LSTM') 34 | @classmethod 35 | def build_model(cls, args, task): 36 | """Build a new model instance.""" 37 | # make sure all arguments are present in older models 38 | lstm_lm_basic(args) 39 | decoder = LSTMCacheDecoder( 40 | dictionary=task.target_dictionary, 41 | embed_dim=args.decoder_embed_dim, 42 | hidden_size=args.decoder_out_embed_dim, 43 | num_layers=args.decoder_layers, 44 | dropout=args.dropout 45 | ) 46 | return LSTMCacheLanguageModel(decoder) 47 | 48 | 49 | class LSTMCacheDecoder(GapBertDecoder): 50 | """ 51 | An LSTM decoder with a cache, using nn.LSTM and no attention 52 | """ 53 | def __init__( 54 | self, dictionary, embed_dim=512, hidden_size=512, num_layers=1, 55 | dropout=0.1, start_at_zeros=False, pretrained_embed=None, use_cache=True): 56 | super().__init__(dictionary) 57 | self.num_layers = num_layers 58 | self.dropout = dropout 59 | self.hidden_size = hidden_size 60 | self.embed_dim = embed_dim 61 | 62 | num_embeddings = len(dictionary) 63 | self.padding_idx = dictionary.pad() 64 | if pretrained_embed is None: 65 | self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) 66 | else: 67 | self.embed_tokens = pretrained_embed 68 | 69 | self.layers = nn.ModuleList([ 70 | LSTMCell( 71 | input_size = embed_dim if layer == 0 else hidden_size, 72 | hidden_size = hidden_size 73 | ) 74 | for layer in range(num_layers) 75 | ]) 76 | 77 | self.output_units = hidden_size 78 | self.out_projection = nn.Linear(self.output_units,len(dictionary)) 79 | 80 | self.init_hiddens = nn.Parameter(torch.FloatTensor(num_layers,hidden_size)) 81 | self.init_cells = nn.Parameter(torch.FloatTensor(num_layers,hidden_size)) 82 | 83 | torch.nn.init.normal_(self.init_hiddens) 84 | torch.nn.init.normal_(self.init_cells) 85 | 86 | if use_cache: 87 | #self.cache_gate = nn.Sequential(nn.Linear(self.hidden_size,self.hidden_size), 88 | # nn.ReLU(), 89 | # nn.Linear(self.hidden_size,1), 90 | # nn.Sigmoid()) 91 | self.alpha = 0. 92 | self.theta = 0.5 93 | 94 | def new_zeros(self,*args): 95 | return self.init_hiddens.data.new_zeros(args) 96 | 97 | def new_ones(self,*args): 98 | return self.init_hiddens.data.new_ones(args) 99 | 100 | def reset_init_params(self): 101 | for i in range(self.num_layers): 102 | nn.init.normal_(self.init_hiddens[i]) 103 | nn.init.normal_(self.init_cells[i]) 104 | 105 | def forward(self, prev_output_tokens, incremental_state = None): 106 | bsz,seqlen = prev_output_tokens.size() 107 | 108 | x = self._embed_tokens(prev_output_tokens, incremental_state) 109 | x = F.dropout(x, p=self.dropout, training=self.training) 110 | x = x.transpose(0,1) 111 | 112 | tok_tensor = self.new_zeros(bsz,seqlen,len(self.dictionary.indices)) 113 | indexer = torch.range(0,bsz-1).long() 114 | 115 | prev_hiddens = [layer[0] for layer in self.init_hiddens.repeat(bsz,1,1).transpose(0,1).split(1)] 116 | prev_cells = [layer[0] for layer in self.init_cells.repeat(bsz,1,1).transpose(0,1).split(1)] 117 | 118 | # hold init_hidden to start 119 | h_states = [self.init_hiddens.repeat(bsz,1,1).transpose(0,1)] 120 | for j,x_j in enumerate(x): 121 | tok_tensor[indexer,j,prev_output_tokens[:,j]] = 1. 122 | h_in = x_j 123 | for i, rnn in enumerate(self.layers): 124 | prev_hiddens[i], prev_cells[i] = rnn(h_in, (prev_hiddens[i], prev_cells[i])) 125 | h_in = F.dropout(prev_hiddens[i], p=self.dropout, training=self.training) 126 | h_states.append(h_in.view(1,bsz,-1)) 127 | 128 | # back to B x T x C 129 | h_states = torch.cat(h_states[:-1],0).transpose(1,0) 130 | rnn_lm = self.out_projection(h_states) 131 | 132 | # result: B x T x T 133 | pre_attn = torch.exp(self.theta * torch.bmm(h_states,h_states.transpose(2,1)) + self.alpha) 134 | attn = pre_attn.data.new_zeros(bsz,seqlen,seqlen) 135 | for b in range(bsz): 136 | attn[b,:,:] = torch.triu(pre_attn[b,:,:],diagonal=1) 137 | # rows will be source 138 | # cols will be target 139 | cache_lm = torch.bmm(attn.transpose(2,1),tok_tensor) 140 | 141 | print(attn.max(), cache_lm.max(), rnn_lm.max()) 142 | 143 | if not self.training: 144 | return cache_lm + rnn_lm, None 145 | 146 | return rnn_lm, None 147 | 148 | def _embed_tokens(self, tokens, incremental_state): 149 | if incremental_state is not None: 150 | # keep only the last token for incremental forward pass 151 | tokens = tokens[:, -1:] 152 | return self.embed_tokens(tokens) 153 | 154 | # not sure why we need this, but we do 155 | def max_positions(self): 156 | return int(1e8) 157 | 158 | def gate_loss(self): 159 | return self.new_zeros(1) 160 | 161 | @register_model_architecture('lstm_cache_lm','lstm_cache_lm') 162 | def lstm_lm_basic(args): 163 | args.dropout=getattr(args,'dropout',0.1) 164 | args.decoder_embed_dim=getattr(args,'decoder_embed_dim',1024) 165 | args.decoder_out_embed_dim=getattr(args,'decoder_out_embed_dim',1024) 166 | args.decoder_layers=getattr(args,'decoder_layers',1) 167 | return args 168 | 169 | -------------------------------------------------------------------------------- /fairseq/models/pronouns.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | # name list obtained from: https://www.ssa.gov/oact/babynames/decades/century.html 4 | # accessed on Nov 6th, 2018 5 | 6 | class PronounLexicon(): 7 | def __init__(self, lexfile='pronouns.tsv'): 8 | self.lexicon = defaultdict(lambda : []) 9 | with open(lexfile) as fin: 10 | for line in fin: 11 | if len(line) > 2: 12 | word = line.split()[0] 13 | feats = dict(x.split('=') for x in line.split()[1].split(',')) 14 | for feat,val in feats.items(): 15 | self.lexicon['='.join([feat,val])].append(word) 16 | print(f"Read lexicon from {lexfile}:\n{self.lexicon}") 17 | 18 | def make_lex(self,feature,dictionary): 19 | ''' 20 | given a fairseq dictionary, export a list of word idxs that match a desired feature 21 | ''' 22 | return [idx for word,idx in dictionary.indices.items() if word.lower() in self.lexicon[feature]] 23 | 24 | def all_word_idxs(self,dictionary): 25 | return [idx for word,idx in dictionary.indices.items() if word.lower() in self.all_words()] 26 | 27 | def all_words(self): 28 | output = set() 29 | for subset in self.lexicon.values(): 30 | for word in subset: 31 | output.add(word) 32 | return output 33 | 34 | def get_feature_set(self, feature_set): 35 | output = set() 36 | for t in feature_set: 37 | output |= set(self.lexicon[t]) 38 | return output 39 | 40 | def annotate_feature_chunk_end(self, sentence, chunk_tags, feature_set): 41 | pronoun_lexicons = self.get_feature_set(feature_set) 42 | assert len(sentence) == len(chunk_tags) 43 | output = [0 for _ in range(len(sentence))] 44 | for i, (token, chunk_tag) in enumerate(zip(sentence, chunk_tags)): 45 | if token.lower() in pronoun_lexicons: 46 | if chunk_tag == 'O' or chunk_tag[:2] == 'U-': 47 | output[i] = 1 48 | else: 49 | chunk_type = chunk_tag[2:] 50 | for j in range(i, len(sentence)): 51 | end_chunk = chunk_tags[j] 52 | assert end_chunk[2:] == chunk_type 53 | if end_chunk[:2] == 'L-': 54 | output[j] = 1 55 | break 56 | return output 57 | 58 | def find_gaps(sentence): 59 | gaps = [] 60 | prev, cur = -1, -1 61 | while cur < len(marked_sentence): 62 | if sentence[cur] == 1: 63 | if prev != -1: 64 | gaps.append(cur - prev) 65 | prev = cur 66 | cur += 1 67 | return gaps 68 | 69 | if __name__ == '__main__': 70 | lex = PronounLexicon() 71 | all_words = lex.all_words() 72 | in_file_path = "data/CBTest/data/cbt_train.txt" 73 | all_lens = [] 74 | all_gaps = [] 75 | with open(in_file_path) as f: 76 | for line in f: 77 | line = line.strip() 78 | marked_sentence = [1 if w in all_words else 0 for w in line.split(' ')] 79 | all_lens.append(len(marked_sentence)) 80 | # print(marked_sentence) 81 | gaps = find_gaps(marked_sentence) 82 | # print(gaps) 83 | all_gaps.extend(gaps) 84 | import numpy as np 85 | print(np.mean(all_lens), np.std(all_lens)) 86 | print(np.mean(all_gaps), np.std(all_gaps)) 87 | 88 | # l = 32 covers 81.5% of the sentences 89 | # l = 64 covers 98.4% of the sentences 90 | l = 64 91 | print(len(list(filter(lambda x: x <= l, all_lens))) / float(len(all_lens))) 92 | 93 | # l = 10 covers 82.7% of the gaps 94 | # l = 20 covers 97.2% of the gaps 95 | # l = 30 covers 99.4% of the gaps 96 | l = 20 97 | print(len(list(filter(lambda x: x <= l, all_gaps))) / float(len(all_gaps))) 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /fairseq/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .adaptive_softmax import AdaptiveSoftmax 9 | from .beamable_mm import BeamableMM 10 | from .character_token_embedder import CharacterTokenEmbedder 11 | from .conv_tbc import ConvTBC 12 | from .downsampled_multihead_attention import DownsampledMultiHeadAttention 13 | from .grad_multiply import GradMultiply 14 | from .highway import Highway 15 | from .learned_positional_embedding import LearnedPositionalEmbedding 16 | from .linearized_convolution import LinearizedConvolution 17 | from .multihead_attention import MultiheadAttention 18 | from .scalar_bias import ScalarBias 19 | from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding 20 | 21 | __all__ = [ 22 | 'AdaptiveSoftmax', 23 | 'BeamableMM', 24 | 'CharacterTokenEmbedder', 25 | 'ConvTBC', 26 | 'DownsampledMultiHeadAttention', 27 | 'GradMultiply', 28 | 'Highway', 29 | 'LearnedPositionalEmbedding', 30 | 'LinearizedConvolution', 31 | 'MultiheadAttention', 32 | 'ScalarBias', 33 | 'SinusoidalPositionalEmbedding', 34 | ] 35 | -------------------------------------------------------------------------------- /fairseq/modules/adaptive_softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | 14 | class AdaptiveSoftmax(nn.Module): 15 | """ 16 | This is an implementation of the efficient softmax approximation for 17 | graphical processing units (GPU), described in the paper "Efficient softmax 18 | approximation for GPUs" (http://arxiv.org/abs/1609.04309). 19 | """ 20 | 21 | def __init__(self, vocab_size, input_dim, cutoff, dropout): 22 | super().__init__() 23 | 24 | if vocab_size > cutoff[-1]: 25 | cutoff = cutoff + [vocab_size] 26 | else: 27 | assert vocab_size == cutoff[ 28 | -1], 'cannot specify cutoff smaller than vocab size' 29 | 30 | output_dim = cutoff[0] + len(cutoff) - 1 31 | 32 | self.vocab_size = vocab_size 33 | self.cutoff = cutoff 34 | self.dropout = dropout 35 | self.input_dim = input_dim 36 | 37 | self.lsm = nn.LogSoftmax(dim=1) 38 | self.head = nn.Linear(input_dim, output_dim, bias=False) 39 | self._make_tail(True) 40 | 41 | def init_weights(m): 42 | if hasattr(m, 'weight'): 43 | nn.init.xavier_uniform_(m.weight) 44 | 45 | self.apply(init_weights) 46 | 47 | self.register_buffer('version', torch.LongTensor([1])) 48 | # versions prior to 1 had a bug that offset indices on the head by 1 49 | self.buggy_offset = 0 50 | 51 | def _make_tail(self, fix_exponent): 52 | extra_denom = 1 if fix_exponent else 0 53 | 54 | self.tail = nn.ModuleList() 55 | for i in range(len(self.cutoff) - 1): 56 | self.tail.append( 57 | nn.Sequential( 58 | nn.Linear(self.input_dim, self.input_dim // 4 ** (i + extra_denom), bias=False), 59 | nn.Dropout(self.dropout), 60 | nn.Linear(self.input_dim // 4 ** (i + extra_denom), self.cutoff[i + 1] - self.cutoff[i], bias=False) 61 | ) 62 | ) 63 | 64 | def upgrade_state_dict_named(self, state_dict, name): 65 | version_name = name + '.version' 66 | if version_name not in state_dict: 67 | self.buggy_offset = 1 68 | self._make_tail(False) 69 | state_dict[version_name] = torch.LongTensor([1]) 70 | 71 | def adapt_target(self, target): 72 | """ 73 | In order to be efficient, the AdaptiveSoftMax does not compute the 74 | scores for all the word of the vocabulary for all the examples. It is 75 | thus necessary to call the method adapt_target of the AdaptiveSoftMax 76 | layer inside each forward pass. 77 | """ 78 | 79 | target = target.view(-1) 80 | new_target = [target.clone()] 81 | target_idxs = [] 82 | 83 | for i in range(len(self.cutoff) - 1): 84 | mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) 85 | new_target[0][mask] = self.cutoff[0] + i - self.buggy_offset 86 | 87 | if mask.any(): 88 | target_idxs.append(mask.nonzero().squeeze(1)) 89 | new_target.append(target[mask].add(-self.cutoff[i])) 90 | else: 91 | target_idxs.append(None) 92 | new_target.append(None) 93 | 94 | return new_target, target_idxs 95 | 96 | def forward(self, input, target): 97 | """ 98 | Args: 99 | input: (b x t x d) 100 | target: (b x t) 101 | Returns: 102 | 2 lists: output for each cutoff section and new targets by cut off 103 | """ 104 | 105 | input = input.contiguous().view(-1, input.size(-1)) 106 | input = F.dropout(input, p=self.dropout, training=self.training) 107 | 108 | new_target, target_idxs = self.adapt_target(target) 109 | output = [self.head(input)] 110 | 111 | for i in range(len(target_idxs)): 112 | if target_idxs[i] is not None: 113 | output.append(self.tail[i](input.index_select(0, target_idxs[i]))) 114 | else: 115 | output.append(None) 116 | 117 | return output, new_target 118 | 119 | def get_log_prob(self, input, target): 120 | """ 121 | Computes the log probabilities for all the words of the vocabulary, 122 | given a 2D tensor of hidden vectors. 123 | """ 124 | 125 | bsz, length, dim = input.size() 126 | input = input.contiguous().view(-1, dim) 127 | 128 | if target is not None: 129 | _, target_idxs = self.adapt_target(target) 130 | else: 131 | target_idxs = None 132 | 133 | head_y = self.head(input) 134 | log_probs = head_y.new_zeros(input.size(0), self.vocab_size) 135 | 136 | head_sz = self.cutoff[0] + len(self.tail) 137 | log_probs[:, :head_sz] = self.lsm(head_y) 138 | tail_priors = log_probs[:, self.cutoff[0] - self.buggy_offset: head_sz - self.buggy_offset].clone() 139 | 140 | for i in range(len(self.tail)): 141 | start = self.cutoff[i] 142 | end = self.cutoff[i + 1] 143 | 144 | if target_idxs is None: 145 | tail_out = log_probs[:, start:end] 146 | tail_out.copy_(self.tail[i](input)) 147 | log_probs[:, start:end] = self.lsm(tail_out).add_(tail_priors[:, i, None]) 148 | elif target_idxs[i] is not None: 149 | idxs = target_idxs[i] 150 | tail_out = log_probs[idxs, start:end] 151 | tail_out.copy_(self.tail[i](input[idxs])) 152 | log_probs[idxs, start:end] = self.lsm(tail_out).add_(tail_priors[idxs, i, None]) 153 | 154 | log_probs = log_probs.view(bsz, length, -1) 155 | return log_probs 156 | -------------------------------------------------------------------------------- /fairseq/modules/beamable_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class BeamableMM(nn.Module): 13 | """This module provides an optimized MM for beam decoding with attention. 14 | 15 | It leverage the fact that the source-side of the input is replicated beam 16 | times and the target-side of the input is of width one. This layer speeds up 17 | inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} 18 | with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. 19 | """ 20 | def __init__(self, beam_size=None): 21 | super(BeamableMM, self).__init__() 22 | self.beam_size = beam_size 23 | 24 | def forward(self, input1, input2): 25 | if ( 26 | not self.training and # test mode 27 | self.beam_size is not None and # beam size is set 28 | input1.dim() == 3 and # only support batched input 29 | input1.size(1) == 1 # single time step update 30 | ): 31 | bsz, beam = input1.size(0), self.beam_size 32 | 33 | # bsz x 1 x nhu --> bsz/beam x beam x nhu 34 | input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) 35 | 36 | # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu 37 | input2 = input2.unfold(0, beam, beam)[:, :, :, 0] 38 | 39 | # use non batched operation if bsz = beam 40 | if input1.size(0) == 1: 41 | output = torch.mm(input1[0, :, :], input2[0, :, :]) 42 | else: 43 | output = input1.bmm(input2) 44 | return output.view(bsz, 1, -1) 45 | else: 46 | return input1.bmm(input2) 47 | 48 | def set_beam_size(self, beam_size): 49 | self.beam_size = beam_size 50 | -------------------------------------------------------------------------------- /fairseq/modules/character_token_embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from torch import nn 13 | from torch.nn.utils.rnn import pad_sequence 14 | 15 | from typing import List, Tuple 16 | 17 | from .highway import Highway 18 | from fairseq.data import Dictionary 19 | 20 | 21 | class CharacterTokenEmbedder(torch.nn.Module): 22 | def __init__( 23 | self, 24 | vocab: Dictionary, 25 | filters: List[Tuple[int, int]], 26 | char_embed_dim: int, 27 | word_embed_dim: int, 28 | highway_layers: int, 29 | max_char_len: int = 50, 30 | ): 31 | super(CharacterTokenEmbedder, self).__init__() 32 | 33 | self.embedding_dim = word_embed_dim 34 | self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0) 35 | self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim)) 36 | self.eos_idx, self.unk_idx = 0, 1 37 | 38 | self.convolutions = nn.ModuleList() 39 | for width, out_c in filters: 40 | self.convolutions.append( 41 | nn.Conv1d(char_embed_dim, out_c, kernel_size=width) 42 | ) 43 | 44 | final_dim = sum(f[1] for f in filters) 45 | 46 | self.highway = Highway(final_dim, highway_layers) 47 | self.projection = nn.Linear(final_dim, word_embed_dim) 48 | 49 | self.set_vocab(vocab, max_char_len) 50 | self.reset_parameters() 51 | 52 | def set_vocab(self, vocab, max_char_len): 53 | word_to_char = torch.LongTensor(len(vocab), max_char_len) 54 | 55 | truncated = 0 56 | for i in range(len(vocab)): 57 | if i < vocab.nspecial: 58 | char_idxs = [0] * max_char_len 59 | else: 60 | chars = vocab[i].encode() 61 | # +1 for padding 62 | char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars)) 63 | if len(char_idxs) > max_char_len: 64 | truncated += 1 65 | char_idxs = char_idxs[:max_char_len] 66 | word_to_char[i] = torch.LongTensor(char_idxs) 67 | 68 | if truncated > 0: 69 | print('Truncated {} words longer than {} characters'.format(truncated, max_char_len)) 70 | 71 | self.vocab = vocab 72 | self.word_to_char = word_to_char 73 | 74 | @property 75 | def padding_idx(self): 76 | return self.vocab.pad() 77 | 78 | def reset_parameters(self): 79 | nn.init.xavier_normal_(self.char_embeddings.weight) 80 | nn.init.xavier_normal_(self.symbol_embeddings) 81 | nn.init.xavier_normal_(self.projection.weight) 82 | nn.init.constant_(self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.) 83 | nn.init.constant_(self.projection.bias, 0.) 84 | 85 | def forward( 86 | self, 87 | words: torch.Tensor, 88 | ): 89 | self.word_to_char = self.word_to_char.type_as(words) 90 | 91 | flat_words = words.view(-1) 92 | word_embs = self._convolve(self.word_to_char[flat_words]) 93 | 94 | pads = flat_words.eq(self.vocab.pad()) 95 | if pads.any(): 96 | word_embs[pads] = 0 97 | 98 | eos = flat_words.eq(self.vocab.eos()) 99 | if eos.any(): 100 | word_embs[eos] = self.symbol_embeddings[self.eos_idx] 101 | 102 | unk = flat_words.eq(self.vocab.unk()) 103 | if unk.any(): 104 | word_embs[unk] = self.symbol_embeddings[self.unk_idx] 105 | 106 | return word_embs.view(words.size() + (-1,)) 107 | 108 | def _convolve( 109 | self, 110 | char_idxs: torch.Tensor, 111 | ): 112 | char_embs = self.char_embeddings(char_idxs) 113 | char_embs = char_embs.transpose(1, 2) # BTC -> BCT 114 | 115 | conv_result = [] 116 | 117 | for i, conv in enumerate(self.convolutions): 118 | x = conv(char_embs) 119 | x, _ = torch.max(x, -1) 120 | x = F.relu(x) 121 | conv_result.append(x) 122 | 123 | conv_result = torch.cat(conv_result, dim=-1) 124 | conv_result = self.highway(conv_result) 125 | 126 | return self.projection(conv_result) 127 | -------------------------------------------------------------------------------- /fairseq/modules/conv_tbc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | from torch.nn.modules.utils import _single 10 | 11 | 12 | class ConvTBC(torch.nn.Module): 13 | """1D convolution over an input of shape (time x batch x channel) 14 | 15 | The implementation uses gemm to perform the convolution. This implementation 16 | is faster than cuDNN for small kernel sizes. 17 | """ 18 | def __init__(self, in_channels, out_channels, kernel_size, padding=0): 19 | super(ConvTBC, self).__init__() 20 | self.in_channels = in_channels 21 | self.out_channels = out_channels 22 | self.kernel_size = _single(kernel_size) 23 | self.padding = _single(padding) 24 | 25 | self.weight = torch.nn.Parameter(torch.Tensor( 26 | self.kernel_size[0], in_channels, out_channels)) 27 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) 28 | 29 | def forward(self, input): 30 | return input.contiguous().conv_tbc(self.weight, self.bias, self.padding[0]) 31 | 32 | def __repr__(self): 33 | s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}' 34 | ', padding={padding}') 35 | if self.bias is None: 36 | s += ', bias=False' 37 | s += ')' 38 | return s.format(name=self.__class__.__name__, **self.__dict__) 39 | -------------------------------------------------------------------------------- /fairseq/modules/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | 11 | class GradMultiply(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, x, scale): 14 | ctx.scale = scale 15 | res = x.new(x) 16 | return res 17 | 18 | @staticmethod 19 | def backward(ctx, grad): 20 | return grad * ctx.scale, None 21 | -------------------------------------------------------------------------------- /fairseq/modules/highway.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from torch import nn 12 | 13 | 14 | class Highway(torch.nn.Module): 15 | """ 16 | A `Highway layer 17 | Adopted from the AllenNLP implementation 18 | """ 19 | 20 | def __init__( 21 | self, 22 | input_dim: int, 23 | num_layers: int = 1 24 | ): 25 | super(Highway, self).__init__() 26 | self.input_dim = input_dim 27 | self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) 28 | for _ in range(num_layers)]) 29 | self.activation = nn.ReLU() 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | for layer in self.layers: 35 | # As per comment in AllenNLP: 36 | # We should bias the highway layer to just carry its input forward. We do that by 37 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to 38 | # be high, so we will carry the input forward. The bias on `B(x)` is the second half 39 | # of the bias vector in each Linear layer. 40 | nn.init.constant_(layer.bias[self.input_dim:], 1) 41 | 42 | nn.init.constant_(layer.bias[:self.input_dim], 0) 43 | nn.init.xavier_normal_(layer.weight) 44 | 45 | def forward( 46 | self, 47 | x: torch.Tensor 48 | ): 49 | for layer in self.layers: 50 | projection = layer(x) 51 | proj_x, gate = projection.chunk(2, dim=-1) 52 | proj_x = self.activation(proj_x) 53 | gate = F.sigmoid(gate) 54 | x = gate * x + (1 - gate) * proj_x 55 | return x 56 | -------------------------------------------------------------------------------- /fairseq/modules/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.nn as nn 9 | 10 | from fairseq import utils 11 | 12 | 13 | class LearnedPositionalEmbedding(nn.Embedding): 14 | """This module learns positional embeddings up to a fixed maximum size. 15 | 16 | Padding symbols are ignored, but it is necessary to specify whether padding 17 | is added on the left side (left_pad=True) or right side (left_pad=False). 18 | """ 19 | 20 | def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad): 21 | super().__init__(num_embeddings, embedding_dim, padding_idx) 22 | self.left_pad = left_pad 23 | 24 | def forward(self, input, incremental_state=None): 25 | """Input is expected to be of size [bsz x seqlen].""" 26 | if incremental_state is not None: 27 | # positions is the same for every token when decoding a single step 28 | positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) 29 | else: 30 | positions = utils.make_positions(input.data, self.padding_idx, self.left_pad) 31 | return super().forward(positions) 32 | 33 | def max_positions(self): 34 | """Maximum number of supported positions.""" 35 | return self.num_embeddings - self.padding_idx - 1 36 | -------------------------------------------------------------------------------- /fairseq/modules/linearized_convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from fairseq import utils 12 | 13 | from .conv_tbc import ConvTBC 14 | 15 | 16 | class LinearizedConvolution(ConvTBC): 17 | """An optimized version of nn.Conv1d. 18 | 19 | At training time, this module uses ConvTBC, which is an optimized version 20 | of Conv1d. At inference time, it optimizes incremental generation (i.e., 21 | one time step at a time) by replacing the convolutions with linear layers. 22 | Note that the input order changes from training to inference. 23 | """ 24 | 25 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 26 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 27 | self._linearized_weight = None 28 | self.register_backward_hook(self._clear_linearized_weight) 29 | 30 | def forward(self, input, incremental_state=None): 31 | """ 32 | Input: 33 | Time x Batch x Channel during training 34 | Batch x Time x Channel during inference 35 | Args: 36 | incremental_state: Used to buffer signal; if not None, then input is 37 | expected to contain a single frame. If the input order changes 38 | between time steps, call reorder_incremental_state. 39 | """ 40 | if incremental_state is None: 41 | output = super().forward(input) 42 | if self.kernel_size[0] > 1 and self.padding[0] > 0: 43 | # remove future timesteps added by padding 44 | output = output[:-self.padding[0], :, :] 45 | return output 46 | 47 | # reshape weight 48 | weight = self._get_linearized_weight() 49 | kw = self.kernel_size[0] 50 | 51 | bsz = input.size(0) # input: bsz x len x dim 52 | if kw > 1: 53 | input = input.data 54 | input_buffer = self._get_input_buffer(incremental_state) 55 | if input_buffer is None: 56 | input_buffer = input.new(bsz, kw, input.size(2)).zero_() 57 | self._set_input_buffer(incremental_state, input_buffer) 58 | else: 59 | # shift buffer 60 | input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() 61 | # append next input 62 | input_buffer[:, -1, :] = input[:, -1, :] 63 | input = input_buffer 64 | with torch.no_grad(): 65 | output = F.linear(input.view(bsz, -1), weight, self.bias) 66 | return output.view(bsz, 1, -1) 67 | 68 | def reorder_incremental_state(self, incremental_state, new_order): 69 | input_buffer = self._get_input_buffer(incremental_state) 70 | if input_buffer is not None: 71 | input_buffer = input_buffer.index_select(0, new_order) 72 | self._set_input_buffer(incremental_state, input_buffer) 73 | 74 | def _get_input_buffer(self, incremental_state): 75 | return utils.get_incremental_state(self, incremental_state, 'input_buffer') 76 | 77 | def _set_input_buffer(self, incremental_state, new_buffer): 78 | return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer) 79 | 80 | def _get_linearized_weight(self): 81 | if self._linearized_weight is None: 82 | kw = self.kernel_size[0] 83 | weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() 84 | assert weight.size() == (self.out_channels, kw, self.in_channels) 85 | self._linearized_weight = weight.view(self.out_channels, -1) 86 | return self._linearized_weight 87 | 88 | def _clear_linearized_weight(self, *args): 89 | self._linearized_weight = None 90 | -------------------------------------------------------------------------------- /fairseq/modules/scalar_bias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | # 8 | 9 | import torch 10 | 11 | 12 | class ScalarBias(torch.autograd.Function): 13 | """ 14 | Adds a vector of scalars, used in self-attention mechanism to allow 15 | the model to optionally attend to this vector instead of the past 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, input, dim, bias_init): 20 | size = list(input.size()) 21 | size[dim] += 1 22 | output = input.new(*size).fill_(bias_init) 23 | output.narrow(dim, 1, size[dim] - 1).copy_(input) 24 | ctx.dim = dim 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad): 29 | return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None 30 | 31 | 32 | def scalar_bias(input, dim, bias_init=0): 33 | return ScalarBias.apply(input, dim, bias_init) 34 | -------------------------------------------------------------------------------- /fairseq/modules/sinusoidal_positional_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from fairseq import utils 14 | 15 | 16 | class SinusoidalPositionalEmbedding(nn.Module): 17 | """This module produces sinusoidal positional embeddings of any length. 18 | 19 | Padding symbols are ignored, but it is necessary to specify whether padding 20 | is added on the left side (left_pad=True) or right side (left_pad=False). 21 | """ 22 | 23 | def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024): 24 | super().__init__() 25 | self.embedding_dim = embedding_dim 26 | self.padding_idx = padding_idx 27 | self.left_pad = left_pad 28 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 29 | init_size, 30 | embedding_dim, 31 | padding_idx, 32 | ) 33 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 34 | 35 | @staticmethod 36 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 37 | """Build sinusoidal embeddings. 38 | 39 | This matches the implementation in tensor2tensor, but differs slightly 40 | from the description in Section 3.5 of "Attention Is All You Need". 41 | """ 42 | half_dim = embedding_dim // 2 43 | emb = math.log(10000) / (half_dim - 1) 44 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 45 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 46 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 47 | if embedding_dim % 2 == 1: 48 | # zero pad 49 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 50 | if padding_idx is not None: 51 | emb[padding_idx, :] = 0 52 | return emb 53 | 54 | def forward(self, input, incremental_state=None): 55 | """Input is expected to be of size [bsz x seqlen].""" 56 | # recompute/expand embeddings if needed 57 | bsz, seq_len = input.size() 58 | max_pos = self.padding_idx + 1 + seq_len 59 | if self.weights is None or max_pos > self.weights.size(0): 60 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 61 | max_pos, 62 | self.embedding_dim, 63 | self.padding_idx, 64 | ) 65 | self.weights = self.weights.type_as(self._float_tensor) 66 | 67 | if incremental_state is not None: 68 | # positions is the same for every token when decoding a single step 69 | return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1) 70 | 71 | positions = utils.make_positions(input.data, self.padding_idx, self.left_pad) 72 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 73 | 74 | def max_positions(self): 75 | """Maximum number of supported positions.""" 76 | return int(1e5) # an arbitrary large number 77 | -------------------------------------------------------------------------------- /fairseq/multiprocessing_pdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import multiprocessing 9 | import os 10 | import pdb 11 | import sys 12 | 13 | 14 | class MultiprocessingPdb(pdb.Pdb): 15 | """A Pdb wrapper that works in a multiprocessing environment. 16 | 17 | Usage: `from fairseq import pdb; pdb.set_trace()` 18 | """ 19 | 20 | _stdin_fd = sys.stdin.fileno() 21 | _stdin = None 22 | _stdin_lock = multiprocessing.Lock() 23 | 24 | def __init__(self): 25 | pdb.Pdb.__init__(self, nosigint=True) 26 | 27 | def _cmdloop(self): 28 | stdin_bak = sys.stdin 29 | with self._stdin_lock: 30 | try: 31 | if not self._stdin: 32 | self._stdin = os.fdopen(self._stdin_fd) 33 | sys.stdin = self._stdin 34 | self.cmdloop() 35 | finally: 36 | sys.stdin = stdin_bak 37 | 38 | 39 | pdb = MultiprocessingPdb() 40 | -------------------------------------------------------------------------------- /fairseq/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_optimizer import FairseqOptimizer 12 | 13 | 14 | OPTIMIZER_REGISTRY = {} 15 | OPTIMIZER_CLASS_NAMES = set() 16 | 17 | 18 | def build_optimizer(args, params): 19 | params = filter(lambda p: p.requires_grad, params) 20 | return OPTIMIZER_REGISTRY[args.optimizer](args, params) 21 | 22 | 23 | def register_optimizer(name): 24 | """Decorator to register a new optimizer.""" 25 | 26 | def register_optimizer_cls(cls): 27 | if name in OPTIMIZER_REGISTRY: 28 | raise ValueError('Cannot register duplicate optimizer ({})'.format(name)) 29 | if not issubclass(cls, FairseqOptimizer): 30 | raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__)) 31 | if cls.__name__ in OPTIMIZER_CLASS_NAMES: 32 | # We use the optimizer class name as a unique identifier in 33 | # checkpoints, so all optimizer must have unique class names. 34 | raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__)) 35 | OPTIMIZER_REGISTRY[name] = cls 36 | OPTIMIZER_CLASS_NAMES.add(cls.__name__) 37 | return cls 38 | 39 | return register_optimizer_cls 40 | 41 | 42 | # automatically import any Python files in the optim/ directory 43 | for file in os.listdir(os.path.dirname(__file__)): 44 | if file.endswith('.py') and not file.startswith('_'): 45 | module = file[:file.find('.py')] 46 | importlib.import_module('fairseq.optim.' + module) 47 | -------------------------------------------------------------------------------- /fairseq/optim/adagrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('adagrad') 14 | class Adagrad(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'weight_decay': self.args.weight_decay, 30 | } 31 | -------------------------------------------------------------------------------- /fairseq/optim/adam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import torch 10 | import torch.optim 11 | 12 | from . import FairseqOptimizer, register_optimizer 13 | 14 | 15 | @register_optimizer('adam') 16 | class FairseqAdam(FairseqOptimizer): 17 | def __init__(self, args, params): 18 | super().__init__(args, params) 19 | self._optimizer = Adam(params, **self.optimizer_config) 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | """Add optimizer-specific arguments to the parser.""" 24 | parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', 25 | help='betas for Adam optimizer') 26 | parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', 27 | help='epsilon for Adam optimizer') 28 | parser.add_argument('--amsgrad', dest='amsgrad', action='store_true') 29 | parser.add_argument('--no-amsgrad', dest='amsgrad', action='store_false') 30 | parser.set_defaults(amsgrad=False) 31 | 32 | @property 33 | def optimizer_config(self): 34 | """ 35 | Return a kwarg dictionary that will be used to override optimizer 36 | args stored in checkpoints. This allows us to load a checkpoint and 37 | resume training using a different set of optimizer args, e.g., with a 38 | different learning rate. 39 | """ 40 | return { 41 | 'lr': self.args.lr[0], 42 | 'betas': eval(self.args.adam_betas), 43 | 'eps': self.args.adam_eps, 44 | 'weight_decay': self.args.weight_decay, 45 | 'amsgrad': self.args.amsgrad, 46 | } 47 | 48 | 49 | class Adam(torch.optim.Optimizer): 50 | """Implements Adam algorithm. 51 | 52 | This implementation is modified from torch.optim.Adam based on: 53 | `Fixed Weight Decay Regularization in Adam` 54 | (see https://arxiv.org/abs/1711.05101) 55 | 56 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 57 | 58 | Arguments: 59 | params (iterable): iterable of parameters to optimize or dicts defining 60 | parameter groups 61 | lr (float, optional): learning rate (default: 1e-3) 62 | betas (Tuple[float, float], optional): coefficients used for computing 63 | running averages of gradient and its square (default: (0.9, 0.999)) 64 | eps (float, optional): term added to the denominator to improve 65 | numerical stability (default: 1e-8) 66 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 67 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 68 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 69 | 70 | .. _Adam\: A Method for Stochastic Optimization: 71 | https://arxiv.org/abs/1412.6980 72 | .. _On the Convergence of Adam and Beyond: 73 | https://openreview.net/forum?id=ryQu7f-RZ 74 | """ 75 | 76 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 77 | weight_decay=0, amsgrad=False): 78 | defaults = dict(lr=lr, betas=betas, eps=eps, 79 | weight_decay=weight_decay, amsgrad=amsgrad) 80 | super(Adam, self).__init__(params, defaults) 81 | 82 | def step(self, closure=None): 83 | """Performs a single optimization step. 84 | 85 | Arguments: 86 | closure (callable, optional): A closure that reevaluates the model 87 | and returns the loss. 88 | """ 89 | loss = None 90 | if closure is not None: 91 | loss = closure() 92 | 93 | for group in self.param_groups: 94 | for p in group['params']: 95 | if p.grad is None: 96 | continue 97 | grad = p.grad.data 98 | if grad.is_sparse: 99 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 100 | amsgrad = group['amsgrad'] 101 | 102 | state = self.state[p] 103 | 104 | # State initialization 105 | if len(state) == 0: 106 | state['step'] = 0 107 | # Exponential moving average of gradient values 108 | state['exp_avg'] = torch.zeros_like(p.data) 109 | # Exponential moving average of squared gradient values 110 | state['exp_avg_sq'] = torch.zeros_like(p.data) 111 | if amsgrad: 112 | # Maintains max of all exp. moving avg. of sq. grad. values 113 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 114 | 115 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 116 | if amsgrad: 117 | max_exp_avg_sq = state['max_exp_avg_sq'] 118 | beta1, beta2 = group['betas'] 119 | 120 | state['step'] += 1 121 | 122 | # Decay the first and second moment running average coefficient 123 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 124 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 125 | if amsgrad: 126 | # Maintains the maximum of all 2nd moment running avg. till now 127 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 128 | # Use the max. for normalizing running avg. of gradient 129 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 130 | else: 131 | denom = exp_avg_sq.sqrt().add_(group['eps']) 132 | 133 | bias_correction1 = 1 - beta1 ** state['step'] 134 | bias_correction2 = 1 - beta2 ** state['step'] 135 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 136 | 137 | if group['weight_decay'] != 0: 138 | p.data.add_(-group['weight_decay'] * group['lr'], p.data) 139 | 140 | p.data.addcdiv_(-step_size, exp_avg, denom) 141 | 142 | return loss 143 | -------------------------------------------------------------------------------- /fairseq/optim/fairseq_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | 10 | import torch.optim 11 | 12 | 13 | class FairseqOptimizer(object): 14 | 15 | def __init__(self, args, params): 16 | super().__init__() 17 | self.args = args 18 | self.params = params 19 | 20 | @staticmethod 21 | def add_args(parser): 22 | """Add optimizer-specific arguments to the parser.""" 23 | pass 24 | 25 | @property 26 | def optimizer(self): 27 | """Return a torch.optim.optimizer.Optimizer instance.""" 28 | if not hasattr(self, '_optimizer'): 29 | raise NotImplementedError 30 | if not isinstance(self._optimizer, torch.optim.Optimizer): 31 | raise ValueError('_optimizer must be an instance of torch.optim.Optimizer') 32 | return self._optimizer 33 | 34 | @property 35 | def optimizer_config(self): 36 | """ 37 | Return a kwarg dictionary that will be used to override optimizer 38 | args stored in checkpoints. This allows us to load a checkpoint and 39 | resume training using a different set of optimizer args, e.g., with a 40 | different learning rate. 41 | """ 42 | raise NotImplementedError 43 | 44 | def get_lr(self): 45 | """Return the current learning rate.""" 46 | return self.optimizer.param_groups[0]['lr'] 47 | 48 | def set_lr(self, lr): 49 | """Set the learning rate.""" 50 | for param_group in self.optimizer.param_groups: 51 | param_group['lr'] = lr 52 | 53 | def state_dict(self): 54 | """Return the optimizer's state dict.""" 55 | return self.optimizer.state_dict() 56 | 57 | def load_state_dict(self, state_dict, optimizer_overrides=None): 58 | """Load an optimizer state dict. 59 | 60 | In general we should prefer the configuration of the existing optimizer 61 | instance (e.g., learning rate) over that found in the state_dict. This 62 | allows us to resume training from a checkpoint using a new set of 63 | optimizer args. 64 | """ 65 | self.optimizer.load_state_dict(state_dict) 66 | 67 | if optimizer_overrides is not None and len(optimizer_overrides) > 0: 68 | # override learning rate, momentum, etc. with latest values 69 | for group in self.optimizer.param_groups: 70 | group.update(optimizer_overrides) 71 | 72 | def backward(self, loss): 73 | loss.backward() 74 | 75 | def multiply_grads(self, c): 76 | """Multiplies grads by a constant ``c``.""" 77 | for p in self.params: 78 | p.grad.data.mul_(c) 79 | 80 | def clip_grad_norm(self, max_norm): 81 | """Clips gradient norm.""" 82 | if max_norm > 0: 83 | return torch.nn.utils.clip_grad_norm_(self.params, max_norm) 84 | else: 85 | return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params)) 86 | 87 | def step(self, closure=None): 88 | """Performs a single optimization step.""" 89 | return self.optimizer.step(closure) 90 | 91 | def zero_grad(self): 92 | """Clears the gradients of all optimized parameters.""" 93 | return self.optimizer.zero_grad() 94 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_lr_scheduler import FairseqLRScheduler 12 | 13 | 14 | LR_SCHEDULER_REGISTRY = {} 15 | 16 | 17 | def build_lr_scheduler(args, optimizer): 18 | return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer) 19 | 20 | 21 | def register_lr_scheduler(name): 22 | """Decorator to register a new LR scheduler.""" 23 | 24 | def register_lr_scheduler_cls(cls): 25 | if name in LR_SCHEDULER_REGISTRY: 26 | raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name)) 27 | if not issubclass(cls, FairseqLRScheduler): 28 | raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__)) 29 | LR_SCHEDULER_REGISTRY[name] = cls 30 | return cls 31 | 32 | return register_lr_scheduler_cls 33 | 34 | 35 | # automatically import any Python files in the optim/lr_scheduler/ directory 36 | for file in os.listdir(os.path.dirname(__file__)): 37 | if file.endswith('.py') and not file.startswith('_'): 38 | module = file[:file.find('.py')] 39 | importlib.import_module('fairseq.optim.lr_scheduler.' + module) 40 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from .. import FairseqOptimizer 9 | 10 | 11 | class FairseqLRScheduler(object): 12 | 13 | def __init__(self, args, optimizer): 14 | super().__init__() 15 | if not isinstance(optimizer, FairseqOptimizer): 16 | raise ValueError('optimizer must be an instance of FairseqOptimizer') 17 | self.args = args 18 | self.optimizer = optimizer 19 | self.best = None 20 | 21 | @staticmethod 22 | def add_args(parser): 23 | """Add arguments to the parser for this LR scheduler.""" 24 | pass 25 | 26 | def state_dict(self): 27 | """Return the LR scheduler state dict.""" 28 | return {'best': self.best} 29 | 30 | def load_state_dict(self, state_dict): 31 | """Load an LR scheduler state dict.""" 32 | self.best = state_dict['best'] 33 | 34 | def step(self, epoch, val_loss=None): 35 | """Update the learning rate at the end of the given epoch.""" 36 | if val_loss is not None: 37 | if self.best is None: 38 | self.best = val_loss 39 | else: 40 | self.best = min(self.best, val_loss) 41 | 42 | def step_update(self, num_updates): 43 | """Update the learning rate after each update.""" 44 | return self.optimizer.get_lr() 45 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/fixed_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('fixed') 12 | class FixedSchedule(FairseqLRScheduler): 13 | """Decay the LR on a fixed schedule.""" 14 | 15 | def __init__(self, args, optimizer): 16 | super().__init__(args, optimizer) 17 | 18 | # set defaults 19 | args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0 20 | 21 | self.lr = args.lr[0] 22 | if args.warmup_updates > 0: 23 | self.warmup_factor = 1. / args.warmup_updates 24 | else: 25 | self.warmup_factor = 1 26 | 27 | @staticmethod 28 | def add_args(parser): 29 | """Add arguments to the parser for this LR scheduler.""" 30 | parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', 31 | help='force annealing at specified epoch') 32 | parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', 33 | help='warmup the learning rate linearly for the first N updates') 34 | 35 | def get_next_lr(self, epoch): 36 | lrs = self.args.lr 37 | if self.args.force_anneal is None or epoch < self.args.force_anneal: 38 | # use fixed LR schedule 39 | next_lr = lrs[min(epoch, len(lrs) - 1)] 40 | else: 41 | # annneal based on lr_shrink 42 | next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) 43 | return next_lr 44 | 45 | def step(self, epoch, val_loss=None): 46 | """Update the learning rate at the end of the given epoch.""" 47 | super().step(epoch, val_loss) 48 | self.lr = self.get_next_lr(epoch) 49 | self.optimizer.set_lr(self.warmup_factor * self.lr) 50 | return self.optimizer.get_lr() 51 | 52 | def step_update(self, num_updates): 53 | """Update the learning rate after each update.""" 54 | if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: 55 | self.warmup_factor = num_updates / float(self.args.warmup_updates) 56 | self.optimizer.set_lr(self.warmup_factor * self.lr) 57 | return self.optimizer.get_lr() 58 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/inverse_square_root_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from . import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('inverse_sqrt') 12 | class InverseSquareRootSchedule(FairseqLRScheduler): 13 | """Decay the LR based on the inverse square root of the update number. 14 | 15 | We also support a warmup phase where we linearly increase the learning rate 16 | from some initial learning rate (`--warmup-init-lr`) until the configured 17 | learning rate (`--lr`). Thereafter we decay proportional to the number of 18 | updates, with a decay factor set to align with the configured learning rate. 19 | 20 | During warmup: 21 | 22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 23 | lr = lrs[update_num] 24 | 25 | After warmup: 26 | 27 | lr = decay_factor / sqrt(update_num) 28 | 29 | where 30 | 31 | decay_factor = args.lr * sqrt(args.warmup_updates) 32 | """ 33 | 34 | def __init__(self, args, optimizer): 35 | super().__init__(args, optimizer) 36 | if len(args.lr) > 1: 37 | raise ValueError( 38 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.' 39 | ' Consider --lr-scheduler=fixed instead.' 40 | ) 41 | warmup_end_lr = args.lr[0] 42 | if args.warmup_init_lr < 0: 43 | args.warmup_init_lr = warmup_end_lr 44 | 45 | # linearly warmup for the first args.warmup_updates 46 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 47 | 48 | # then, decay prop. to the inverse square root of the update number 49 | self.decay_factor = warmup_end_lr * args.warmup_updates**0.5 50 | 51 | # initial learning rate 52 | self.lr = args.warmup_init_lr 53 | self.optimizer.set_lr(self.lr) 54 | 55 | @staticmethod 56 | def add_args(parser): 57 | """Add arguments to the parser for this LR scheduler.""" 58 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', 59 | help='warmup the learning rate linearly for the first N updates') 60 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 61 | help='initial learning rate during warmup phase; default is args.lr') 62 | 63 | def step(self, epoch, val_loss=None): 64 | """Update the learning rate at the end of the given epoch.""" 65 | super().step(epoch, val_loss) 66 | # we don't change the learning rate at epoch boundaries 67 | return self.optimizer.get_lr() 68 | 69 | def step_update(self, num_updates): 70 | """Update the learning rate after each update.""" 71 | if num_updates < self.args.warmup_updates: 72 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step 73 | else: 74 | self.lr = self.decay_factor * num_updates**-0.5 75 | self.optimizer.set_lr(self.lr) 76 | return self.lr 77 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim.lr_scheduler 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('reduce_lr_on_plateau') 14 | class ReduceLROnPlateau(FairseqLRScheduler): 15 | """Decay the LR by a factor every time the validation loss plateaus.""" 16 | 17 | def __init__(self, args, optimizer): 18 | super().__init__(args, optimizer) 19 | if len(args.lr) > 1: 20 | raise ValueError( 21 | 'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.' 22 | ' Consider --lr-scheduler=fixed instead.' 23 | ) 24 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 25 | self.optimizer.optimizer, patience=0, factor=args.lr_shrink) 26 | 27 | def state_dict(self): 28 | """Return the LR scheduler state dict.""" 29 | return { 30 | 'best': self.lr_scheduler.best, 31 | 'last_epoch': self.lr_scheduler.last_epoch, 32 | } 33 | 34 | def load_state_dict(self, state_dict): 35 | """Load an LR scheduler state dict.""" 36 | self.lr_scheduler.best = state_dict['best'] 37 | if 'last_epoch' in state_dict: 38 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 39 | 40 | def step(self, epoch, val_loss=None): 41 | """Update the learning rate at the end of the given epoch.""" 42 | if val_loss is not None: 43 | self.lr_scheduler.step(val_loss, epoch) 44 | else: 45 | self.lr_scheduler.last_epoch = epoch 46 | return self.optimizer.get_lr() 47 | -------------------------------------------------------------------------------- /fairseq/optim/lr_scheduler/reduce_lr_on_plateau_patience.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim.lr_scheduler 9 | 10 | from . import FairseqLRScheduler, register_lr_scheduler 11 | 12 | 13 | @register_lr_scheduler('reduce_lr_on_plateau_patience') 14 | class ReduceLROnPlateauPatience(FairseqLRScheduler): 15 | """Decay the LR by a factor every time the validation loss plateaus.""" 16 | 17 | def __init__(self, args, optimizer): 18 | super().__init__(args, optimizer) 19 | if len(args.lr) > 1: 20 | raise ValueError( 21 | 'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.' 22 | ' Consider --lr-scheduler=fixed instead.' 23 | ) 24 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 25 | self.optimizer.optimizer, patience=args.lr_scheduler_patience, 26 | factor=args.lr_shrink) 27 | 28 | @staticmethod 29 | def add_args(parser): 30 | """Add arguments to the parser for this LR scheduler.""" 31 | parser.add_argument('--lr-scheduler-patience', default=3, type=int, 32 | help='lr scheduler patience') 33 | 34 | def state_dict(self): 35 | """Return the LR scheduler state dict.""" 36 | return { 37 | 'best': self.lr_scheduler.best, 38 | 'last_epoch': self.lr_scheduler.last_epoch, 39 | } 40 | 41 | def load_state_dict(self, state_dict): 42 | """Load an LR scheduler state dict.""" 43 | self.lr_scheduler.best = state_dict['best'] 44 | if 'last_epoch' in state_dict: 45 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 46 | 47 | def step(self, epoch, val_loss=None): 48 | """Update the learning rate at the end of the given epoch.""" 49 | if val_loss is not None: 50 | self.lr_scheduler.step(val_loss, epoch) 51 | else: 52 | self.lr_scheduler.last_epoch = epoch 53 | return self.optimizer.get_lr() 54 | -------------------------------------------------------------------------------- /fairseq/optim/nag.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.optim.optimizer import Optimizer, required 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('nag') 14 | class FairseqNAG(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = NAG(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'momentum': self.args.momentum, 30 | 'weight_decay': self.args.weight_decay, 31 | } 32 | 33 | 34 | class NAG(Optimizer): 35 | def __init__(self, params, lr=required, momentum=0, weight_decay=0): 36 | defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) 37 | super(NAG, self).__init__(params, defaults) 38 | 39 | def step(self, closure=None): 40 | """Performs a single optimization step. 41 | 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | loss = closure() 49 | 50 | for group in self.param_groups: 51 | weight_decay = group['weight_decay'] 52 | momentum = group['momentum'] 53 | lr = group['lr'] 54 | lr_old = group.get('lr_old', lr) 55 | lr_correct = lr / lr_old 56 | 57 | for p in group['params']: 58 | if p.grad is None: 59 | continue 60 | 61 | d_p = p.grad.data 62 | param_state = self.state[p] 63 | if 'momentum_buffer' not in param_state: 64 | param_state['momentum_buffer'] = d_p.clone().zero_() 65 | 66 | buf = param_state['momentum_buffer'] 67 | 68 | if weight_decay != 0: 69 | p.data.mul_(1 - lr * weight_decay) 70 | p.data.add_(momentum * momentum * lr_correct, buf) 71 | p.data.add_(-(1 + momentum) * lr, d_p) 72 | 73 | buf.mul_(momentum * lr_correct).add_(-lr, d_p) 74 | 75 | group['lr_old'] = lr 76 | 77 | return loss 78 | -------------------------------------------------------------------------------- /fairseq/optim/rmsprop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from torch.optim.optimizer import Optimizer, required 9 | from torch.optim import RMSprop 10 | 11 | from . import FairseqOptimizer, register_optimizer 12 | 13 | 14 | @register_optimizer('rmsprop') 15 | class FairseqRMSprop(FairseqOptimizer): 16 | def __init__(self, args, params): 17 | super().__init__(args, params) 18 | self._optimizer = RMSprop(params, **self.optimizer_config) 19 | 20 | @staticmethod 21 | def add_args(parser): 22 | """Add optimizer-specific arguments to the parser.""" 23 | parser.add_argument('--rmsprop-alpha', type=float, default=0.99, metavar='D', 24 | help='alpha for RMSprop optimizer') 25 | parser.add_argument('--rmsprop-eps', type=float, default=1e-8, metavar='D', 26 | help='epsilon for RMSprop optimizer') 27 | parser.add_argument('--rmsprop-centered', dest='rmsprop_centered', action='store_true') 28 | parser.add_argument('--no-rmsprop-centered', dest='rmsprop_centered', action='store_false') 29 | parser.set_defaults(rmsprop_centered=False) 30 | 31 | 32 | @property 33 | def optimizer_config(self): 34 | """ 35 | Return a kwarg dictionary that will be used to override optimizer 36 | args stored in checkpoints. This allows us to load a checkpoint and 37 | resume training using a different set of optimizer args, e.g., with a 38 | different learning rate. 39 | """ 40 | return { 41 | 'lr': self.args.lr[0], 42 | 'alpha': self.args.rmsprop_alpha, 43 | 'eps': self.args.rmsprop_eps, 44 | 'weight_decay': self.args.weight_decay, 45 | 'momentum': self.args.momentum, 46 | 'centered': self.args.rmsprop_centered, 47 | } -------------------------------------------------------------------------------- /fairseq/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch.optim 9 | 10 | from . import FairseqOptimizer, register_optimizer 11 | 12 | 13 | @register_optimizer('sgd') 14 | class SGD(FairseqOptimizer): 15 | def __init__(self, args, params): 16 | super().__init__(args, params) 17 | self._optimizer = torch.optim.SGD(params, **self.optimizer_config) 18 | 19 | @property 20 | def optimizer_config(self): 21 | """ 22 | Return a kwarg dictionary that will be used to override optimizer 23 | args stored in checkpoints. This allows us to load a checkpoint and 24 | resume training using a different set of optimizer args, e.g., with a 25 | different learning rate. 26 | """ 27 | return { 28 | 'lr': self.args.lr[0], 29 | 'momentum': self.args.momentum, 30 | 'weight_decay': self.args.weight_decay, 31 | } 32 | -------------------------------------------------------------------------------- /fairseq/progress_bar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | """ 9 | Wrapper around various loggers and progress bars (e.g., tqdm). 10 | """ 11 | 12 | from collections import OrderedDict 13 | import json 14 | from numbers import Number 15 | import sys 16 | 17 | from tqdm import tqdm 18 | 19 | from fairseq.meters import AverageMeter 20 | 21 | 22 | def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'): 23 | if args.log_format is None: 24 | args.log_format = no_progress_bar if args.no_progress_bar else default 25 | 26 | if args.log_format == 'tqdm' and not sys.stderr.isatty(): 27 | args.log_format = 'simple' 28 | 29 | if args.log_format == 'json': 30 | bar = json_progress_bar(iterator, epoch, prefix, args.log_interval) 31 | elif args.log_format == 'none': 32 | bar = noop_progress_bar(iterator, epoch, prefix) 33 | elif args.log_format == 'simple': 34 | bar = simple_progress_bar(iterator, epoch, prefix, args.log_interval) 35 | elif args.log_format == 'tqdm': 36 | bar = tqdm_progress_bar(iterator, epoch, prefix) 37 | else: 38 | raise ValueError('Unknown log format: {}'.format(args.log_format)) 39 | return bar 40 | 41 | 42 | class progress_bar(object): 43 | """Abstract class for progress bars.""" 44 | def __init__(self, iterable, epoch=None, prefix=None): 45 | self.iterable = iterable 46 | self.epoch = epoch 47 | self.prefix = '' 48 | if epoch is not None: 49 | self.prefix += '| epoch {:03d}'.format(epoch) 50 | if prefix is not None: 51 | self.prefix += ' | {}'.format(prefix) 52 | 53 | def __enter__(self): 54 | return self 55 | 56 | def __exit__(self, *exc): 57 | return False 58 | 59 | def __iter__(self): 60 | raise NotImplementedError 61 | 62 | def log(self, stats): 63 | """Log intermediate stats according to log_interval.""" 64 | raise NotImplementedError 65 | 66 | def print(self, stats): 67 | """Print end-of-epoch stats.""" 68 | raise NotImplementedError 69 | 70 | def _str_commas(self, stats): 71 | return ', '.join(key + '=' + stats[key].strip() 72 | for key in stats.keys()) 73 | 74 | def _str_pipes(self, stats): 75 | return ' | '.join(key + ' ' + stats[key].strip() 76 | for key in stats.keys()) 77 | 78 | def _format_stats(self, stats): 79 | postfix = OrderedDict(stats) 80 | # Preprocess stats according to datatype 81 | for key in postfix.keys(): 82 | # Number: limit the length of the string 83 | if isinstance(postfix[key], Number): 84 | postfix[key] = '{:g}'.format(postfix[key]) 85 | # Meter: display both current and average value 86 | elif isinstance(postfix[key], AverageMeter): 87 | postfix[key] = '{:.2f} ({:.2f})'.format( 88 | postfix[key].val, postfix[key].avg) 89 | # Else for any other type, try to get the string conversion 90 | elif not isinstance(postfix[key], str): 91 | postfix[key] = str(postfix[key]) 92 | # Else if it's a string, don't need to preprocess anything 93 | return postfix 94 | 95 | 96 | class json_progress_bar(progress_bar): 97 | """Log output in JSON format.""" 98 | 99 | def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): 100 | super().__init__(iterable, epoch, prefix) 101 | self.log_interval = log_interval 102 | self.stats = None 103 | 104 | def __iter__(self): 105 | size = float(len(self.iterable)) 106 | for i, obj in enumerate(self.iterable): 107 | yield obj 108 | if self.stats is not None and i > 0 and \ 109 | self.log_interval is not None and i % self.log_interval == 0: 110 | update = self.epoch - 1 + float(i / size) if self.epoch is not None else None 111 | stats = self._format_stats(self.stats, epoch=self.epoch, update=update) 112 | print(json.dumps(stats), flush=True) 113 | 114 | def log(self, stats): 115 | """Log intermediate stats according to log_interval.""" 116 | self.stats = stats 117 | 118 | def print(self, stats): 119 | """Print end-of-epoch stats.""" 120 | self.stats = stats 121 | stats = self._format_stats(self.stats, epoch=self.epoch) 122 | print(json.dumps(stats), flush=True) 123 | 124 | def _format_stats(self, stats, epoch=None, update=None): 125 | postfix = OrderedDict() 126 | if epoch is not None: 127 | postfix['epoch'] = epoch 128 | if update is not None: 129 | postfix['update'] = update 130 | # Preprocess stats according to datatype 131 | for key in stats.keys(): 132 | # Meter: display both current and average value 133 | if isinstance(stats[key], AverageMeter): 134 | postfix[key] = stats[key].val 135 | postfix[key + '_avg'] = stats[key].avg 136 | else: 137 | postfix[key] = stats[key] 138 | return postfix 139 | 140 | 141 | class noop_progress_bar(progress_bar): 142 | """No logging.""" 143 | 144 | def __init__(self, iterable, epoch=None, prefix=None): 145 | super().__init__(iterable, epoch, prefix) 146 | 147 | def __iter__(self): 148 | for obj in self.iterable: 149 | yield obj 150 | 151 | def log(self, stats): 152 | """Log intermediate stats according to log_interval.""" 153 | pass 154 | 155 | def print(self, stats): 156 | """Print end-of-epoch stats.""" 157 | pass 158 | 159 | 160 | class simple_progress_bar(progress_bar): 161 | """A minimal logger for non-TTY environments.""" 162 | 163 | def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): 164 | super().__init__(iterable, epoch, prefix) 165 | self.log_interval = log_interval 166 | self.stats = None 167 | 168 | def __iter__(self): 169 | size = len(self.iterable) 170 | for i, obj in enumerate(self.iterable): 171 | yield obj 172 | if self.stats is not None and i > 0 and \ 173 | self.log_interval is not None and i % self.log_interval == 0: 174 | postfix = self._str_commas(self.stats) 175 | print('{}: {:5d} / {:d} {}'.format(self.prefix, i, size, postfix), 176 | flush=True) 177 | 178 | def log(self, stats): 179 | """Log intermediate stats according to log_interval.""" 180 | self.stats = self._format_stats(stats) 181 | 182 | def print(self, stats): 183 | """Print end-of-epoch stats.""" 184 | postfix = self._str_pipes(self._format_stats(stats)) 185 | print('{} | {}'.format(self.prefix, postfix), flush=True) 186 | 187 | 188 | class tqdm_progress_bar(progress_bar): 189 | """Log to tqdm.""" 190 | 191 | def __init__(self, iterable, epoch=None, prefix=None): 192 | super().__init__(iterable, epoch, prefix) 193 | self.tqdm = tqdm(iterable, self.prefix, leave=False) 194 | 195 | def __iter__(self): 196 | return iter(self.tqdm) 197 | 198 | def log(self, stats): 199 | """Log intermediate stats according to log_interval.""" 200 | self.tqdm.set_postfix(self._format_stats(stats), refresh=False) 201 | 202 | def print(self, stats): 203 | """Print end-of-epoch stats.""" 204 | postfix = self._str_pipes(self._format_stats(stats)) 205 | self.tqdm.write('{} | {}'.format(self.tqdm.desc, postfix)) 206 | -------------------------------------------------------------------------------- /fairseq/sequence_scorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import torch 9 | 10 | from fairseq import utils 11 | 12 | 13 | class SequenceScorer(object): 14 | """Scores the target for a given source sentence.""" 15 | 16 | def __init__(self, models, tgt_dict): 17 | self.models = models 18 | self.pad = tgt_dict.pad() 19 | 20 | def cuda(self): 21 | for model in self.models: 22 | model.cuda() 23 | return self 24 | 25 | def score_batched_itr(self, data_itr, cuda=False, timer=None): 26 | """Iterate over a batched dataset and yield scored translations.""" 27 | for sample in data_itr: 28 | s = utils.move_to_cuda(sample) if cuda else sample 29 | if timer is not None: 30 | timer.start() 31 | pos_scores, attn = self.score(s) 32 | for i, id in enumerate(s['id'].data): 33 | # remove padding from ref 34 | src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad) 35 | ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None 36 | tgt_len = ref.numel() 37 | pos_scores_i = pos_scores[i][:tgt_len] 38 | score_i = pos_scores_i.sum() / tgt_len 39 | if attn is not None: 40 | attn_i = attn[i] 41 | _, alignment = attn_i.max(dim=0) 42 | else: 43 | attn_i = alignment = None 44 | hypos = [{ 45 | 'tokens': ref, 46 | 'score': score_i, 47 | 'attention': attn_i, 48 | 'alignment': alignment, 49 | 'positional_scores': pos_scores_i, 50 | }] 51 | if timer is not None: 52 | timer.stop(s['ntokens']) 53 | # return results in the same format as SequenceGenerator 54 | yield id, src, ref, hypos 55 | 56 | def score(self, sample): 57 | """Score a batch of translations.""" 58 | net_input = sample['net_input'] 59 | 60 | # compute scores for each model in the ensemble 61 | avg_probs = None 62 | avg_attn = None 63 | for model in self.models: 64 | with torch.no_grad(): 65 | model.eval() 66 | decoder_out = model.forward(**net_input) 67 | attn = decoder_out[1] 68 | 69 | probs = model.get_normalized_probs(decoder_out, log_probs=False, sample=sample).data 70 | if avg_probs is None: 71 | avg_probs = probs 72 | else: 73 | avg_probs.add_(probs) 74 | if attn is not None: 75 | attn = attn.data 76 | if avg_attn is None: 77 | avg_attn = attn 78 | else: 79 | avg_attn.add_(attn) 80 | avg_probs.div_(len(self.models)) 81 | avg_probs.log_() 82 | if avg_attn is not None: 83 | avg_attn.div_(len(self.models)) 84 | avg_probs = avg_probs.gather( 85 | dim=2, 86 | index=sample['target'].data.unsqueeze(-1), 87 | ) 88 | return avg_probs.squeeze(2), avg_attn 89 | -------------------------------------------------------------------------------- /fairseq/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import importlib 9 | import os 10 | 11 | from .fairseq_task import FairseqTask 12 | 13 | 14 | TASK_REGISTRY = {} 15 | TASK_CLASS_NAMES = set() 16 | 17 | 18 | def setup_task(args): 19 | return TASK_REGISTRY[args.task].setup_task(args) 20 | 21 | 22 | def register_task(name): 23 | """Decorator to register a new task.""" 24 | 25 | def register_task_cls(cls): 26 | if name in TASK_REGISTRY: 27 | raise ValueError('Cannot register duplicate task ({})'.format(name)) 28 | if not issubclass(cls, FairseqTask): 29 | raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__)) 30 | if cls.__name__ in TASK_CLASS_NAMES: 31 | raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__)) 32 | TASK_REGISTRY[name] = cls 33 | TASK_CLASS_NAMES.add(cls.__name__) 34 | return cls 35 | 36 | return register_task_cls 37 | 38 | 39 | # automatically import any Python files in the tasks/ directory 40 | for file in os.listdir(os.path.dirname(__file__)): 41 | if file.endswith('.py') and not file.startswith('_'): 42 | module = file[:file.find('.py')] 43 | importlib.import_module('fairseq.tasks.' + module) 44 | -------------------------------------------------------------------------------- /fairseq/tasks/fairseq_task.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | 9 | class FairseqTask(object): 10 | """ 11 | A Task defines the data format, stores shared state (e.g., dictionaries) and 12 | provides helpers for building the model/criterion and calculating the loss. 13 | """ 14 | 15 | @staticmethod 16 | def add_args(parser): 17 | """Add task-specific arguments to the parser.""" 18 | pass 19 | 20 | def __init__(self, args): 21 | self.args = args 22 | self.datasets = {} 23 | 24 | @classmethod 25 | def setup_task(cls, args, **kwargs): 26 | raise NotImplementedError 27 | 28 | def load_dataset(self, split, combine=False): 29 | raise NotImplementedError 30 | 31 | def dataset(self, split): 32 | """Return a dataset split.""" 33 | from fairseq.data import FairseqDataset 34 | if split not in self.datasets: 35 | raise KeyError('Dataset not loaded: ' + split) 36 | if not isinstance(self.datasets[split], FairseqDataset): 37 | raise TypeError('Datasets are expected to be of type FairseqDataset') 38 | return self.datasets[split] 39 | 40 | def build_model(self, args): 41 | from fairseq import models 42 | return models.build_model(args, self) 43 | 44 | def build_criterion(self, args): 45 | from fairseq import criterions 46 | return criterions.build_criterion(args, self) 47 | 48 | def get_loss(self, model, criterion, sample): 49 | return criterion(model, sample) 50 | 51 | @property 52 | def source_dictionary(self): 53 | raise NotImplementedError 54 | 55 | @property 56 | def target_dictionary(self): 57 | raise NotImplementedError 58 | -------------------------------------------------------------------------------- /fairseq/tasks/language_modeling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import itertools 9 | import numpy as np 10 | import os 11 | 12 | from torch.utils.data import ConcatDataset 13 | 14 | from fairseq.data import ( 15 | Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset, 16 | MonolingualDataset, TokenBlockDataset, 17 | ) 18 | 19 | from . import FairseqTask, register_task 20 | 21 | 22 | @register_task('language_modeling') 23 | class LanguageModelingTask(FairseqTask): 24 | 25 | @staticmethod 26 | def add_args(parser): 27 | """Add task-specific arguments to the parser.""" 28 | parser.add_argument('data', metavar='DIR', help='path to data directory') 29 | parser.add_argument('--sample-break-mode', metavar='VAL', 30 | choices=['none', 'complete', 'eos', 'cbt_booktitle'], 31 | help='If omitted or "none", fills each sample with tokens-per-sample ' 32 | 'tokens. If set to "complete", splits samples only at the end ' 33 | 'of sentence, but may include multiple sentences per sample. ' 34 | 'If set to "eos", includes only one sentence per sample.') 35 | parser.add_argument('--tokens-per-sample', default=1024, type=int, metavar='N', 36 | help='max number of tokens per sample for LM dataset') 37 | parser.add_argument('--raw-text', default=False, action='store_true', 38 | help='load raw text dataset') 39 | 40 | def __init__(self, args, dictionary): 41 | super().__init__(args) 42 | self.dictionary = dictionary 43 | 44 | @classmethod 45 | def setup_task(cls, args, **kwargs): 46 | dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) 47 | print('| dictionary: {} types'.format(len(dictionary))) 48 | return cls(args, dictionary) 49 | 50 | def load_dataset(self, split, combine=False): 51 | """Load a dataset split.""" 52 | 53 | loaded_datasets = [] 54 | 55 | for k in itertools.count(): 56 | split_k = split + (str(k) if k > 0 else '') 57 | path = os.path.join(self.args.data, split_k) 58 | 59 | if self.args.raw_text and IndexedRawTextDataset.exists(path): 60 | ds = IndexedRawTextDataset(path, self.dictionary) 61 | tokens = [t for l in ds.tokens_list for t in l] 62 | elif not self.args.raw_text and IndexedInMemoryDataset.exists(path): 63 | ds = IndexedInMemoryDataset(path, fix_lua_indexing=True) 64 | tokens = ds.buffer 65 | else: 66 | if k > 0: 67 | break 68 | else: 69 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) 70 | 71 | cbt_booktitle_idx = None 72 | if self.args.sample_break_mode == 'cbt_booktitle': 73 | if self.dictionary.index('_BOOK_TITLE_') != self.dictionary.unk(): 74 | cbt_booktitle_idx = self.dictionary.index('_BOOK_TITLE_') 75 | 76 | loaded_datasets.append( 77 | TokenBlockDataset( 78 | tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode, 79 | include_targets=True, cbt_booktitle_idx=cbt_booktitle_idx, 80 | )) 81 | 82 | print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) 83 | 84 | if not combine: 85 | break 86 | 87 | if len(loaded_datasets) == 1: 88 | dataset = loaded_datasets[0] 89 | sizes = dataset.sizes 90 | else: 91 | dataset = ConcatDataset(loaded_datasets) 92 | sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) 93 | 94 | self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False) 95 | 96 | @property 97 | def target_dictionary(self): 98 | return self.dictionary 99 | -------------------------------------------------------------------------------- /fairseq/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from collections import Counter 9 | import re 10 | 11 | import torch 12 | 13 | 14 | SPACE_NORMALIZER = re.compile("\s+") 15 | 16 | 17 | def tokenize_line(line): 18 | line = SPACE_NORMALIZER.sub(" ", line) 19 | line = line.strip() 20 | return line.split() 21 | 22 | 23 | class Tokenizer: 24 | 25 | @staticmethod 26 | def add_file_to_dictionary(filename, dict, tokenize): 27 | with open(filename, 'r') as f: 28 | for line in f: 29 | for word in tokenize(line): 30 | dict.add_symbol(word) 31 | dict.add_symbol(dict.eos_word) 32 | 33 | @staticmethod 34 | def binarize(filename, dict, consumer, tokenize=tokenize_line, 35 | append_eos=True, reverse_order=False): 36 | nseq, ntok = 0, 0 37 | replaced = Counter() 38 | 39 | def replaced_consumer(word, idx): 40 | if idx == dict.unk_index and word != dict.unk_word: 41 | replaced.update([word]) 42 | 43 | with open(filename, 'r') as f: 44 | for line in f: 45 | ids = Tokenizer.tokenize( 46 | line=line, 47 | dict=dict, 48 | tokenize=tokenize, 49 | add_if_not_exist=False, 50 | consumer=replaced_consumer, 51 | append_eos=append_eos, 52 | reverse_order=reverse_order, 53 | ) 54 | nseq += 1 55 | 56 | consumer(ids) 57 | ntok += len(ids) 58 | return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)} 59 | 60 | @staticmethod 61 | def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True, 62 | consumer=None, append_eos=True, reverse_order=False): 63 | words = tokenize(line) 64 | if reverse_order: 65 | words = list(reversed(words)) 66 | nwords = len(words) 67 | ids = torch.IntTensor(nwords + 1 if append_eos else nwords) 68 | 69 | for i, word in enumerate(words): 70 | if add_if_not_exist: 71 | idx = dict.add_symbol(word) 72 | else: 73 | idx = dict.index(word) 74 | if consumer is not None: 75 | consumer(word, idx) 76 | ids[i] = idx 77 | if append_eos: 78 | ids[nwords] = dict.eos_index 79 | return ids 80 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import torch 10 | 11 | from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils 12 | from fairseq.meters import StopwatchMeter, TimeMeter 13 | from fairseq.sequence_generator import SequenceGenerator 14 | from fairseq.sequence_scorer import SequenceScorer 15 | 16 | 17 | def main(args): 18 | assert args.path is not None, '--path required for generation!' 19 | assert not args.sampling or args.nbest == args.beam, \ 20 | '--sampling requires --nbest to be equal to --beam' 21 | assert args.replace_unk is None or args.raw_text, \ 22 | '--replace-unk requires a raw text dataset (--raw-text)' 23 | 24 | if args.max_tokens is None and args.max_sentences is None: 25 | args.max_tokens = 12000 26 | print(args) 27 | 28 | use_cuda = torch.cuda.is_available() and not args.cpu 29 | 30 | # Load dataset splits 31 | task = tasks.setup_task(args) 32 | task.load_dataset(args.gen_subset) 33 | print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) 34 | 35 | # Set dictionaries 36 | src_dict = task.source_dictionary 37 | tgt_dict = task.target_dictionary 38 | 39 | # Load ensemble 40 | print('| loading model(s) from {}'.format(args.path)) 41 | models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides)) 42 | 43 | # Optimize ensemble for generation 44 | for model in models: 45 | model.make_generation_fast_( 46 | beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, 47 | need_attn=args.print_alignment, 48 | ) 49 | if args.fp16: 50 | model.half() 51 | 52 | # Load alignment dictionary for unknown word replacement 53 | # (None if no unknown word replacement, empty if no path to align dictionary) 54 | align_dict = utils.load_align_dict(args.replace_unk) 55 | 56 | # Load dataset (possibly sharded) 57 | itr = data.EpochBatchIterator( 58 | dataset=task.dataset(args.gen_subset), 59 | max_tokens=args.max_tokens, 60 | max_sentences=args.max_sentences, 61 | max_positions=models[0].max_positions(), 62 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 63 | required_batch_size_multiple=8, 64 | num_shards=args.num_shards, 65 | shard_id=args.shard_id, 66 | ).next_epoch_itr(shuffle=False) 67 | 68 | # Initialize generator 69 | gen_timer = StopwatchMeter() 70 | if args.score_reference: 71 | translator = SequenceScorer(models, task.target_dictionary) 72 | else: 73 | translator = SequenceGenerator( 74 | models, task.target_dictionary, beam_size=args.beam, 75 | stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), 76 | len_penalty=args.lenpen, unk_penalty=args.unkpen, 77 | sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len, 78 | ) 79 | 80 | if use_cuda: 81 | translator.cuda() 82 | 83 | # Generate and compute BLEU score 84 | scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) 85 | num_sentences = 0 86 | has_target = True 87 | with progress_bar.build_progress_bar(args, itr) as t: 88 | if args.score_reference: 89 | translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) 90 | else: 91 | translations = translator.generate_batched_itr( 92 | t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, 93 | cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size, 94 | ) 95 | 96 | wps_meter = TimeMeter() 97 | for sample_id, src_tokens, target_tokens, hypos in translations: 98 | # Process input and ground truth 99 | has_target = target_tokens is not None 100 | target_tokens = target_tokens.int().cpu() if has_target else None 101 | 102 | # Either retrieve the original sentences or regenerate them from tokens. 103 | if align_dict is not None: 104 | src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) 105 | target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) 106 | else: 107 | src_str = src_dict.string(src_tokens, args.remove_bpe) 108 | if has_target: 109 | target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) 110 | 111 | if not args.quiet: 112 | print('S-{}\t{}'.format(sample_id, src_str)) 113 | if has_target: 114 | print('T-{}\t{}'.format(sample_id, target_str)) 115 | 116 | # Process top predictions 117 | for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]): 118 | hypo_tokens, hypo_str, alignment = utils.post_process_prediction( 119 | hypo_tokens=hypo['tokens'].int().cpu(), 120 | src_str=src_str, 121 | alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, 122 | align_dict=align_dict, 123 | tgt_dict=tgt_dict, 124 | remove_bpe=args.remove_bpe, 125 | ) 126 | 127 | if not args.quiet: 128 | print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str)) 129 | print('P-{}\t{}'.format( 130 | sample_id, 131 | ' '.join(map( 132 | lambda x: '{:.4f}'.format(x), 133 | hypo['positional_scores'].tolist(), 134 | )) 135 | )) 136 | 137 | if args.print_alignment: 138 | print('A-{}\t{}'.format( 139 | sample_id, 140 | ' '.join(map(lambda x: str(utils.item(x)), alignment)) 141 | )) 142 | 143 | # Score only the top hypothesis 144 | if has_target and i == 0: 145 | if align_dict is not None or args.remove_bpe is not None: 146 | # Convert back to tokens for evaluation with unk replacement and/or without BPE 147 | target_tokens = tokenizer.Tokenizer.tokenize( 148 | target_str, tgt_dict, add_if_not_exist=True) 149 | scorer.add(target_tokens, hypo_tokens) 150 | 151 | wps_meter.update(src_tokens.size(0)) 152 | t.log({'wps': round(wps_meter.avg)}) 153 | num_sentences += 1 154 | 155 | print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( 156 | num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) 157 | if has_target: 158 | print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) 159 | 160 | 161 | if __name__ == '__main__': 162 | parser = options.get_generation_parser() 163 | args = options.parse_args_and_arch(parser) 164 | main(args) 165 | -------------------------------------------------------------------------------- /interactive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | from collections import namedtuple 10 | import numpy as np 11 | import sys 12 | 13 | import torch 14 | 15 | from fairseq import data, options, tasks, tokenizer, utils 16 | from fairseq.sequence_generator import SequenceGenerator 17 | 18 | 19 | Batch = namedtuple('Batch', 'srcs tokens lengths') 20 | Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') 21 | 22 | 23 | def buffered_read(buffer_size): 24 | buffer = [] 25 | for src_str in sys.stdin: 26 | buffer.append(src_str.strip()) 27 | if len(buffer) >= buffer_size: 28 | yield buffer 29 | buffer = [] 30 | 31 | if len(buffer) > 0: 32 | yield buffer 33 | 34 | 35 | def make_batches(lines, args, src_dict, max_positions): 36 | tokens = [ 37 | tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long() 38 | for src_str in lines 39 | ] 40 | lengths = np.array([t.numel() for t in tokens]) 41 | itr = data.EpochBatchIterator( 42 | dataset=data.LanguagePairDataset(tokens, lengths, src_dict), 43 | max_tokens=args.max_tokens, 44 | max_sentences=args.max_sentences, 45 | max_positions=max_positions, 46 | ).next_epoch_itr(shuffle=False) 47 | for batch in itr: 48 | yield Batch( 49 | srcs=[lines[i] for i in batch['id']], 50 | tokens=batch['net_input']['src_tokens'], 51 | lengths=batch['net_input']['src_lengths'], 52 | ), batch['id'] 53 | 54 | 55 | def main(args): 56 | if args.buffer_size < 1: 57 | args.buffer_size = 1 58 | if args.max_tokens is None and args.max_sentences is None: 59 | args.max_sentences = 1 60 | 61 | assert not args.sampling or args.nbest == args.beam, \ 62 | '--sampling requires --nbest to be equal to --beam' 63 | assert not args.max_sentences or args.max_sentences <= args.buffer_size, \ 64 | '--max-sentences/--batch-size cannot be larger than --buffer-size' 65 | 66 | print(args) 67 | 68 | use_cuda = torch.cuda.is_available() and not args.cpu 69 | 70 | # Setup task, e.g., translation 71 | task = tasks.setup_task(args) 72 | 73 | # Load ensemble 74 | print('| loading model(s) from {}'.format(args.path)) 75 | model_paths = args.path.split(':') 76 | models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides)) 77 | 78 | # Set dictionaries 79 | src_dict = task.source_dictionary 80 | tgt_dict = task.target_dictionary 81 | 82 | # Optimize ensemble for generation 83 | for model in models: 84 | model.make_generation_fast_( 85 | beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, 86 | need_attn=args.print_alignment, 87 | ) 88 | if args.fp16: 89 | model.half() 90 | 91 | # Initialize generator 92 | translator = SequenceGenerator( 93 | models, tgt_dict, beam_size=args.beam, stop_early=(not args.no_early_stop), 94 | normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, 95 | unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk, 96 | minlen=args.min_len, sampling_temperature=args.sampling_temperature 97 | ) 98 | 99 | if use_cuda: 100 | translator.cuda() 101 | 102 | # Load alignment dictionary for unknown word replacement 103 | # (None if no unknown word replacement, empty if no path to align dictionary) 104 | align_dict = utils.load_align_dict(args.replace_unk) 105 | 106 | def make_result(src_str, hypos): 107 | result = Translation( 108 | src_str='O\t{}'.format(src_str), 109 | hypos=[], 110 | pos_scores=[], 111 | alignments=[], 112 | ) 113 | 114 | # Process top predictions 115 | for hypo in hypos[:min(len(hypos), args.nbest)]: 116 | hypo_tokens, hypo_str, alignment = utils.post_process_prediction( 117 | hypo_tokens=hypo['tokens'].int().cpu(), 118 | src_str=src_str, 119 | alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, 120 | align_dict=align_dict, 121 | tgt_dict=tgt_dict, 122 | remove_bpe=args.remove_bpe, 123 | ) 124 | result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str)) 125 | result.pos_scores.append('P\t{}'.format( 126 | ' '.join(map( 127 | lambda x: '{:.4f}'.format(x), 128 | hypo['positional_scores'].tolist(), 129 | )) 130 | )) 131 | result.alignments.append( 132 | 'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment))) 133 | if args.print_alignment else None 134 | ) 135 | return result 136 | 137 | def process_batch(batch): 138 | tokens = batch.tokens 139 | lengths = batch.lengths 140 | 141 | if use_cuda: 142 | tokens = tokens.cuda() 143 | lengths = lengths.cuda() 144 | 145 | translations = translator.generate( 146 | tokens, 147 | lengths, 148 | maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b), 149 | ) 150 | 151 | return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)] 152 | 153 | if args.buffer_size > 1: 154 | print('| Sentence buffer size:', args.buffer_size) 155 | print('| Type the input sentence and press return:') 156 | for inputs in buffered_read(args.buffer_size): 157 | indices = [] 158 | results = [] 159 | for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()): 160 | indices.extend(batch_indices) 161 | results += process_batch(batch) 162 | 163 | for i in np.argsort(indices): 164 | result = results[i] 165 | print(result.src_str) 166 | for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments): 167 | print(hypo) 168 | print(pos_scores) 169 | if align is not None: 170 | print(align) 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = options.get_generation_parser(interactive=True) 175 | args = options.parse_args_and_arch(parser) 176 | main(args) 177 | -------------------------------------------------------------------------------- /score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | # 9 | 10 | import argparse 11 | import os 12 | import sys 13 | 14 | from fairseq import bleu, tokenizer 15 | from fairseq.data import dictionary 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.') 20 | parser.add_argument('-s', '--sys', default='-', help='system output') 21 | parser.add_argument('-r', '--ref', required=True, help='references') 22 | parser.add_argument('-o', '--order', default=4, metavar='N', 23 | type=int, help='consider ngrams up to this order') 24 | parser.add_argument('--ignore-case', action='store_true', 25 | help='case-insensitive scoring') 26 | 27 | args = parser.parse_args() 28 | print(args) 29 | 30 | assert args.sys == '-' or os.path.exists(args.sys), \ 31 | "System output file {} does not exist".format(args.sys) 32 | assert os.path.exists(args.ref), \ 33 | "Reference file {} does not exist".format(args.ref) 34 | 35 | dict = dictionary.Dictionary() 36 | 37 | def readlines(fd): 38 | for line in fd.readlines(): 39 | if args.ignore_case: 40 | yield line.lower() 41 | yield line 42 | 43 | def score(fdsys): 44 | with open(args.ref) as fdref: 45 | scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) 46 | for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): 47 | sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict) 48 | ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict) 49 | scorer.add(ref_tok, sys_tok) 50 | print(scorer.result_string(args.order)) 51 | 52 | if args.sys == '-': 53 | score(sys.stdin) 54 | else: 55 | with open(args.sys, 'r') as f: 56 | score(f) 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | from setuptools import setup, find_packages, Extension 10 | import sys 11 | 12 | 13 | if sys.version_info < (3,): 14 | sys.exit('Sorry, Python3 is required for fairseq.') 15 | 16 | with open('README.md') as f: 17 | readme = f.read() 18 | 19 | with open('LICENSE') as f: 20 | license = f.read() 21 | 22 | with open('requirements.txt') as f: 23 | reqs = f.read() 24 | 25 | 26 | bleu = Extension( 27 | 'fairseq.libbleu', 28 | sources=[ 29 | 'fairseq/clib/libbleu/libbleu.cpp', 30 | 'fairseq/clib/libbleu/module.cpp', 31 | ], 32 | extra_compile_args=['-std=c++11'], 33 | ) 34 | 35 | 36 | setup( 37 | name='fairseq', 38 | version='0.5.0', 39 | description='Facebook AI Research Sequence-to-Sequence Toolkit', 40 | long_description=readme, 41 | license=license, 42 | install_requires=reqs.strip().split('\n'), 43 | packages=find_packages(), 44 | ext_modules=[bleu], 45 | test_suite='tests', 46 | ) 47 | -------------------------------------------------------------------------------- /txt2dict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import Counter 3 | 4 | def read_file(in_file_path): 5 | sentences = [] 6 | with open(in_file_path) as f: 7 | for line in f: 8 | line = line.strip() 9 | sentences.append(line.split(' ')) 10 | return sentences 11 | 12 | def write_dict(cnt, out_file_path): 13 | with open(out_file_path, 'w+') as f: 14 | for k, v in cnt.items(): 15 | f.write(f"{k} {v}\n") 16 | 17 | if __name__ == "__main__": 18 | all_sentences = [] 19 | for in_file_path in sys.argv[1:-1]: 20 | all_sentences.extend(read_file(in_file_path)) 21 | 22 | cnt = Counter([word for sentence in all_sentences for word in sentence]) 23 | 24 | out_file_path = sys.argv[-1] 25 | write_dict(cnt, out_file_path) 26 | 27 | --------------------------------------------------------------------------------