├── .gitignore ├── metrics ├── vocab.xls ├── sentence-bleu ├── sari.py ├── calc_sent_sim.py ├── calc_sent_sim_zh.py ├── hsk_freq.py ├── calc_metrics.sh ├── calc_bleu.py └── calc_bleu_zh.py ├── pretrained_model ├── MASS │ └── README.md └── MASS_zh │ └── README.md ├── data └── textbook │ └── processed │ ├── test.bin │ ├── test.idx │ ├── train.bin │ ├── train.idx │ ├── valid.bin │ ├── valid.idx │ ├── test.noise-prim.prim.bin │ ├── test.noise-prim.prim.idx │ ├── train.prim-None.prim.bin │ ├── train.prim-None.prim.idx │ ├── valid.prim-None.prim.bin │ └── valid.prim-None.prim.idx ├── requirements-train.txt ├── mass ├── __init__.py ├── bert_dictionary.py ├── learned_positional_embedding.py ├── masked_dataset.py ├── masked_s2s.py ├── language_pair_dataset.py ├── multihead_attention.py ├── trainer.py └── sequence_generator.py ├── requirements-eval.txt ├── run ├── train_oxford_oald_multi_task.sh ├── train_cwn_textbook_multi_task.sh ├── evaluate_cwn_textbook.sh ├── data_process_zh.sh ├── evaluate_oxford_oald.sh └── data_process.sh ├── README.md ├── encode-zh.py ├── encode.py ├── train.py ├── tokenization_bert.py └── tokenization_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | -------------------------------------------------------------------------------- /metrics/vocab.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/metrics/vocab.xls -------------------------------------------------------------------------------- /pretrained_model/MASS/README.md: -------------------------------------------------------------------------------- 1 | # File List 2 | 3 | - dict.txt 4 | - mass-base-uncased.pt -------------------------------------------------------------------------------- /pretrained_model/MASS_zh/README.md: -------------------------------------------------------------------------------- 1 | # File List 2 | 3 | - dict.txt 4 | - mass-base-chinese.pt -------------------------------------------------------------------------------- /metrics/sentence-bleu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/metrics/sentence-bleu -------------------------------------------------------------------------------- /data/textbook/processed/test.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/test.bin -------------------------------------------------------------------------------- /data/textbook/processed/test.idx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/test.idx -------------------------------------------------------------------------------- /data/textbook/processed/train.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/train.bin -------------------------------------------------------------------------------- /data/textbook/processed/train.idx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/train.idx -------------------------------------------------------------------------------- /data/textbook/processed/valid.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/valid.bin -------------------------------------------------------------------------------- /data/textbook/processed/valid.idx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/valid.idx -------------------------------------------------------------------------------- /requirements-train.txt: -------------------------------------------------------------------------------- 1 | blingfire==0.0.16 2 | fairseq==0.9.0 3 | torch==1.4.0 4 | torchtext==0.5.0 5 | torchvision==0.5.0 -------------------------------------------------------------------------------- /mass/__init__.py: -------------------------------------------------------------------------------- 1 | from . import masked_s2s 2 | from . import s2s_model 3 | from . import translation 4 | from . import modifyed_lsce 5 | -------------------------------------------------------------------------------- /data/textbook/processed/test.noise-prim.prim.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/test.noise-prim.prim.bin -------------------------------------------------------------------------------- /data/textbook/processed/test.noise-prim.prim.idx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/test.noise-prim.prim.idx -------------------------------------------------------------------------------- /data/textbook/processed/train.prim-None.prim.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/train.prim-None.prim.bin -------------------------------------------------------------------------------- /data/textbook/processed/train.prim-None.prim.idx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/train.prim-None.prim.idx -------------------------------------------------------------------------------- /data/textbook/processed/valid.prim-None.prim.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/valid.prim-None.prim.bin -------------------------------------------------------------------------------- /data/textbook/processed/valid.prim-None.prim.idx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blcuicall/SimpDefiner/HEAD/data/textbook/processed/valid.prim-None.prim.idx -------------------------------------------------------------------------------- /requirements-eval.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | sentence-transformers==1.0.4 3 | jieba==0.42.1 4 | nltk==3.4.3 5 | pandas==1.3.1 6 | scipy==1.6.2 7 | xlrd==2.0.1 -------------------------------------------------------------------------------- /metrics/sari.py: -------------------------------------------------------------------------------- 1 | # -*- coding -*- 2 | import sys 3 | from easse.sari import corpus_sari 4 | 5 | 6 | def read_file(file): 7 | sents = [] 8 | with open(file) as fr: 9 | for line in fr: 10 | sents.append(line.strip()) 11 | return sents 12 | 13 | 14 | def main(*argv): 15 | if not argv: 16 | argv = sys.argv[1:] 17 | assert len(argv) == 3 18 | orig_file = argv[0] 19 | sys_file = argv[1] 20 | refs_file = argv[2] 21 | orig_sents = read_file(orig_file) 22 | sys_sents = read_file(sys_file) 23 | refs_sents = read_file(refs_file) 24 | assert len(orig_sents) == len(sys_sents) == len(refs_sents) 25 | sari = corpus_sari(orig_sents=orig_sents, 26 | sys_sents=sys_sents, 27 | refs_sents=[refs_sents]) 28 | print(sari) 29 | 30 | 31 | if __name__ == '__main__': 32 | sys.exit(main()) 33 | -------------------------------------------------------------------------------- /metrics/calc_sent_sim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import argparse 4 | import scipy 5 | from sentence_transformers import SentenceTransformer 6 | 7 | 8 | def main(args): 9 | embedder = SentenceTransformer('distiluse-base-multilingual-cased') 10 | 11 | hyp_data = [] 12 | tgt_data = [] 13 | with open(args.out_path) as fr_out, open(args.tgt_path) as fr_tgt: 14 | for hyp, tgt in zip(fr_out, fr_tgt): 15 | hyp_data.append(hyp) 16 | tgt_data.append(tgt) 17 | 18 | assert len(hyp_data) == len(tgt_data) 19 | total_sim = 0 20 | for hyp, tgt in zip(hyp_data, tgt_data): 21 | hyp_embedding = embedder.encode([hyp]) 22 | tgt_embedding = embedder.encode([tgt]) 23 | sim = 1 - scipy.spatial.distance.cdist(hyp_embedding, tgt_embedding, 'cosine')[0][0] 24 | total_sim += sim 25 | 26 | avg_sim = total_sim / len(hyp_data) 27 | print(f"Semantic Score: {avg_sim}") 28 | return 0 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--out_path', type=str, help='path of reference file') 34 | parser.add_argument('--tgt_path', type=str, help='path of target file') 35 | args = parser.parse_args() 36 | sys.exit(main(args)) 37 | -------------------------------------------------------------------------------- /metrics/calc_sent_sim_zh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import argparse 4 | import scipy 5 | import jieba 6 | from sentence_transformers import SentenceTransformer 7 | 8 | 9 | def main(args): 10 | embedder = SentenceTransformer('distiluse-base-multilingual-cased') 11 | 12 | hyp_data = [] 13 | tgt_data = [] 14 | with open(args.out_path) as fr_out, open(args.tgt_path) as fr_tgt: 15 | for hyp, tgt in zip(fr_out, fr_tgt): 16 | hyp_data.append(' '.join(jieba.lcut(hyp.replace(' ', '')))) 17 | tgt_data.append(' '.join(jieba.lcut(tgt.replace(' ', '')))) 18 | 19 | assert len(hyp_data) == len(tgt_data) 20 | total_sim = 0 21 | for hyp, tgt in zip(hyp_data, tgt_data): 22 | hyp_embedding = embedder.encode([hyp]) 23 | tgt_embedding = embedder.encode([tgt]) 24 | sim = 1 - scipy.spatial.distance.cdist(hyp_embedding, tgt_embedding, 'cosine')[0][0] 25 | total_sim += sim 26 | 27 | avg_sim = total_sim / len(hyp_data) 28 | print(f"Semantic Score: {avg_sim}") 29 | return 0 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--out_path', type=str, help='path of reference file') 35 | parser.add_argument('--tgt_path', type=str, help='path of target file') 36 | args = parser.parse_args() 37 | sys.exit(main(args)) 38 | -------------------------------------------------------------------------------- /run/train_oxford_oald_multi_task.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | set -e 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | SUMM_DIR=data/oxford_public/processed 6 | DAE_DIR=data/oald/processed 7 | PRETRAINED_MODEL_PATH=pretrained_model/MASS/mass-base-uncased.pt 8 | SAVE_DIR=checkpoints/sdgf-811 9 | mkdir -p $SAVE_DIR 10 | 11 | python train.py \ 12 | $SUMM_DIR:$DAE_DIR \ 13 | --user-dir mass --task translation_mix --arch transformer_mix_base \ 14 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 15 | --lr 3e-4 --min-lr 1e-09 \ 16 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 500 \ 17 | --weight-decay 0.0 \ 18 | --seed 1111 \ 19 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 20 | --update-freq 4 --max-tokens 3072 \ 21 | --ddp-backend=no_c10d --max-epoch 20 \ 22 | --max-source-positions 512 --max-target-positions 512 \ 23 | --skip-invalid-size-inputs-valid-test \ 24 | --dropout 0.2 \ 25 | --load-from-pretrained-model $PRETRAINED_MODEL_PATH \ 26 | --model_lang_pairs src-tgt oald-oald --lang-pairs src-tgt --dae-styles oald \ 27 | --lambda-parallel-config 0.8 --lambda-denoising-config 0.1 --lambda-lm-config 0.1 \ 28 | --max-word-shuffle-distance 5 \ 29 | --word-dropout-prob 0.2 \ 30 | --word-blanking-prob 0.2 \ 31 | --divide-decoder-self-attn-norm True \ 32 | --divide-decoder-final-norm True \ 33 | --divide-decoder-encoder-attn-query True \ 34 | --save-dir $SAVE_DIR \ 35 | 2>&1 | tee $SAVE_DIR/training.log 36 | -------------------------------------------------------------------------------- /run/train_cwn_textbook_multi_task.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | set -e 3 | export CUDA_VISIBLE_DEVICES=0,1,2,3 4 | 5 | SUMM_DIR=data/cwn/processed 6 | DAE_DIR=data/textbook/processed 7 | PRETRAINED_MODEL_PATH=pretrained_model/MASS-zh/mass-base-chinese.pt 8 | SAVE_DIR=checkpoints/sdgf-zh-811 9 | mkdir -p $SAVE_DIR 10 | 11 | python train.py \ 12 | $SUMM_DIR:$DAE_DIR \ 13 | --user-dir mass --task translation_mix --arch transformer_mix_base \ 14 | --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ 15 | --lr 3e-4 --min-lr 1e-09 \ 16 | --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 500 \ 17 | --weight-decay 0.0 \ 18 | --seed 1111 \ 19 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 20 | --update-freq 4 --max-tokens 3072 \ 21 | --ddp-backend=no_c10d --max-epoch 50 \ 22 | --max-source-positions 512 --max-target-positions 512 \ 23 | --skip-invalid-size-inputs-valid-test \ 24 | --dropout 0.2 \ 25 | --load-from-pretrained-model $PRETRAINED_MODEL_PATH \ 26 | --model_lang_pairs src-tgt textbook-textbook --lang-pairs src-tgt --dae-styles textbook \ 27 | --lambda-parallel-config 0.8 --lambda-denoising-config 0.1 --lambda-lm-config 0.1 \ 28 | --max-word-shuffle-distance 5 \ 29 | --word-dropout-prob 0.2 \ 30 | --word-blanking-prob 0.2 \ 31 | --divide-decoder-self-attn-norm True \ 32 | --divide-decoder-final-norm True \ 33 | --divide-decoder-encoder-attn-query True \ 34 | --save-dir $SAVE_DIR \ 35 | 2>&1 | tee $SAVE_DIR/training.log 36 | -------------------------------------------------------------------------------- /run/evaluate_cwn_textbook.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Read arguments 3 | set -e 4 | export CUDA_VISIBLE_DEVICES=0 5 | POSITIONAL=() 6 | while [[ $# -gt 0 ]]; do 7 | key="$1" 8 | case $key in 9 | --model_dir) 10 | MODEL_DIR="$2" 11 | shift 2 12 | ;; 13 | *) 14 | POSITIONAL+=("$1") 15 | shift 16 | ;; 17 | esac 18 | done 19 | set -- "${POSITIONAL[@]}" 20 | 21 | DATA_DIR=data/cwn/processed 22 | MODEL=$MODEL_DIR/checkpoint_best.pt 23 | 24 | fairseq-generate $DATA_DIR \ 25 | --path $MODEL \ 26 | --user-dir mass \ 27 | --task translation_mix \ 28 | --model_lang_pairs src-tgt textbook-textbook \ 29 | --lang-pairs src-tgt \ 30 | --dae-styles textbook \ 31 | --batch-size 128 \ 32 | --skip-invalid-size-inputs-valid-test \ 33 | --beam 5 \ 34 | --lenpen 1.0 \ 35 | --min-len 2 \ 36 | --max-len-b 30 \ 37 | --unkpen 3 \ 38 | --no-repeat-ngram-size 3 \ 39 | 2>&1 | tee $MODEL_DIR/output_src_tgt.txt 40 | 41 | cp $DATA_DIR/test.src-tgt.src.bin $DATA_DIR/test.src-textbook.src.bin 42 | cp $DATA_DIR/test.src-tgt.src.idx $DATA_DIR/test.src-textbook.src.idx 43 | cp $DATA_DIR/test.src-tgt.tgt.bin $DATA_DIR/test.src-textbook.textbook.bin 44 | cp $DATA_DIR/test.src-tgt.tgt.idx $DATA_DIR/test.src-textbook.textbook.idx 45 | cp $DATA_DIR/dict.tgt.txt $DATA_DIR/dict.textbook.txt 46 | 47 | fairseq-generate $DATA_DIR \ 48 | --path $MODEL \ 49 | --user-dir mass \ 50 | --task translation_mix \ 51 | --model_lang_pairs src-tgt textbook-textbook \ 52 | --lang-pairs src-textbook \ 53 | --dae-styles textbook \ 54 | --batch-size 128 \ 55 | --skip-invalid-size-inputs-valid-test \ 56 | --beam 5 \ 57 | --lenpen 1.0 \ 58 | --min-len 2 \ 59 | --max-len-b 30 \ 60 | --unkpen 3 \ 61 | --no-repeat-ngram-size 3 \ 62 | 2>&1 | tee $MODEL_DIR/output_src_textbook.txt 63 | -------------------------------------------------------------------------------- /run/data_process_zh.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | DATA_DIR=data/cwn/raw 4 | OUT_DIR=data/cwn/processed 5 | 6 | for SPLIT in train valid test; do 7 | python encode-zh.py \ 8 | --inputs $DATA_DIR/${SPLIT}.src \ 9 | --outputs $DATA_DIR/${SPLIT}.bpe.src \ 10 | --workers 30 11 | python encode-zh.py \ 12 | --inputs $DATA_DIR/${SPLIT}.tgt \ 13 | --outputs $DATA_DIR/${SPLIT}.bpe.tgt \ 14 | --workers 30 15 | done 16 | 17 | fairseq-preprocess \ 18 | --user-dir mass --task masked_s2s \ 19 | --source-lang src --target-lang tgt \ 20 | --trainpref $DATA_DIR/train.bpe \ 21 | --validpref $DATA_DIR/valid.bpe \ 22 | --testpref $DATA_DIR/test.bpe \ 23 | --destdir $OUT_DIR \ 24 | --srcdict pretrained_model/MASS-zh/dict.txt \ 25 | --tgtdict pretrained_model/MASS-zh/dict.txt \ 26 | --workers 20 27 | 28 | DATA_DIR=data/textbook/raw 29 | DEST_DIR=data/textbook/processed 30 | 31 | for SPLIT in train valid test; do 32 | python encode-zh.py \ 33 | --inputs $DATA_DIR/${SPLIT}.txt \ 34 | --outputs $DATA_DIR/${SPLIT}.bpe \ 35 | --workers 30 36 | done 37 | 38 | fairseq-preprocess \ 39 | --user-dir mass \ 40 | --task translation_mix \ 41 | --only-source \ 42 | --trainpref ${DATA_DIR}/train.bpe \ 43 | --validpref ${DATA_DIR}/valid.bpe \ 44 | --testpref ${DATA_DIR}/test.bpe \ 45 | --destdir $DEST_DIR \ 46 | --workers 20 \ 47 | --srcdict pretrained_model/MASS-zh/dict.txt 48 | 49 | for split in train valid; do 50 | cp $DEST_DIR/$split.idx $DEST_DIR/$split.prim-None.prim.idx 51 | cp $DEST_DIR/$split.bin $DEST_DIR/$split.prim-None.prim.bin 52 | done 53 | 54 | cp $DEST_DIR/test.bin $DEST_DIR/test.noise-prim.prim.bin 55 | cp $DEST_DIR/test.idx $DEST_DIR/test.noise-prim.prim.idx 56 | cp $DEST_DIR/dict.txt $DEST_DIR/dict.noise.txt 57 | cp $DEST_DIR/dict.txt $DEST_DIR/dict.prim.txt 58 | -------------------------------------------------------------------------------- /mass/bert_dictionary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from collections import Counter 7 | from multiprocessing import Pool 8 | import os 9 | 10 | import torch 11 | 12 | from fairseq.tokenizer import tokenize_line 13 | from fairseq.binarizer import safe_readline 14 | from fairseq.data import data_utils, Dictionary 15 | 16 | 17 | class BertDictionary(Dictionary): 18 | """A mapping from symbols to consecutive integers""" 19 | 20 | def __init__( 21 | self, 22 | pad='', 23 | eos='', 24 | unk='', 25 | bos='', 26 | extra_special_symbols=None, 27 | ): 28 | super().__init__(pad=pad, 29 | eos=eos, 30 | unk=unk, 31 | bos=bos, 32 | extra_special_symbols=extra_special_symbols) 33 | 34 | @classmethod 35 | def load_from_file(cls, filename): 36 | d = cls() 37 | d.symbols = [] 38 | d.count = [] 39 | d.indices = {} 40 | 41 | with open(filename, 'r', encoding='utf-8', errors='ignore') as input_file: 42 | for line in input_file: 43 | k, v = line.split(' ') 44 | d.add_symbol(k) 45 | 46 | d.unk_word = '[UNK]' 47 | d.pad_word = '[PAD]' 48 | d.eos_word = '[SEP]' 49 | d.bos_word = '[CLS]' 50 | 51 | d.bos_index = d.add_symbol('[CLS]') 52 | d.pad_index = d.add_symbol('[PAD]') 53 | d.eos_index = d.add_symbol('[SEP]') 54 | d.unk_index = d.add_symbol('[UNK]') 55 | 56 | d.nspecial = 999 57 | return d 58 | 59 | def save(self, f): 60 | """Stores dictionary into a text file""" 61 | ex_keys, ex_vals = self._get_meta() 62 | self._save(f, zip(ex_keys + self.symbols, ex_vals + self.count)) 63 | -------------------------------------------------------------------------------- /mass/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from fairseq import utils 4 | 5 | 6 | class LearnedPositionalEmbedding(nn.Embedding): 7 | """ 8 | This module learns positional embeddings up to a fixed maximum size. 9 | Padding ids are ignored by either offsetting based on padding_idx 10 | or by setting padding_idx to None and ensuring that the appropriate 11 | position ids are passed to the forward function. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | num_embeddings: int, 17 | embedding_dim: int, 18 | padding_idx: int, 19 | ): 20 | super().__init__(num_embeddings, embedding_dim, padding_idx) 21 | self.onnx_trace = False 22 | 23 | def forward(self, input, incremental_state=None, positions=None): 24 | """Input is expected to be of size [bsz x seqlen].""" 25 | assert ( 26 | (positions is None) or (self.padding_idx is None) 27 | ), "If positions is pre-computed then padding_idx should not be set." 28 | 29 | if positions is None: 30 | if incremental_state is not None: 31 | # positions is the same for every token when decoding a single step 32 | # Without the int() cast, it doesn't work in some cases when exporting to ONNX 33 | positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1))) 34 | else: 35 | positions = utils.make_positions( 36 | input.data, self.padding_idx, onnx_trace=self.onnx_trace, 37 | ) 38 | return super().forward(positions) 39 | 40 | def max_positions(self): 41 | """Maximum number of supported positions.""" 42 | if self.padding_idx is not None: 43 | return self.num_embeddings - self.padding_idx - 1 44 | else: 45 | return self.num_embeddings 46 | 47 | def _forward(self, positions): 48 | return super().forward(positions) 49 | -------------------------------------------------------------------------------- /run/evaluate_oxford_oald.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Read arguments 3 | 4 | set -e 5 | export CUDA_VISIBLE_DEVICES=7 6 | POSITIONAL=() 7 | while [[ $# -gt 0 ]]; do 8 | key="$1" 9 | case $key in 10 | --model_dir) 11 | MODEL_DIR="$2" 12 | shift 2 13 | ;; 14 | *) 15 | POSITIONAL+=("$1") 16 | shift 17 | ;; 18 | esac 19 | done 20 | set -- "${POSITIONAL[@]}" 21 | 22 | DATA_DIR=data/oxford/processed 23 | OALD_DATA_DIR=data/annotated-oxford-oald-test/processed 24 | MODEL=$MODEL_DIR/checkpoint_best.pt 25 | 26 | fairseq-generate $DATA_DIR \ 27 | --path $MODEL \ 28 | --user-dir mass \ 29 | --task translation_mix \ 30 | --model_lang_pairs src-tgt oald-oald \ 31 | --lang-pairs src-tgt \ 32 | --dae-styles oald \ 33 | --batch-size 128 \ 34 | --skip-invalid-size-inputs-valid-test \ 35 | --beam 5 \ 36 | --lenpen 1.0 \ 37 | --min-len 2 \ 38 | --max-len-b 30 \ 39 | --unkpen 3 \ 40 | --no-repeat-ngram-size 3 \ 41 | 2>&1 | tee $MODEL_DIR/output_src_tgt.txt 42 | #bash metrics/calc_metrics.sh $MODEL_DIR oxford $CUDA >$MODEL_DIR/log_oxford_metrics.txt 43 | 44 | cp $OALD_DATA_DIR/test.src-oald.src.bin $DATA_DIR/test.src-oald.src.bin 45 | cp $OALD_DATA_DIR/test.src-oald.src.idx $DATA_DIR/test.src-oald.src.idx 46 | cp $OALD_DATA_DIR/test.src-oald.oald.bin $DATA_DIR/test.src-oald.oald.bin 47 | cp $OALD_DATA_DIR/test.src-oald.oald.idx $DATA_DIR/test.src-oald.oald.idx 48 | cp $OALD_DATA_DIR/dict.oald.txt $DATA_DIR/dict.oald.txt 49 | 50 | fairseq-generate $DATA_DIR \ 51 | --path $MODEL \ 52 | --user-dir mass \ 53 | --task translation_mix \ 54 | --model_lang_pairs src-tgt oald-oald \ 55 | --lang-pairs src-oald \ 56 | --dae-styles oald \ 57 | --batch-size 128 \ 58 | --skip-invalid-size-inputs-valid-test \ 59 | --beam 5 \ 60 | --lenpen 1.0 \ 61 | --min-len 2 \ 62 | --max-len-b 30 \ 63 | --unkpen 3 \ 64 | --no-repeat-ngram-size 3 \ 65 | 2>&1 | tee $MODEL_DIR/output_src_oald.txt 66 | #bash metrics/calc_metrics.sh $MODEL_DIR oald $CUDA >$MODEL_DIR/log_oald_metrics.txt 67 | -------------------------------------------------------------------------------- /run/data_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | DATA_DIR=data/oxford/raw 4 | OUT_DIR=data/oxford/processed 5 | 6 | for SPLIT in train valid test; do 7 | python encode.py \ 8 | --inputs $DATA_DIR/oxford.${SPLIT}.src \ 9 | --outputs $DATA_DIR/oxford.${SPLIT}.bpe.src \ 10 | --workers 30 11 | python encode.py \ 12 | --inputs $DATA_DIR/oxford.${SPLIT}.tgt \ 13 | --outputs $DATA_DIR/oxford.${SPLIT}.bpe.tgt \ 14 | --workers 30 15 | done 16 | 17 | fairseq-preprocess \ 18 | --user-dir mass --task masked_s2s \ 19 | --source-lang src --target-lang tgt \ 20 | --trainpref $DATA_DIR/oxford.train.bpe \ 21 | --validpref $DATA_DIR/oxford.valid.bpe \ 22 | --testpref $DATA_DIR/oxford.test.bpe \ 23 | --destdir $OUT_DIR \ 24 | --srcdict pretrained_model/MASS-zh/dict.txt \ 25 | --tgtdict pretrained_model/MASS-zh/dict.txt \ 26 | --workers 30 27 | 28 | DATA_DIR=data/aligned-oxford-oald-test/raw 29 | OUT_DIR=data/aligned-oxford-oald-test/processed 30 | 31 | python encode.py \ 32 | --inputs $DATA_DIR/test.src \ 33 | --outputs $DATA_DIR/test.bpe.src \ 34 | --workers 30 35 | 36 | python encode.py \ 37 | --inputs $DATA_DIR/test.oxford \ 38 | --outputs $DATA_DIR/test.bpe.oxford \ 39 | --workers 30 40 | 41 | python encode.py \ 42 | --inputs $DATA_DIR/test.oald \ 43 | --outputs $DATA_DIR/test.bpe.oald \ 44 | --workers 30 45 | 46 | fairseq-preprocess \ 47 | --user-dir mass --task masked_s2s \ 48 | --source-lang src --target-lang oxford \ 49 | --testpref $DATA_DIR/test.bpe \ 50 | --destdir $OUT_DIR \ 51 | --srcdict pretrained_model/MASS/dict.txt \ 52 | --tgtdict pretrained_model/MASS/dict.txt \ 53 | --workers 20 54 | 55 | fairseq-preprocess \ 56 | --user-dir mass --task masked_s2s \ 57 | --source-lang src --target-lang oald \ 58 | --testpref $DATA_DIR/test.bpe \ 59 | --destdir $OUT_DIR \ 60 | --srcdict pretrained_model/MASS/dict.txt \ 61 | --tgtdict pretrained_model/MASS/dict.txt \ 62 | --workers 20 63 | 64 | cp $DEST_DIR/dict.txt $DEST_DIR/dict.oxford.txt 65 | cp $DEST_DIR/dict.txt $DEST_DIR/dict.oald.txt 66 | 67 | DATA_DIR=data/oald/raw 68 | DEST_DIR=data/oald/processed 69 | 70 | for SPLIT in train valid test; do 71 | python encode.py \ 72 | --inputs $DATA_DIR/oald.${SPLIT}.txt \ 73 | --outputs $DATA_DIR/oald.${SPLIT}.bpe \ 74 | --workers 30 75 | done 76 | 77 | fairseq-preprocess \ 78 | --user-dir mass \ 79 | --task translation_mix \ 80 | --only-source \ 81 | --trainpref ${DATA_DIR}/oald.train.bpe \ 82 | --validpref ${DATA_DIR}/oald.valid.bpe \ 83 | --testpref ${DATA_DIR}/oald.test.bpe \ 84 | --destdir $DEST_DIR \ 85 | --workers 20 \ 86 | --srcdict pretrained_model/MASS/dict.txt 87 | 88 | for split in train valid; do 89 | cp $DEST_DIR/$split.idx $DEST_DIR/$split.oald-None.oald.idx 90 | cp $DEST_DIR/$split.bin $DEST_DIR/$split.oald-None.oald.bin 91 | done 92 | 93 | cp $DEST_DIR/test.bin $DEST_DIR/test.noise-oald.oald.bin 94 | cp $DEST_DIR/test.idx $DEST_DIR/test.noise-oald.oald.idx 95 | cp $DEST_DIR/dict.txt $DEST_DIR/dict.noise.txt 96 | cp $DEST_DIR/dict.txt $DEST_DIR/dict.oald.txt 97 | -------------------------------------------------------------------------------- /metrics/hsk_freq.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pandas as pd 3 | import sys 4 | import jieba 5 | 6 | 7 | def read_hsk_vocab(path): 8 | hsk_vocab = {} 9 | vocab = pd.read_excel(path) 10 | level_keys = { 11 | '一级': 1, 12 | '二级': 2, 13 | '三级': 3, 14 | '四级': 4, 15 | '五级': 5, 16 | '六级': 6, 17 | '七-九级': 7 18 | } 19 | for k in level_keys: 20 | for word in vocab[k].dropna(): 21 | if '/' in word: 22 | words = word.split('/') 23 | for w in words: 24 | hsk_vocab[w.strip()] = level_keys[k] 25 | else: 26 | hsk_vocab[word.strip()] = level_keys[k] 27 | return hsk_vocab 28 | 29 | 30 | def count_nums(file_path, hsk_vocab): 31 | level_freq = {} 32 | num_tokens = 0 33 | num_types = 0 34 | with open(file_path) as fr: 35 | for line in fr: 36 | line = line.strip().replace(' ', '') 37 | seg_line = jieba.lcut(line) 38 | for word in seg_line: 39 | if word in hsk_vocab: 40 | level = hsk_vocab[word] 41 | else: 42 | level = 8 43 | if level not in level_freq: 44 | level_freq[level] = {} 45 | if word not in level_freq[level]: 46 | level_freq[level][word] = 0 47 | num_types += 1 48 | level_freq[level][word] += 1 49 | num_tokens += 1 50 | return level_freq, num_tokens, num_types 51 | 52 | 53 | def main(*argv): 54 | if not argv: 55 | argv = sys.argv[1:] 56 | assert len(argv) == 1 57 | 58 | inp_file = argv[0] 59 | hsk_vocab = read_hsk_vocab('metrics/vocab.xls') 60 | 61 | level_freq, num_tokens, num_types = count_nums(inp_file, hsk_vocab) 62 | # for level in level_freq: 63 | # level_types = len(level_freq[level]) 64 | # level_tokens = sum(level_freq[level].values()) 65 | # level_type_rate = (level_types / num_types) * 100 66 | # level_token_rate = (level_tokens / num_tokens) * 100 67 | # print(f"{level_type_rate},{level_token_rate}") 68 | 69 | low_level_types = 0 70 | low_level_tokens = 0 71 | for level in [1, 2, 3]: 72 | low_level_types += len(level_freq[level]) 73 | low_level_tokens += sum(level_freq[level].values()) 74 | low_level_type_rate = (low_level_types / num_types) * 100 75 | low_level_token_rate = (low_level_tokens / num_tokens) * 100 76 | # print(f"Low Level Type Rate: {low_level_type_rate:.2f}%") 77 | print(f"Low Level Token Rate: {low_level_token_rate:.2f}%") 78 | 79 | high_level_types = 0 80 | high_level_tokens = 0 81 | for level in [7, 8]: 82 | high_level_types += len(level_freq[level]) 83 | high_level_tokens += sum(level_freq[level].values()) 84 | high_level_type_rate = (high_level_types / num_types) * 100 85 | high_level_token_rate = (high_level_tokens / num_tokens) * 100 86 | # print(f"High Level Type Rate: {high_level_type_rate:.2f}%") 87 | print(f"High Level Token Rate: {high_level_token_rate:.2f}%") 88 | 89 | return 0 90 | 91 | 92 | if __name__ == '__main__': 93 | sys.exit(main()) 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multitasking Framework for Unsupervised Simple Definition Generation 2 | 3 | Source code for the paper **Multitasking Framework for Unsupervised Simple Definition Generation** published on **ACL 2022**. 4 | 5 | ## Requirements 6 | ### Training Environment 7 | - Pytorch 8 | - fairseq 9 | - blingfire 10 | 11 | In order to install them, you can run this command: 12 | 13 | ``` 14 | pip install -r requirements-train.txt 15 | ``` 16 | 17 | ### Evaluation Environment 18 | - Pytorch 19 | - Sentence-Transformers 20 | - Jieba 21 | - NLTK 22 | - Pandas 23 | - scipy 24 | - xlrd 25 | - EASSE 26 | 27 | In order to install them, you can run this command: 28 | 29 | ``` 30 | pip install -r requirements-eval.txt 31 | git clone https://github.com/feralvam/easse.git 32 | cd easse 33 | pip install . 34 | ``` 35 | 36 | ## Usage 37 | 1. All data including the Chinese and English DG dataset, and the simple text corpora mentioned in the paper have been placed in the folder "data". 38 | 39 | 2. Please download the pretrained model parameters of MASS from \[[en](https://modelrelease.blob.core.windows.net/mass/mass-base-uncased.tar.gz)|[zh](https://stublcuedu-my.sharepoint.com/:u:/g/personal/201921296062_stu_blcu_edu_cn/EZpcGUWQanxAt0XZNWb6QqsBauh4dqaR0JdF5u8ia5zJIQ?e=X2tV8r)], unzip it, and put the unzipped files into the folder "pretrained_model/MASS" and "pretrained_model/MASS-zh" respectively. 40 | 41 | 3. To preprocess the dataset, please run the following command: 42 | ```shell 43 | bash run/data_process.sh #for English 44 | # or 45 | bash run/data_process_zh.sh # for Chinese 46 | ``` 47 | 48 | 4. To train a SimpDefiner that can simultaneously generated complex and simple definitions, you can run the following command: 49 | ```shell 50 | bash run/train_oxford_oald_multi_task.sh # for English 51 | # or 52 | bash run/train_cwn_textbook_multi_task.sh # for Chinese 53 | ``` 54 | Model checkpoints will be saved in a `checkpoint` dir. 55 | 56 | 5. If you want to evaluate the trained model and generate definitions (both complex and simple) using this model, please run the following command: 57 | 58 | ```shell 59 | bash run/evaluate_oxford_oald.sh --model_dir [model-dir] # for English 60 | # or 61 | bash run/evaluate_cwn_textbook.sh --model_dir [model-dir] # for Chinese 62 | ``` 63 | The generated definitions will be saved in the same `checkpoint` dir. 64 | 65 | 6. If you want to run automatic metrics for the generated definitions, please run the following command: 66 | ```shell 67 | bash metrics/calc_metrics.sh [model-dir] [oxford|oald|cwn|textbook] [GPU_ID] 68 | ``` 69 | The `[oxford|oald|cwn|textbook]` arguments are used to assign the specific definitions, where `[oxford|oald]` are for English, and `[cwn|textbook]` are for Chinese. 70 | 71 | ## Cite 72 | 73 | ```bibtex 74 | @inproceedings{kong-etal-2022-simpdefiner, 75 | title = "Multitasking Framework for Unsupervised Simple Definition Generation", 76 | author = "Kong, Cunliang and 77 | Chen, Yun and 78 | Zhang, Hengyuan and 79 | Yang, Liner and 80 | Yang, Erhong", 81 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics", 82 | year = "2022" 83 | } 84 | ``` 85 | ## Contact 86 | If you have questions, suggestions or bug reports, please email cunliang.kong@outlook.com 87 | -------------------------------------------------------------------------------- /metrics/calc_metrics.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | set -e 3 | export CUDA_VISIBLE_DEVICES=0 4 | 5 | POSITIONAL=() 6 | while [[ $# -gt 0 ]]; do 7 | key="$1" 8 | case $key in 9 | --model_dir) 10 | MODEL_DIR="$2" 11 | shift 2 12 | ;; 13 | --style) 14 | STYLE="$2" 15 | shift 2 16 | ;; 17 | *) 18 | POSITIONAL+=("$1") 19 | shift 20 | ;; 21 | esac 22 | done 23 | set -- "${POSITIONAL[@]}" 24 | 25 | OXFORD_SRC=data/oxford/raw/oxford.test.src 26 | OXFORD_TGT=data/oxford/raw/oxford.test.tgt 27 | OALD_SRC=data/aligned-oxford-oald-text/raw/test.src 28 | OALD_TGT=data/aligned-oxford-oald-text/raw/test.oald 29 | OALD_TGT_COMPLEX=data/aligned-oxford-oald-text/raw/test.oxford 30 | 31 | CWN_SRC=data/cwn/raw/test.src 32 | CWN_TGT=data/cwn/raw/test.tgt 33 | 34 | if [[ "${STYLE}" == "oxford" ]]; then 35 | grep ^H "${MODEL_DIR}/output_src_tgt.txt" | 36 | sed 's/^H-//' | 37 | sort -n -k 1 | 38 | cut -f 3 | 39 | sed "s/ ##//g" \ 40 | >"${MODEL_DIR}/output_src_tgt.ordered.tgt" 41 | echo "Calculating BLEU Score" 42 | python metrics/calc_bleu.py "${OXFORD_SRC}" "${OXFORD_TGT}" "${MODEL_DIR}/output_src_tgt.ordered.tgt" 43 | echo "Calculating Semantic Score" 44 | CUDA_VISIBLE_DEVICES=${CUDA} python metrics/calc_sent_sim.py --out_path "${MODEL_DIR}/output_src_tgt.ordered.tgt" --tgt_path "${OXFORD_TGT}" 45 | python metrics/sari.py ${OALD_TGT_COMPLEX} ${MODEL_DIR}/output_src_tgt.ordered.tgt ${OALD_TGT} 46 | 47 | elif [[ "${STYLE}" == "oald" ]]; then 48 | grep ^H "${MODEL_DIR}/output_src_oald.txt" | 49 | sed 's/^H-//' | 50 | sort -n -k 1 | 51 | cut -f 3 | 52 | sed "s/ ##//g" \ 53 | >"${MODEL_DIR}/output_src_oald.ordered.tgt" 54 | echo "Calculating BLEU Score" 55 | python metrics/calc_bleu.py "${OALD_SRC}" "${OALD_TGT}" "${MODEL_DIR}/output_src_oald.ordered.tgt" 56 | echo "Calculating Semantic Score" 57 | CUDA_VISIBLE_DEVICES=${CUDA} python metrics/calc_sent_sim.py --out_path "${MODEL_DIR}/output_src_oald.ordered.tgt" --tgt_path "${OALD_TGT}" 58 | python metrics/sari.py ${OALD_TGT_COMPLEX} ${MODEL_DIR}/output_src_oald.ordered.tgt ${OALD_TGT} 59 | 60 | elif [[ "${STYLE}" == "cwn" ]]; then 61 | grep ^H "${MODEL_DIR}/output_src_tgt.txt" | 62 | sed 's/^H-//' | 63 | sort -n -k 1 | 64 | cut -f 3 | 65 | sed "s/ ##//g" \ 66 | >"${MODEL_DIR}/output_src_tgt.ordered.tgt" 67 | echo "Calculating BLEU Score" 68 | python metrics/calc_bleu_zh.py "${CWN_SRC}" "${CWN_TGT}" "${MODEL_DIR}/output_src_tgt.ordered.tgt" 69 | echo "Calculating Semantic Score" 70 | CUDA_VISIBLE_DEVICES=${CUDA} python metrics/calc_sent_sim_zh.py --out_path "${MODEL_DIR}/output_src_tgt.ordered.tgt" --tgt_path "${CWN_TGT}" 71 | python metrics/hsk_freq.py ${MODEL_DIR}/output_src_tgt.ordered.tgt 72 | 73 | elif [[ "${STYLE}" == "textbook" ]]; then 74 | grep ^H "${MODEL_DIR}/output_src_textbook.txt" | 75 | sed 's/^H-//' | 76 | sort -n -k 1 | 77 | cut -f 3 | 78 | sed "s/ ##//g" \ 79 | >"${MODEL_DIR}/output_src_textbook.ordered.tgt" 80 | echo "Calculating BLEU Score" 81 | python metrics/calc_bleu_zh.py "${CWN_SRC}" "${CWN_TGT}" "${MODEL_DIR}/output_src_textbook.ordered.tgt" 82 | echo "Calculating Semantic Score" 83 | CUDA_VISIBLE_DEVICES=${CUDA} python metrics/calc_sent_sim_zh.py --out_path "${MODEL_DIR}/output_src_textbook.ordered.tgt" --tgt_path "${CWN_TGT}" 84 | python metrics/hsk_freq.py ${MODEL_DIR}/output_src_textbook.ordered.tgt 85 | fi 86 | -------------------------------------------------------------------------------- /encode-zh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import sys 4 | 5 | from collections import Counter 6 | from multiprocessing import Pool 7 | 8 | from tokenization_bert import BertTokenizer 9 | from blingfire import text_to_words 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | "--inputs", 16 | nargs="+", 17 | default=['-'], 18 | help="input files to filter/encode", 19 | ) 20 | parser.add_argument( 21 | "--outputs", 22 | nargs="+", 23 | default=['-'], 24 | help="path to save encoded outputs", 25 | ) 26 | parser.add_argument( 27 | "--keep-empty", 28 | action="store_true", 29 | help="keep empty lines", 30 | ) 31 | parser.add_argument( 32 | "--tokenizer", 33 | type=str, 34 | default='bpe', 35 | help="which tokenizer to use", 36 | ) 37 | parser.add_argument("--workers", type=int, default=20) 38 | args = parser.parse_args() 39 | 40 | assert len(args.inputs) == len(args.outputs), \ 41 | "number of input and output paths should match" 42 | 43 | with contextlib.ExitStack() as stack: 44 | inputs = [ 45 | stack.enter_context(open(input, "r", encoding="utf-8", errors='ignore')) 46 | if input != "-" else sys.stdin 47 | for input in args.inputs 48 | ] 49 | outputs = [ 50 | stack.enter_context(open(output, "w", encoding="utf-8")) 51 | if output != "-" else sys.stdout 52 | for output in args.outputs 53 | ] 54 | 55 | encoder = MultiprocessingEncoder(args) 56 | pool = Pool(args.workers, initializer=encoder.initializer) 57 | encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) 58 | 59 | stats = Counter() 60 | for i, (filt, enc_lines) in enumerate(encoded_lines, start=1): 61 | if filt == "PASS": 62 | for enc_line, output_h in zip(enc_lines, outputs): 63 | print(enc_line, file=output_h) 64 | else: 65 | stats["num_filtered_" + filt] += 1 66 | if i % 10000 == 0: 67 | print("processed {} lines".format(i), file=sys.stderr) 68 | 69 | for k, v in stats.most_common(): 70 | print("[{}] filtered {} lines".format(k, v), file=sys.stderr) 71 | 72 | 73 | class MultiprocessingEncoder(object): 74 | 75 | def __init__(self, args): 76 | self.args = args 77 | 78 | def initializer(self): 79 | global bpe 80 | bpe = BertTokenizer.from_pretrained('bert-base-chinese') 81 | 82 | def encode(self, line): 83 | global bpe 84 | subword = bpe.tokenize(line) 85 | return subword 86 | 87 | def decode(self, tokens): 88 | global bpe 89 | return bpe.decode(tokens) 90 | 91 | def encode_lines(self, lines): 92 | """ 93 | Encode a set of lines. All lines will be encoded together. 94 | """ 95 | enc_lines = [] 96 | for line in lines: 97 | line = line.strip() 98 | if len(line) == 0 and not self.args.keep_empty: 99 | return ["EMPTY", None] 100 | if self.args.tokenizer == 'bpe': 101 | tokens = self.encode(line) 102 | enc_lines.append(" ".join(tokens)) 103 | else: 104 | enc_lines.append(text_to_words(line)) 105 | return ["PASS", enc_lines] 106 | 107 | def decode_lines(self, lines): 108 | dec_lines = [] 109 | for line in lines: 110 | tokens = map(int, line.strip().split()) 111 | dec_lines.append(self.decode(tokens)) 112 | return ["PASS", dec_lines] 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /encode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import sys 4 | 5 | from collections import Counter 6 | from multiprocessing import Pool 7 | 8 | from tokenization_bert import BertTokenizer 9 | from blingfire import text_to_words 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | "--inputs", 16 | nargs="+", 17 | default=['-'], 18 | help="input files to filter/encode", 19 | ) 20 | parser.add_argument( 21 | "--outputs", 22 | nargs="+", 23 | default=['-'], 24 | help="path to save encoded outputs", 25 | ) 26 | parser.add_argument( 27 | "--keep-empty", 28 | action="store_true", 29 | help="keep empty lines", 30 | ) 31 | parser.add_argument( 32 | "--tokenizer", 33 | type=str, 34 | default='bpe', 35 | help="which tokenizer to use", 36 | ) 37 | parser.add_argument("--workers", type=int, default=20) 38 | args = parser.parse_args() 39 | 40 | assert len(args.inputs) == len(args.outputs), \ 41 | "number of input and output paths should match" 42 | 43 | with contextlib.ExitStack() as stack: 44 | inputs = [ 45 | stack.enter_context(open(input, "r", encoding="utf-8", errors='ignore')) 46 | if input != "-" else sys.stdin 47 | for input in args.inputs 48 | ] 49 | outputs = [ 50 | stack.enter_context(open(output, "w", encoding="utf-8")) 51 | if output != "-" else sys.stdout 52 | for output in args.outputs 53 | ] 54 | 55 | encoder = MultiprocessingEncoder(args) 56 | pool = Pool(args.workers, initializer=encoder.initializer) 57 | encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) 58 | 59 | stats = Counter() 60 | for i, (filt, enc_lines) in enumerate(encoded_lines, start=1): 61 | if filt == "PASS": 62 | for enc_line, output_h in zip(enc_lines, outputs): 63 | print(enc_line, file=output_h) 64 | else: 65 | stats["num_filtered_" + filt] += 1 66 | if i % 10000 == 0: 67 | print("processed {} lines".format(i), file=sys.stderr) 68 | 69 | for k, v in stats.most_common(): 70 | print("[{}] filtered {} lines".format(k, v), file=sys.stderr) 71 | 72 | 73 | class MultiprocessingEncoder(object): 74 | 75 | def __init__(self, args): 76 | self.args = args 77 | 78 | def initializer(self): 79 | global bpe 80 | bpe = BertTokenizer.from_pretrained('pretrained_model/bert_uncased_base_vocab.txt') 81 | 82 | def encode(self, line): 83 | global bpe 84 | subword = bpe.tokenize(line) 85 | return subword 86 | 87 | def decode(self, tokens): 88 | global bpe 89 | return bpe.decode(tokens) 90 | 91 | def encode_lines(self, lines): 92 | """ 93 | Encode a set of lines. All lines will be encoded together. 94 | """ 95 | enc_lines = [] 96 | for line in lines: 97 | line = line.strip() 98 | if len(line) == 0 and not self.args.keep_empty: 99 | return ["EMPTY", None] 100 | if self.args.tokenizer == 'bpe': 101 | tokens = self.encode(line) 102 | enc_lines.append(" ".join(tokens)) 103 | else: 104 | enc_lines.append(text_to_words(line)) 105 | return ["PASS", enc_lines] 106 | 107 | def decode_lines(self, lines): 108 | dec_lines = [] 109 | for line in lines: 110 | tokens = map(int, line.strip().split()) 111 | dec_lines.append(self.decode(tokens)) 112 | return ["PASS", dec_lines] 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /metrics/calc_bleu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | import random 5 | import string 6 | from nltk.translate import bleu_score 7 | from subprocess import Popen, PIPE 8 | from collections import defaultdict 9 | 10 | 11 | def bleu(hyp, raw_data, bleu_path="metrics/sentence-bleu", nltk="cpp"): 12 | assert nltk in ['cpp', 'corpus', 'sentence'], \ 13 | "nltk param should be cpp/corpus/sentence" 14 | assert len(hyp) == len(raw_data), \ 15 | "sentence num in hyp not equal to dataset" 16 | tmp_dir = "/tmp" 17 | suffix = ''.join(random.sample(string.ascii_letters + string.digits, 8)) 18 | hyp_path = os.path.join(tmp_dir, 'hyp-' + suffix) 19 | base_ref_path = os.path.join(tmp_dir, 'ref-' + suffix) 20 | to_be_deleted = set() 21 | to_be_deleted.add(hyp_path) 22 | 23 | ref_dict = defaultdict(list) 24 | for word, exp, sense in raw_data: 25 | ref_dict[word].append(sense) 26 | 27 | score = 0 28 | num_hyp = 0 29 | if nltk == 'corpus': 30 | refs = [] 31 | with open(os.devnull, 'w') as devnull: 32 | for idx, desc in enumerate(hyp): 33 | word = raw_data[idx][0] 34 | if nltk == 'sentence': 35 | if len(desc) == 0: 36 | auto_reweigh = False 37 | else: 38 | auto_reweigh = True 39 | bleu = bleu_score.sentence_bleu( 40 | [r.split(' ') for r in ref_dict[word]], 41 | desc, 42 | smoothing_function=bleu_score.SmoothingFunction().method2, 43 | auto_reweigh=auto_reweigh) 44 | score += bleu 45 | num_hyp += 1 46 | 47 | elif nltk == 'corpus': 48 | refs.append([r.split(' ') for r in ref_dict[word]]) 49 | 50 | elif nltk == 'cpp': 51 | ref_paths = [] 52 | for i, ref in enumerate(ref_dict[word][:30]): 53 | ref_path = base_ref_path + str(i) 54 | with open(ref_path, 'w') as f: 55 | f.write(ref + '\n') 56 | ref_paths.append(ref_path) 57 | to_be_deleted.add(ref_path) 58 | 59 | with open(hyp_path, 'w') as f: 60 | f.write(' '.join(desc) + '\n') 61 | 62 | rp = Popen(['cat', hyp_path], stdout=PIPE) 63 | bp = Popen([bleu_path] + ref_paths, stdin=rp.stdout, stdout=PIPE, stderr=devnull) 64 | out, err = bp.communicate() 65 | bleu = float(out.strip()) 66 | score += bleu 67 | num_hyp += 1 68 | 69 | else: 70 | raise ValueError("nltk must be sentence/corpus/cpp") 71 | if nltk == 'cpp': 72 | for f in to_be_deleted: 73 | if os.path.exists(f): 74 | os.remove(f) 75 | if nltk == 'corpus': 76 | bleu = bleu_score.corpus_bleu(refs, [h for h in hyp], 77 | smoothing_function=bleu_score.SmoothingFunction().method2) 78 | ret_bleu = bleu 79 | else: 80 | ret_bleu = score / num_hyp 81 | 82 | return ret_bleu 83 | 84 | 85 | def main(argv=None): 86 | if argv is None: 87 | argv = sys.argv[1:] 88 | assert len(argv) == 3 89 | gold_src = argv[0] 90 | gold_tgt = argv[1] 91 | hyp_file = argv[2] 92 | 93 | hyp_data = [] 94 | raw_data = [] 95 | with open(gold_src) as fr_src, \ 96 | open(gold_tgt) as fr_tgt, \ 97 | open(hyp_file) as fr_hyp: 98 | src_content = fr_src.readlines() 99 | tgt_content = fr_tgt.readlines() 100 | hyp_content = fr_hyp.readlines() 101 | assert len(src_content) == len(tgt_content) == len(hyp_content) 102 | for src, tgt, hyp in zip(src_content, tgt_content, hyp_content): 103 | word, exp = src.strip().split(' [SEP] ') 104 | hyp_data.append(hyp.strip().split(' ')) 105 | raw_data.append((word, exp, tgt.strip())) 106 | 107 | bleu_score = bleu(hyp_data, raw_data, nltk='cpp') 108 | print(f"BLEU Score: {bleu_score}") 109 | return 0 110 | 111 | 112 | if __name__ == "__main__": 113 | sys.exit(main()) 114 | -------------------------------------------------------------------------------- /metrics/calc_bleu_zh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | import random 5 | import string 6 | from nltk.translate import bleu_score 7 | from subprocess import Popen, PIPE 8 | from collections import defaultdict 9 | import jieba 10 | 11 | 12 | def bleu(hyp, raw_data, bleu_path="metrics/sentence-bleu", nltk="cpp"): 13 | assert nltk in ['cpp', 'corpus', 'sentence'], \ 14 | "nltk param should be cpp/corpus/sentence" 15 | assert len(hyp) == len(raw_data), \ 16 | "sentence num in hyp not equal to dataset" 17 | tmp_dir = "/tmp" 18 | suffix = ''.join(random.sample(string.ascii_letters + string.digits, 8)) 19 | hyp_path = os.path.join(tmp_dir, 'hyp-' + suffix) 20 | base_ref_path = os.path.join(tmp_dir, 'ref-' + suffix) 21 | to_be_deleted = set() 22 | to_be_deleted.add(hyp_path) 23 | 24 | ref_dict = defaultdict(list) 25 | for word, exp, sense in raw_data: 26 | ref_dict[word].append(sense) 27 | 28 | score = 0 29 | num_hyp = 0 30 | if nltk == 'corpus': 31 | refs = [] 32 | with open(os.devnull, 'w') as devnull: 33 | for idx, desc in enumerate(hyp): 34 | word = raw_data[idx][0] 35 | if nltk == 'sentence': 36 | if len(desc) == 0: 37 | auto_reweigh = False 38 | else: 39 | auto_reweigh = True 40 | bleu = bleu_score.sentence_bleu( 41 | [r.split(' ') for r in ref_dict[word]], 42 | desc, 43 | smoothing_function=bleu_score.SmoothingFunction().method2, 44 | auto_reweigh=auto_reweigh) 45 | score += bleu 46 | num_hyp += 1 47 | 48 | elif nltk == 'corpus': 49 | refs.append([r.split(' ') for r in ref_dict[word]]) 50 | 51 | elif nltk == 'cpp': 52 | ref_paths = [] 53 | for i, ref in enumerate(ref_dict[word][:30]): 54 | ref_path = base_ref_path + str(i) 55 | with open(ref_path, 'w') as f: 56 | f.write(ref + '\n') 57 | ref_paths.append(ref_path) 58 | to_be_deleted.add(ref_path) 59 | 60 | with open(hyp_path, 'w') as f: 61 | f.write(' '.join(desc) + '\n') 62 | 63 | rp = Popen(['cat', hyp_path], stdout=PIPE) 64 | bp = Popen([bleu_path] + ref_paths, stdin=rp.stdout, stdout=PIPE, stderr=devnull) 65 | out, err = bp.communicate() 66 | bleu = float(out.strip()) 67 | score += bleu 68 | num_hyp += 1 69 | 70 | else: 71 | raise ValueError("nltk must be sentence/corpus/cpp") 72 | if nltk == 'cpp': 73 | for f in to_be_deleted: 74 | if os.path.exists(f): 75 | os.remove(f) 76 | if nltk == 'corpus': 77 | bleu = bleu_score.corpus_bleu(refs, [h for h in hyp], 78 | smoothing_function=bleu_score.SmoothingFunction().method2) 79 | ret_bleu = bleu 80 | else: 81 | ret_bleu = score / num_hyp 82 | 83 | return ret_bleu 84 | 85 | 86 | def main(argv=None): 87 | if argv is None: 88 | argv = sys.argv[1:] 89 | assert len(argv) == 3 90 | gold_src = argv[0] 91 | gold_tgt = argv[1] 92 | hyp_file = argv[2] 93 | 94 | hyp_data = [] 95 | raw_data = [] 96 | with open(gold_src) as fr_src, \ 97 | open(gold_tgt) as fr_tgt, \ 98 | open(hyp_file) as fr_hyp: 99 | src_content = fr_src.readlines() 100 | tgt_content = fr_tgt.readlines() 101 | hyp_content = fr_hyp.readlines() 102 | assert len(src_content) == len(tgt_content) == len(hyp_content) 103 | for src, tgt, hyp in zip(src_content, tgt_content, hyp_content): 104 | word, exp = src.strip().split(' [SEP] ') 105 | exp = ' '.join(jieba.lcut(exp.replace(' ', ''))) 106 | hyp = jieba.lcut(hyp.strip().replace(' ', '')) 107 | tgt = ' '.join(jieba.lcut(tgt.strip().replace(' ', ''))) 108 | hyp_data.append(hyp) 109 | raw_data.append((word, exp, tgt)) 110 | 111 | bleu_score = bleu(hyp_data, raw_data, nltk='cpp') 112 | print(f"BLEU Score: {bleu_score}") 113 | return 0 114 | 115 | 116 | if __name__ == "__main__": 117 | sys.exit(main()) 118 | -------------------------------------------------------------------------------- /mass/masked_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import time 5 | import math 6 | 7 | from fairseq import utils 8 | from fairseq.data import data_utils, LanguagePairDataset 9 | 10 | 11 | class MaskedLanguagePairDataset(LanguagePairDataset): 12 | """ Wrapper for masked language datasets 13 | (support monolingual and bilingual) 14 | 15 | For monolingual dataset: 16 | [x1, x2, x3, x4, x5] 17 | || 18 | VV 19 | [x1, _, _, x4, x5] => [x2, x3] 20 | 21 | default, _ will be replaced by 8:1:1 (mask, self, rand), 22 | """ 23 | def __init__( 24 | self, 25 | src, src_sizes, src_dict, 26 | tgt=None, tgt_sizes=None, tgt_dict=None, 27 | left_pad_source=True, left_pad_target=False, 28 | max_source_positions=1024, max_target_positions=1024, 29 | shuffle=True, mask_prob=0.15, pred_probs=None, block_size=64, 30 | ): 31 | self.src = src 32 | self.tgt = tgt 33 | self.src_sizes = src_sizes 34 | self.tgt_sizes = tgt_sizes 35 | self.src_dict = src_dict 36 | self.tgt_dict = tgt_dict 37 | self.left_pad_source = left_pad_source 38 | self.left_pad_target = left_pad_target 39 | self.shuffle = shuffle 40 | 41 | self.mask_prob = mask_prob 42 | self.pred_probs = pred_probs 43 | self.block_size = block_size 44 | 45 | def __getitem__(self, index): 46 | pkgs = {'id': index} 47 | tgt_item = self.tgt[index] if self.tgt is not None else None 48 | src_item = self.src[index] 49 | 50 | positions = np.arange(0, len(self.src[index])) 51 | masked_pos = [] 52 | for i in range(1, len(src_item), self.block_size): 53 | block = positions[i: i + self.block_size] 54 | masked_len = int(len(block) * self.mask_prob) 55 | masked_block_start = np.random.choice(block[:len(block) - int(masked_len) + 1], 1)[0] 56 | masked_pos.extend(positions[masked_block_start : masked_block_start + masked_len]) 57 | masked_pos = np.array(masked_pos) 58 | 59 | pkgs['target'] = src_item[masked_pos].clone() 60 | pkgs['prev_output_tokens'] = src_item[masked_pos - 1].clone() 61 | pkgs['positions'] = torch.LongTensor(masked_pos) + self.src_dict.pad_index 62 | src_item[masked_pos] = self.replace(src_item[masked_pos]) 63 | pkgs['source'] = src_item 64 | return pkgs 65 | 66 | def collate(self, samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False): 67 | if len(samples) == 0: 68 | return {} 69 | 70 | def merge(x, left_pad, move_eos_to_beginning=False): 71 | return data_utils.collate_tokens( 72 | x, pad_idx, eos_idx, left_pad, move_eos_to_beginning 73 | ) 74 | 75 | id = torch.LongTensor([s['id'] for s in samples]) 76 | source = merge([s['source'] for s in samples], left_pad=left_pad_source) 77 | src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) 78 | 79 | prev_output_tokens = merge([s['prev_output_tokens'] for s in samples], left_pad=left_pad_target) 80 | positions = merge([s['positions'] for s in samples], left_pad=left_pad_target) 81 | target = merge([s['target'] for s in samples], left_pad=left_pad_target) 82 | ntokens = target.numel() 83 | 84 | batch = { 85 | 'id' : id, 86 | 'nsentences': len(samples), 87 | 'net_input' : { 88 | 'src_lengths': src_lengths, 89 | 'src_tokens' : source, 90 | 'prev_output_tokens': prev_output_tokens, 91 | 'positions' : positions, 92 | }, 93 | 'target' : target, 94 | 'ntokens': ntokens, 95 | } 96 | return batch 97 | 98 | def collater(self, samples): 99 | return self.collate(samples, self.src_dict.pad(), self.src_dict.eos()) 100 | 101 | def size(self, index): 102 | return self.src.sizes[index] 103 | 104 | def replace(self, x): 105 | _x_real = x 106 | _x_rand = _x_real.clone().random_(self.src_dict.nspecial, len(self.src_dict)) 107 | _x_mask = _x_real.clone().fill_(self.src_dict.index('[MASK]')) 108 | probs = torch.multinomial(self.pred_probs, len(x), replacement=True) 109 | _x = _x_mask * (probs == 0).long() + \ 110 | _x_real * (probs == 1).long() + \ 111 | _x_rand * (probs == 2).long() 112 | return _x 113 | -------------------------------------------------------------------------------- /mass/masked_s2s.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import numpy as np 8 | 9 | import torch 10 | 11 | from collections import OrderedDict 12 | from fairseq import utils 13 | from fairseq.data import ( 14 | data_utils, 15 | Dictionary, 16 | TokenBlockDataset, 17 | ) 18 | from fairseq.tasks import FairseqTask, register_task 19 | from .masked_dataset import MaskedLanguagePairDataset 20 | from .bert_dictionary import BertDictionary 21 | 22 | 23 | @register_task('masked_s2s') 24 | class MaskedS2STask(FairseqTask): 25 | """ 26 | Train a sequence-to-sequence task 27 | 28 | Args: 29 | dictionary (~fairseq.data.Dictionary): the dictionary for the input of 30 | the language model 31 | """ 32 | 33 | @staticmethod 34 | def add_args(parser): 35 | """Add task-specific arguments to the parser.""" 36 | # fmt: off 37 | parser.add_argument('data', help='path to data directory') 38 | parser.add_argument('--sample-break-mode', default='none', 39 | choices=['none', 'complete', 'complete_doc', 'eos'], 40 | help='If omitted or "none", fills each sample with tokens-per-sample ' 41 | 'tokens. If set to "complete", splits samples only at the end ' 42 | 'of sentence, but may include multiple sentences per sample. ' 43 | '"complete_doc" is similar but respects doc boundaries. ' 44 | 'If set to "eos", includes only one sentence per sample.') 45 | parser.add_argument('--tokens-per-sample', default=512, type=int, 46 | help='max number of tokens per sample for text dataset') 47 | parser.add_argument('--lazy-load', action='store_true', 48 | help='load the dataset lazily') 49 | parser.add_argument('--raw-text', default=False, action='store_true', 50 | help='load raw text dataset') 51 | 52 | parser.add_argument('--mask-s2s-prob', default=0.15, type=float, 53 | help='probability of replacing a token with mask') 54 | parser.add_argument('--mask-s2s-mask-keep-rand', default="0.8,0.1,0.1", type=str, 55 | help='Word prediction probability for decoder mask') 56 | 57 | # fmt: on 58 | 59 | def __init__(self, args, dictionary): 60 | super().__init__(args) 61 | self.dictionary = dictionary 62 | 63 | @classmethod 64 | def setup_task(cls, args, **kwargs): 65 | """Setup the task (e.g., load dictionaries). 66 | 67 | Args: 68 | args (argparse.Namespace): parsed command-line arguments 69 | """ 70 | if getattr(args, 'raw_text', False): 71 | utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw') 72 | args.dataset_impl = 'raw' 73 | elif getattr(args, 'lazy_load', False): 74 | utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy') 75 | args.dataset_impl = 'lazy' 76 | 77 | paths = args.data.split(':') 78 | 79 | dictionary = cls.load_dictionary(os.path.join(paths[0], 'dict.txt')) 80 | print('| dictionary: {} types'.format(len(dictionary))) 81 | return cls(args, dictionary) 82 | 83 | @classmethod 84 | def load_dictionary(cls, filename): 85 | return BertDictionary.load_from_file(filename) 86 | 87 | def train_step(self, sample, model, criterion, optimizer, ignore_grad=False): 88 | model.train() 89 | loss, sample_size, logging_output = criterion(model, sample) 90 | if ignore_grad: 91 | loss *= 0 92 | optimizer.backward(loss) 93 | return loss, sample_size, logging_output 94 | 95 | def valid_step(self, sample, model, criterion): 96 | model.eval() 97 | with torch.no_grad(): 98 | loss, sample_size, logging_output = criterion(model, sample) 99 | return loss, sample_size, logging_output 100 | 101 | def build_model(self, args): 102 | from fairseq import models 103 | model = models.build_model(args, self) 104 | return model 105 | 106 | def load_dataset(self, split, epoch=0, combine=False, **kwargs): 107 | """Load a given dataset split. 108 | 109 | Args: 110 | split (str): name of the split (e.g., train, valid, test) 111 | """ 112 | paths = self.args.data.split(':') 113 | assert len(paths) > 0 114 | data_path = paths[epoch % len(paths)] 115 | split_path = os.path.join(data_path, split) 116 | 117 | dataset = data_utils.load_indexed_dataset( 118 | split_path, 119 | self.dictionary, 120 | self.args.dataset_impl, 121 | combine=combine, 122 | ) 123 | if dataset is None: 124 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) 125 | 126 | self.datasets[split] = self.build_s2s_dataset(dataset) 127 | 128 | def build_s2s_dataset(self, dataset): 129 | dataset = TokenBlockDataset( 130 | dataset, 131 | dataset.sizes, 132 | self.args.tokens_per_sample, 133 | pad=self.source_dictionary.pad(), 134 | eos=self.source_dictionary.eos(), 135 | break_mode=self.args.sample_break_mode, 136 | ) 137 | 138 | pred_probs = torch.FloatTensor([float(x) for x in self.args.mask_s2s_mask_keep_rand.split(',')]) 139 | 140 | s2s_dataset = MaskedLanguagePairDataset( 141 | dataset, dataset.sizes, self.source_dictionary, 142 | shuffle=True, mask_prob=self.args.mask_s2s_prob, 143 | pred_probs=pred_probs, 144 | ) 145 | return s2s_dataset 146 | 147 | def build_dataset_for_inference(self, src_tokens, src_lengths): 148 | raise NotImplementedError 149 | 150 | def inference_step(self, generator, models, sample, prefix_tokens=None): 151 | raise NotImplementedError 152 | 153 | @property 154 | def source_dictionary(self): 155 | """Return the :class:`~fairseq.data.Dictionary` for the language 156 | model.""" 157 | return self.dictionary 158 | 159 | @property 160 | def target_dictionary(self): 161 | """Return the :class:`~fairseq.data.Dictionary` for the language 162 | model.""" 163 | return self.dictionary 164 | 165 | def max_positions(self): 166 | max_positions = 1024 167 | if hasattr(self.args, 'max_positions'): 168 | max_positions = min(max_positions, self.args.max_positions) 169 | if hasattr(self.args, 'max_source_positions'): 170 | max_positions = min(max_positions, self.args.max_source_positions) 171 | if hasattr(self.args, 'max_target_positions'): 172 | max_positions = min(max_positions, self.args.max_target_positions) 173 | return (max_positions, max_positions) 174 | -------------------------------------------------------------------------------- /mass/language_pair_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from fairseq.data import data_utils, FairseqDataset 10 | 11 | 12 | def collate( 13 | samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, 14 | input_feeding=True, lang_pair=None 15 | ): 16 | if len(samples) == 0: 17 | return {} 18 | 19 | def merge(key, left_pad, move_eos_to_beginning=False): 20 | return data_utils.collate_tokens( 21 | [s[key] for s in samples], 22 | pad_idx, eos_idx, left_pad, move_eos_to_beginning, 23 | ) 24 | 25 | def check_alignment(alignment, src_len, tgt_len): 26 | if alignment is None or len(alignment) == 0: 27 | return False 28 | if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1: 29 | print("| alignment size mismatch found, skipping alignment!") 30 | return False 31 | return True 32 | 33 | def compute_alignment_weights(alignments): 34 | """ 35 | Given a tensor of shape [:, 2] containing the source-target indices 36 | corresponding to the alignments, a weight vector containing the 37 | inverse frequency of each target index is computed. 38 | For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then 39 | a tensor containing [1., 0.5, 0.5, 1] should be returned (since target 40 | index 3 is repeated twice) 41 | """ 42 | align_tgt = alignments[:, 1] 43 | _, align_tgt_i, align_tgt_c = torch.unique(align_tgt, return_inverse=True, return_counts=True) 44 | align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] 45 | return 1. / align_weights.float() 46 | 47 | id = torch.LongTensor([s['id'] for s in samples]) 48 | src_tokens = merge('source', left_pad=left_pad_source) 49 | # sort by descending source length 50 | src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) 51 | src_lengths, sort_order = src_lengths.sort(descending=True) 52 | id = id.index_select(0, sort_order) 53 | src_tokens = src_tokens.index_select(0, sort_order) 54 | 55 | prev_output_tokens = None 56 | target = None 57 | if samples[0].get('target', None) is not None: 58 | target = merge('target', left_pad=left_pad_target) 59 | target = target.index_select(0, sort_order) 60 | tgt_lengths = torch.LongTensor([s['target'].numel() for s in samples]).index_select(0, sort_order) 61 | ntokens = sum(len(s['target']) for s in samples) 62 | 63 | if input_feeding: 64 | # we create a shifted version of targets for feeding the 65 | # previous output token(s) into the next decoder step 66 | prev_output_tokens = merge( 67 | 'target', 68 | left_pad=left_pad_target, 69 | move_eos_to_beginning=True, 70 | ) 71 | prev_output_tokens = prev_output_tokens.index_select(0, sort_order) 72 | else: 73 | ntokens = sum(len(s['source']) for s in samples) 74 | 75 | batch = { 76 | 'id': id, 77 | 'nsentences': len(samples), 78 | 'ntokens': ntokens, 79 | 'net_input': { 80 | 'src_tokens': src_tokens, 81 | 'src_lengths': src_lengths, 82 | 'lang_pair': lang_pair, 83 | }, 84 | 'target': target, 85 | } 86 | if prev_output_tokens is not None: 87 | batch['net_input']['prev_output_tokens'] = prev_output_tokens 88 | 89 | if samples[0].get('alignment', None) is not None: 90 | bsz, tgt_sz = batch['target'].shape 91 | src_sz = batch['net_input']['src_tokens'].shape[1] 92 | 93 | offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) 94 | offsets[:, 1] += (torch.arange(len(sort_order), dtype=torch.long) * tgt_sz) 95 | if left_pad_source: 96 | offsets[:, 0] += (src_sz - src_lengths) 97 | if left_pad_target: 98 | offsets[:, 1] += (tgt_sz - tgt_lengths) 99 | 100 | alignments = [ 101 | alignment + offset 102 | for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths) 103 | for alignment in [samples[align_idx]['alignment'].view(-1, 2)] 104 | if check_alignment(alignment, src_len, tgt_len) 105 | ] 106 | 107 | if len(alignments) > 0: 108 | alignments = torch.cat(alignments, dim=0) 109 | align_weights = compute_alignment_weights(alignments) 110 | 111 | batch['alignments'] = alignments 112 | batch['align_weights'] = align_weights 113 | 114 | return batch 115 | 116 | 117 | class LanguagePairDataset(FairseqDataset): 118 | """ 119 | A pair of torch.utils.data.Datasets. 120 | Args: 121 | src (torch.utils.data.Dataset): source dataset to wrap 122 | src_sizes (List[int]): source sentence lengths 123 | src_dict (~fairseq.data.Dictionary): source vocabulary 124 | tgt (torch.utils.data.Dataset, optional): target dataset to wrap 125 | tgt_sizes (List[int], optional): target sentence lengths 126 | tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary 127 | left_pad_source (bool, optional): pad source tensors on the left side 128 | (default: True). 129 | left_pad_target (bool, optional): pad target tensors on the left side 130 | (default: False). 131 | max_source_positions (int, optional): max number of tokens in the 132 | source sentence (default: 1024). 133 | max_target_positions (int, optional): max number of tokens in the 134 | target sentence (default: 1024). 135 | shuffle (bool, optional): shuffle dataset elements before batching 136 | (default: True). 137 | input_feeding (bool, optional): create a shifted version of the targets 138 | to be passed into the model for teacher forcing (default: True). 139 | remove_eos_from_source (bool, optional): if set, removes eos from end 140 | of source if it's present (default: False). 141 | append_eos_to_target (bool, optional): if set, appends eos to end of 142 | target if it's absent (default: False). 143 | align_dataset (torch.utils.data.Dataset, optional): dataset 144 | containing alignments. 145 | append_bos (bool, optional): if set, appends bos to the beginning of 146 | source/target sentence. 147 | """ 148 | 149 | def __init__( 150 | self, src, src_sizes, src_dict, 151 | tgt=None, tgt_sizes=None, tgt_dict=None, 152 | left_pad_source=True, left_pad_target=False, 153 | max_source_positions=1024, max_target_positions=1024, 154 | shuffle=True, input_feeding=True, 155 | remove_eos_from_source=False, append_eos_to_target=False, 156 | align_dataset=None, 157 | append_bos=False, 158 | lang_pair=None, 159 | ): 160 | if tgt_dict is not None: 161 | assert src_dict.pad() == tgt_dict.pad() 162 | assert src_dict.eos() == tgt_dict.eos() 163 | assert src_dict.unk() == tgt_dict.unk() 164 | self.src = src 165 | self.tgt = tgt 166 | self.src_sizes = np.array(src_sizes) 167 | self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None 168 | self.src_dict = src_dict 169 | self.tgt_dict = tgt_dict 170 | self.left_pad_source = left_pad_source 171 | self.left_pad_target = left_pad_target 172 | self.max_source_positions = max_source_positions 173 | self.max_target_positions = max_target_positions 174 | self.shuffle = shuffle 175 | self.input_feeding = input_feeding 176 | self.remove_eos_from_source = remove_eos_from_source 177 | self.append_eos_to_target = append_eos_to_target 178 | self.align_dataset = align_dataset 179 | self.lang_pair = lang_pair 180 | if self.align_dataset is not None: 181 | assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided" 182 | self.append_bos = append_bos 183 | 184 | def __getitem__(self, index): 185 | tgt_item = self.tgt[index] if self.tgt is not None else None 186 | src_item = self.src[index] 187 | # Append EOS to end of tgt sentence if it does not have an EOS and remove 188 | # EOS from end of src sentence if it exists. This is useful when we use 189 | # use existing datasets for opposite directions i.e., when we want to 190 | # use tgt_dataset as src_dataset and vice versa 191 | if self.append_eos_to_target: 192 | eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() 193 | if self.tgt and self.tgt[index][-1] != eos: 194 | tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) 195 | 196 | if self.append_bos: 197 | bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() 198 | if self.tgt and self.tgt[index][0] != bos: 199 | tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) 200 | 201 | bos = self.src_dict.bos() 202 | if self.src[index][-1] != bos: 203 | src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) 204 | 205 | if self.remove_eos_from_source: 206 | eos = self.src_dict.eos() 207 | if self.src[index][-1] == eos: 208 | src_item = self.src[index][:-1] 209 | 210 | example = { 211 | 'id': index, 212 | 'source': src_item, 213 | 'target': tgt_item, 214 | } 215 | if self.align_dataset is not None: 216 | example['alignment'] = self.align_dataset[index] 217 | return example 218 | 219 | def __len__(self): 220 | return len(self.src) 221 | 222 | def collater(self, samples): 223 | """Merge a list of samples to form a mini-batch. 224 | Args: 225 | samples (List[dict]): samples to collate 226 | Returns: 227 | dict: a mini-batch with the following keys: 228 | - `id` (LongTensor): example IDs in the original input order 229 | - `ntokens` (int): total number of tokens in the batch 230 | - `net_input` (dict): the input to the Model, containing keys: 231 | - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in 232 | the source sentence of shape `(bsz, src_len)`. Padding will 233 | appear on the left if *left_pad_source* is ``True``. 234 | - `src_lengths` (LongTensor): 1D Tensor of the unpadded 235 | lengths of each source sentence of shape `(bsz)` 236 | - `prev_output_tokens` (LongTensor): a padded 2D Tensor of 237 | tokens in the target sentence, shifted right by one 238 | position for teacher forcing, of shape `(bsz, tgt_len)`. 239 | This key will not be present if *input_feeding* is 240 | ``False``. Padding will appear on the left if 241 | *left_pad_target* is ``True``. 242 | - `target` (LongTensor): a padded 2D Tensor of tokens in the 243 | target sentence of shape `(bsz, tgt_len)`. Padding will appear 244 | on the left if *left_pad_target* is ``True``. 245 | """ 246 | return collate( 247 | samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(), 248 | left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, 249 | input_feeding=self.input_feeding, lang_pair=self.lang_pair, 250 | ) 251 | 252 | def num_tokens(self, index): 253 | """Return the number of tokens in a sample. This value is used to 254 | enforce ``--max-tokens`` during batching.""" 255 | return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) 256 | 257 | def size(self, index): 258 | """Return an example's size as a float or tuple. This value is used when 259 | filtering a dataset with ``--max-positions``.""" 260 | return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) 261 | 262 | def ordered_indices(self): 263 | """Return an ordered list of indices. Batches will be constructed based 264 | on this order.""" 265 | if self.shuffle: 266 | indices = np.random.permutation(len(self)) 267 | else: 268 | indices = np.arange(len(self)) 269 | if self.tgt_sizes is not None: 270 | indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] 271 | return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] 272 | 273 | @property 274 | def supports_prefetch(self): 275 | return ( 276 | getattr(self.src, 'supports_prefetch', False) 277 | and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None) 278 | ) 279 | 280 | def prefetch(self, indices): 281 | self.src.prefetch(indices) 282 | if self.tgt is not None: 283 | self.tgt.prefetch(indices) 284 | if self.align_dataset is not None: 285 | self.align_dataset.prefetch(indices) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Train a new model on one or across multiple GPUs. 8 | """ 9 | 10 | import collections 11 | import math 12 | import random 13 | import os 14 | 15 | import numpy as np 16 | import torch 17 | 18 | from mass.trainer import Trainer 19 | from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils 20 | from fairseq.data import iterators 21 | from fairseq.meters import AverageMeter, StopwatchMeter 22 | import wandb 23 | 24 | 25 | def main(args, init_distributed=False): 26 | if args.device_id == 0: 27 | wandb.init(project='def-gen-simp-final', 28 | entity='cunliang-kong', 29 | reinit=False) 30 | wandb.run.name = os.path.split(args.save_dir)[1] 31 | 32 | config = wandb.config 33 | hyper_param_of_interest = [ 34 | 'lr', 'min_lr', 'clip_norm', 'warmup_init_lr', 'warmup_updates', 35 | 'weight_decay', 'label_smoothing', 'update_freq', 'max_tokens', 36 | 'max_epoch', 'dropout', 'dae_styles' 37 | ] 38 | for hp in hyper_param_of_interest: 39 | setattr(config, hp, getattr(args, hp)) 40 | utils.import_user_module(args) 41 | 42 | assert args.max_tokens is not None or args.max_sentences is not None, \ 43 | 'Must specify batch size either with --max-tokens or --max-sentences' 44 | 45 | # Initialize CUDA and distributed training 46 | if torch.cuda.is_available() and not args.cpu: 47 | torch.cuda.set_device(args.device_id) 48 | np.random.seed(args.seed) 49 | torch.manual_seed(args.seed) 50 | if init_distributed: 51 | args.distributed_rank = distributed_utils.distributed_init(args) 52 | 53 | if distributed_utils.is_master(args): 54 | checkpoint_utils.verify_checkpoint_directory(args.save_dir) 55 | 56 | # Print args 57 | print(args) 58 | 59 | # Setup task, e.g., translation, language modeling, etc. 60 | task = tasks.setup_task(args) 61 | 62 | # Load valid dataset (we load training data below, based on the latest checkpoint) 63 | for valid_sub_split in args.valid_subset.split(','): 64 | task.load_dataset(valid_sub_split, combine=False, epoch=0) 65 | 66 | # Build model and criterion 67 | model = task.build_model(args) 68 | criterion = task.build_criterion(args) 69 | print(model) 70 | print('| model {}, criterion {}'.format(args.arch, 71 | criterion.__class__.__name__)) 72 | print('| num. model params: {} (num. trained: {})'.format( 73 | sum(p.numel() for p in model.parameters()), 74 | sum(p.numel() for p in model.parameters() if p.requires_grad), 75 | )) 76 | 77 | # Build trainer 78 | trainer = Trainer(args, task, model, criterion) 79 | print('| training on {} GPUs'.format(args.distributed_world_size)) 80 | print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( 81 | args.max_tokens, 82 | args.max_sentences, 83 | )) 84 | 85 | # Load the latest checkpoint if one is available and restore the 86 | # corresponding train iterator 87 | extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) 88 | 89 | # Train until the learning rate gets too small 90 | max_epoch = args.max_epoch or math.inf 91 | max_update = args.max_update or math.inf 92 | lr = trainer.get_lr() 93 | train_meter = StopwatchMeter() 94 | train_meter.start() 95 | valid_subsets = args.valid_subset.split(',') 96 | while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or 97 | (epoch_itr.epoch == max_epoch 98 | and epoch_itr._next_epoch_itr is not None)) 99 | and trainer.get_num_updates() < max_update): 100 | # train for one epoch 101 | train(args, trainer, task, epoch_itr) 102 | 103 | if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: 104 | valid_losses = validate(args, trainer, task, epoch_itr, 105 | valid_subsets) 106 | else: 107 | valid_losses = [None] 108 | 109 | # only use first validation loss to update the learning rate 110 | lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) 111 | 112 | # save checkpoint 113 | if epoch_itr.epoch % args.save_interval == 0: 114 | checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, 115 | valid_losses[0]) 116 | 117 | reload_dataset = ':' in getattr(args, 'data', '') 118 | # sharded data: get train iterator for next epoch 119 | epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, 120 | load_dataset=reload_dataset) 121 | train_meter.stop() 122 | print('| done training in {:.1f} seconds'.format(train_meter.sum)) 123 | if args.device_id == 0: 124 | wandb.finish() 125 | 126 | 127 | def train(args, trainer, task, epoch_itr): 128 | """Train the model for one epoch.""" 129 | # Update parameters every N batches 130 | update_freq = args.update_freq[epoch_itr.epoch - 1] \ 131 | if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] 132 | 133 | # Initialize data iterator 134 | itr = epoch_itr.next_epoch_itr( 135 | fix_batches_to_gpus=args.fix_batches_to_gpus, 136 | shuffle=(epoch_itr.epoch >= args.curriculum), 137 | ) 138 | itr = iterators.GroupedIterator(itr, update_freq) 139 | progress = progress_bar.build_progress_bar( 140 | args, 141 | itr, 142 | epoch_itr.epoch, 143 | no_progress_bar='simple', 144 | ) 145 | 146 | extra_meters = collections.defaultdict(lambda: AverageMeter()) 147 | valid_subsets = args.valid_subset.split(',') 148 | max_update = args.max_update or math.inf 149 | for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): 150 | log_output = trainer.train_step(samples) 151 | if log_output is None: 152 | continue 153 | 154 | # log mid-epoch stats 155 | stats = get_training_stats(trainer) 156 | for k, v in log_output.items(): 157 | if k in [ 158 | 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' 159 | ]: 160 | continue # these are already logged above 161 | if 'loss' in k or k == 'accuracy': 162 | extra_meters[k].update(v, log_output['sample_size']) 163 | else: 164 | extra_meters[k].update(v) 165 | stats[k] = extra_meters[k].avg 166 | progress.log(stats, tag='train', step=stats['num_updates']) 167 | 168 | src_tgt, style_style = args.model_lang_pairs 169 | 170 | if args.device_id == 0: 171 | wandb_log = { 172 | 'train/loss': stats['loss'].avg, 173 | 'train/nll_loss': stats['nll_loss'].avg, 174 | 'train/ppl': stats['ppl'], 175 | 'train/lr': stats['lr'], 176 | 'src-tgt/loss': stats[f'{src_tgt}:loss'], 177 | 'src-tgt/nll_loss': stats[f'{src_tgt}:nll_loss'], 178 | 'style-style/loss': stats[f'{style_style}:loss'], 179 | 'style-style/nll_loss': stats[f'{style_style}:nll_loss'], 180 | } 181 | wandb.log(wandb_log) 182 | 183 | # ignore the first mini-batch in words-per-second and updates-per-second calculation 184 | if i == 0: 185 | trainer.get_meter('wps').reset() 186 | trainer.get_meter('ups').reset() 187 | 188 | num_updates = trainer.get_num_updates() 189 | if (not args.disable_validation and args.save_interval_updates > 0 190 | and num_updates % args.save_interval_updates == 0 191 | and num_updates > 0): 192 | valid_losses = validate(args, trainer, task, epoch_itr, 193 | valid_subsets) 194 | checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, 195 | valid_losses[0]) 196 | 197 | if num_updates >= max_update: 198 | break 199 | 200 | # log end-of-epoch stats 201 | stats = get_training_stats(trainer) 202 | for k, meter in extra_meters.items(): 203 | stats[k] = meter.avg 204 | progress.print(stats, tag='train', step=stats['num_updates']) 205 | 206 | # reset training meters 207 | for k in [ 208 | 'train_loss', 209 | 'train_nll_loss', 210 | 'wps', 211 | 'ups', 212 | 'wpb', 213 | 'bsz', 214 | 'gnorm', 215 | 'clip', 216 | ]: 217 | meter = trainer.get_meter(k) 218 | if meter is not None: 219 | meter.reset() 220 | 221 | 222 | def get_training_stats(trainer): 223 | stats = collections.OrderedDict() 224 | stats['loss'] = trainer.get_meter('train_loss') 225 | if trainer.get_meter('train_nll_loss').count > 0: 226 | nll_loss = trainer.get_meter('train_nll_loss') 227 | stats['nll_loss'] = nll_loss 228 | else: 229 | nll_loss = trainer.get_meter('train_loss') 230 | stats['ppl'] = utils.get_perplexity(nll_loss.avg) 231 | stats['wps'] = trainer.get_meter('wps') 232 | stats['ups'] = trainer.get_meter('ups') 233 | stats['wpb'] = trainer.get_meter('wpb') 234 | stats['bsz'] = trainer.get_meter('bsz') 235 | stats['num_updates'] = trainer.get_num_updates() 236 | stats['lr'] = trainer.get_lr() 237 | stats['gnorm'] = trainer.get_meter('gnorm') 238 | stats['clip'] = trainer.get_meter('clip') 239 | stats['oom'] = trainer.get_meter('oom') 240 | if trainer.get_meter('loss_scale') is not None: 241 | stats['loss_scale'] = trainer.get_meter('loss_scale') 242 | stats['wall'] = round(trainer.get_meter('wall').elapsed_time) 243 | stats['train_wall'] = trainer.get_meter('train_wall') 244 | return stats 245 | 246 | 247 | def validate(args, trainer, task, epoch_itr, subsets): 248 | """Evaluate the model on the validation set(s) and return the losses.""" 249 | 250 | if args.fixed_validation_seed is not None: 251 | # set fixed seed for every validation 252 | utils.set_torch_seed(args.fixed_validation_seed) 253 | 254 | valid_losses = [] 255 | for subset in subsets: 256 | # Initialize data iterator 257 | itr = task.get_batch_iterator( 258 | dataset=task.dataset(subset), 259 | max_tokens=args.max_tokens_valid, 260 | max_sentences=args.max_sentences_valid, 261 | max_positions=utils.resolve_max_positions( 262 | task.max_positions(), 263 | trainer.get_model().max_positions(), 264 | ), 265 | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, 266 | required_batch_size_multiple=args.required_batch_size_multiple, 267 | seed=args.seed, 268 | num_shards=args.distributed_world_size, 269 | shard_id=args.distributed_rank, 270 | num_workers=args.num_workers, 271 | ).next_epoch_itr(shuffle=False) 272 | progress = progress_bar.build_progress_bar( 273 | args, 274 | itr, 275 | epoch_itr.epoch, 276 | prefix='valid on \'{}\' subset'.format(subset), 277 | no_progress_bar='simple') 278 | 279 | # reset validation loss meters 280 | for k in ['valid_loss', 'valid_nll_loss']: 281 | meter = trainer.get_meter(k) 282 | if meter is not None: 283 | meter.reset() 284 | extra_meters = collections.defaultdict(lambda: AverageMeter()) 285 | 286 | for sample in progress: 287 | log_output = trainer.valid_step(sample) 288 | 289 | for k, v in log_output.items(): 290 | if k in [ 291 | 'loss', 'nll_loss', 'ntokens', 'nsentences', 292 | 'sample_size' 293 | ]: 294 | continue 295 | extra_meters[k].update(v) 296 | 297 | # log validation stats 298 | stats = get_valid_stats(trainer, args, extra_meters) 299 | for k, meter in extra_meters.items(): 300 | stats[k] = meter.avg 301 | progress.print(stats, tag=subset, step=trainer.get_num_updates()) 302 | 303 | if args.device_id == 0: 304 | wandb_log = { 305 | 'valid/loss': stats['loss'].avg, 306 | 'valid/nll_loss': stats['nll_loss'].avg, 307 | 'valid/ppl': stats['ppl'], 308 | } 309 | wandb.log(wandb_log) 310 | 311 | valid_losses.append(stats[args.best_checkpoint_metric].avg if args. 312 | best_checkpoint_metric == 313 | 'loss' else stats[args.best_checkpoint_metric]) 314 | return valid_losses 315 | 316 | 317 | def get_valid_stats(trainer, args, extra_meters=None): 318 | stats = collections.OrderedDict() 319 | stats['loss'] = trainer.get_meter('valid_loss') 320 | if trainer.get_meter('valid_nll_loss').count > 0: 321 | nll_loss = trainer.get_meter('valid_nll_loss') 322 | stats['nll_loss'] = nll_loss 323 | else: 324 | nll_loss = stats['loss'] 325 | stats['ppl'] = utils.get_perplexity(nll_loss.avg) 326 | stats['num_updates'] = trainer.get_num_updates() 327 | if hasattr(checkpoint_utils.save_checkpoint, 'best'): 328 | key = 'best_{0}'.format(args.best_checkpoint_metric) 329 | best_function = max if args.maximize_best_checkpoint_metric else min 330 | 331 | current_metric = None 332 | if args.best_checkpoint_metric == 'loss': 333 | current_metric = stats['loss'].avg 334 | elif args.best_checkpoint_metric in extra_meters: 335 | current_metric = extra_meters[args.best_checkpoint_metric].avg 336 | elif args.best_checkpoint_metric in stats: 337 | current_metric = stats[args.best_checkpoint_metric] 338 | else: 339 | raise ValueError("best_checkpoint_metric not found in logs") 340 | 341 | stats[key] = best_function( 342 | checkpoint_utils.save_checkpoint.best, 343 | current_metric, 344 | ) 345 | return stats 346 | 347 | 348 | def distributed_main(i, args, start_rank=0): 349 | args.device_id = i 350 | if args.distributed_rank is None: # torch.multiprocessing.spawn 351 | args.distributed_rank = start_rank + i 352 | main(args, init_distributed=True) 353 | 354 | 355 | def cli_main(): 356 | parser = options.get_training_parser() 357 | args = options.parse_args_and_arch(parser) 358 | 359 | if args.distributed_init_method is None: 360 | distributed_utils.infer_init_method(args) 361 | 362 | if args.distributed_init_method is not None: 363 | # distributed training 364 | if torch.cuda.device_count() > 1 and not args.distributed_no_spawn: 365 | start_rank = args.distributed_rank 366 | args.distributed_rank = None # assign automatically 367 | torch.multiprocessing.spawn( 368 | fn=distributed_main, 369 | args=(args, start_rank), 370 | nprocs=torch.cuda.device_count(), 371 | ) 372 | else: 373 | distributed_main(args.device_id, args) 374 | elif args.distributed_world_size > 1: 375 | # fallback for single node with multiple GPUs 376 | assert args.distributed_world_size <= torch.cuda.device_count() 377 | port = random.randint(10000, 20000) 378 | args.distributed_init_method = 'tcp://localhost:{port}'.format( 379 | port=port) 380 | args.distributed_rank = None # set based on device id 381 | if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d': 382 | print( 383 | '| NOTE: you may get better performance with: --ddp-backend=no_c10d' 384 | ) 385 | torch.multiprocessing.spawn( 386 | fn=distributed_main, 387 | args=(args, ), 388 | nprocs=args.distributed_world_size, 389 | ) 390 | else: 391 | # single GPU training 392 | main(args) 393 | 394 | 395 | if __name__ == '__main__': 396 | cli_main() 397 | -------------------------------------------------------------------------------- /mass/multihead_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch 8 | from torch import nn 9 | from torch.nn import Parameter 10 | import torch.nn.functional as F 11 | 12 | from fairseq import utils 13 | 14 | 15 | class MultiheadAttention(nn.Module): 16 | """Multi-headed attention. 17 | See "Attention Is All You Need" for more details. 18 | """ 19 | 20 | def __init__(self, 21 | embed_dim, 22 | num_heads, 23 | kdim=None, 24 | vdim=None, 25 | dropout=0., 26 | bias=True, 27 | add_bias_kv=False, 28 | add_zero_attn=False, 29 | self_attention=False, 30 | encoder_decoder_attention=False, 31 | tgt_types=None, 32 | enable_torch_version=None): 33 | super().__init__() 34 | self.embed_dim = embed_dim 35 | self.kdim = kdim if kdim is not None else embed_dim 36 | self.vdim = vdim if vdim is not None else embed_dim 37 | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim 38 | 39 | self.num_heads = num_heads 40 | self.dropout = dropout 41 | self.head_dim = embed_dim // num_heads 42 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 43 | self.scaling = self.head_dim**-0.5 44 | 45 | self.self_attention = self_attention 46 | self.encoder_decoder_attention = encoder_decoder_attention 47 | 48 | assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ 49 | 'value to be of the same size' 50 | 51 | self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) 52 | self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) 53 | if tgt_types is not None and isinstance(tgt_types, dict) and len(tgt_types) > 1: 54 | self.q_proj = nn.ModuleList( 55 | [nn.Linear(embed_dim, embed_dim, bias=bias) for _ in range(len(tgt_types))]) 56 | self.tgt_types = tgt_types 57 | else: 58 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 59 | self.tgt_types = None 60 | 61 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 62 | 63 | if add_bias_kv: 64 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 65 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 66 | else: 67 | self.bias_k = self.bias_v = None 68 | 69 | self.add_zero_attn = add_zero_attn 70 | 71 | self.reset_parameters() 72 | 73 | self.onnx_trace = False 74 | 75 | if enable_torch_version is not None: 76 | self.enable_torch_version = enable_torch_version 77 | else: 78 | self.enable_torch_version = False 79 | if hasattr(F, "multi_head_attention_forward"): 80 | self.enable_torch_version = True 81 | else: 82 | self.enable_torch_version = False 83 | 84 | def prepare_for_onnx_export_(self): 85 | self.onnx_trace = True 86 | 87 | def reset_parameters(self): 88 | if self.qkv_same_dim: 89 | # Empirically observed the convergence to be much better with 90 | # the scaled initialization 91 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) 92 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) 93 | if self.tgt_types is not None: 94 | for q_proj in self.q_proj: 95 | nn.init.xavier_uniform_(q_proj.weight, gain=1 / math.sqrt(2)) 96 | else: 97 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) 98 | else: 99 | nn.init.xavier_uniform_(self.k_proj.weight) 100 | nn.init.xavier_uniform_(self.v_proj.weight) 101 | if self.tgt_types is not None: 102 | for q_proj in self.q_proj: 103 | nn.init.xavier_uniform_(q_proj.weight, gain=1 / math.sqrt(2)) 104 | else: 105 | nn.init.xavier_uniform_(self.q_proj.weight) 106 | 107 | nn.init.xavier_uniform_(self.out_proj.weight) 108 | nn.init.constant_(self.out_proj.bias, 0.) 109 | if self.bias_k is not None: 110 | nn.init.xavier_normal_(self.bias_k) 111 | if self.bias_v is not None: 112 | nn.init.xavier_normal_(self.bias_v) 113 | 114 | def forward( 115 | self, 116 | query, 117 | key, 118 | value, 119 | key_padding_mask=None, 120 | incremental_state=None, 121 | need_weights=True, 122 | static_kv=False, 123 | attn_mask=None, 124 | before_softmax=False, 125 | need_head_weights=False, 126 | tgt_type=None, 127 | ): 128 | """Input shape: Time x Batch x Channel 129 | Args: 130 | key_padding_mask (ByteTensor, optional): mask to exclude 131 | keys that are pads, of shape `(batch, src_len)`, where 132 | padding elements are indicated by 1s. 133 | need_weights (bool, optional): return the attention weights, 134 | averaged over heads (default: False). 135 | attn_mask (ByteTensor, optional): typically used to 136 | implement causal attention, where the mask prevents the 137 | attention from looking forward in time (default: None). 138 | before_softmax (bool, optional): return the raw attention 139 | weights and values before the attention softmax. 140 | need_head_weights (bool, optional): return the attention 141 | weights for each head. Implies *need_weights*. Default: 142 | return the average attention weights over all heads. 143 | """ 144 | if need_head_weights: 145 | need_weights = True 146 | 147 | tgt_len, bsz, embed_dim = query.size() 148 | assert embed_dim == self.embed_dim 149 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 150 | 151 | if self.enable_torch_version and not self.onnx_trace and incremental_state is None and not static_kv: 152 | return F.multi_head_attention_forward(query, 153 | key, 154 | value, 155 | self.embed_dim, 156 | self.num_heads, 157 | torch.empty([0]), 158 | torch.cat((self.q_proj.bias, self.k_proj.bias, 159 | self.v_proj.bias)), 160 | self.bias_k, 161 | self.bias_v, 162 | self.add_zero_attn, 163 | self.dropout, 164 | self.out_proj.weight, 165 | self.out_proj.bias, 166 | self.training, 167 | key_padding_mask, 168 | need_weights, 169 | attn_mask, 170 | use_separate_proj_weight=True, 171 | q_proj_weight=self.q_proj.weight, 172 | k_proj_weight=self.k_proj.weight, 173 | v_proj_weight=self.v_proj.weight) 174 | 175 | if incremental_state is not None: 176 | saved_state = self._get_input_buffer(incremental_state) 177 | if 'prev_key' in saved_state: 178 | # previous time steps are cached - no need to recompute 179 | # key and value if they are static 180 | if static_kv: 181 | assert self.encoder_decoder_attention and not self.self_attention 182 | key = value = None 183 | else: 184 | saved_state = None 185 | 186 | if self.self_attention: 187 | # q = self.q_proj(query) 188 | k = self.k_proj(query) 189 | v = self.v_proj(query) 190 | elif self.encoder_decoder_attention: 191 | # encoder-decoder attention 192 | # q = self.q_proj(query) 193 | if key is None: 194 | assert value is None 195 | k = v = None 196 | else: 197 | k = self.k_proj(key) 198 | v = self.v_proj(key) 199 | 200 | else: 201 | # q = self.q_proj(query) 202 | k = self.k_proj(key) 203 | v = self.v_proj(value) 204 | if tgt_type is not None and self.tgt_types is not None: 205 | q = self.q_proj[self.tgt_types[tgt_type]](query) 206 | else: 207 | q = self.q_proj(query) 208 | q *= self.scaling 209 | 210 | if self.bias_k is not None: 211 | assert self.bias_v is not None 212 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 213 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 214 | if attn_mask is not None: 215 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 216 | if key_padding_mask is not None: 217 | key_padding_mask = torch.cat( 218 | [key_padding_mask, 219 | key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], 220 | dim=1) 221 | 222 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 223 | if k is not None: 224 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 225 | if v is not None: 226 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 227 | 228 | if saved_state is not None: 229 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) 230 | if 'prev_key' in saved_state: 231 | prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim) 232 | if static_kv: 233 | k = prev_key 234 | else: 235 | k = torch.cat((prev_key, k), dim=1) 236 | if 'prev_value' in saved_state: 237 | prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim) 238 | if static_kv: 239 | v = prev_value 240 | else: 241 | v = torch.cat((prev_value, v), dim=1) 242 | key_padding_mask = self._append_prev_key_padding_mask( 243 | key_padding_mask=key_padding_mask, 244 | prev_key_padding_mask=saved_state.get('prev_key_padding_mask', None), 245 | batch_size=bsz, 246 | src_len=k.size(1), 247 | static_kv=static_kv, 248 | ) 249 | 250 | saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) 251 | saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) 252 | saved_state['prev_key_padding_mask'] = key_padding_mask 253 | 254 | self._set_input_buffer(incremental_state, saved_state) 255 | 256 | src_len = k.size(1) 257 | 258 | # This is part of a workaround to get around fork/join parallelism 259 | # not supporting Optional types. 260 | if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]): 261 | key_padding_mask = None 262 | 263 | if key_padding_mask is not None: 264 | assert key_padding_mask.size(0) == bsz 265 | assert key_padding_mask.size(1) == src_len 266 | 267 | if self.add_zero_attn: 268 | src_len += 1 269 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 270 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 271 | if attn_mask is not None: 272 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 273 | if key_padding_mask is not None: 274 | key_padding_mask = torch.cat([ 275 | key_padding_mask, 276 | torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask) 277 | ], 278 | dim=1) 279 | 280 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 281 | attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) 282 | 283 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 284 | 285 | if attn_mask is not None: 286 | attn_mask = attn_mask.unsqueeze(0) 287 | if self.onnx_trace: 288 | attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) 289 | attn_weights += attn_mask 290 | 291 | if key_padding_mask is not None: 292 | # don't attend to padding symbols 293 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 294 | attn_weights = attn_weights.masked_fill( 295 | key_padding_mask.unsqueeze(1).unsqueeze(2), 296 | float('-inf'), 297 | ) 298 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 299 | 300 | if before_softmax: 301 | return attn_weights, v 302 | 303 | attn_weights_float = utils.softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) 304 | attn_weights = attn_weights_float.type_as(attn_weights) 305 | attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), 306 | p=self.dropout, 307 | training=self.training) 308 | 309 | attn = torch.bmm(attn_probs, v) 310 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 311 | if (self.onnx_trace and attn.size(1) == 1): 312 | # when ONNX tracing a single decoder step (sequence length == 1) 313 | # the transpose is a no-op copy before view, thus unnecessary 314 | attn = attn.contiguous().view(tgt_len, bsz, embed_dim) 315 | else: 316 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 317 | attn = self.out_proj(attn) 318 | 319 | if need_weights: 320 | attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, 321 | src_len).transpose(1, 0) 322 | if not need_head_weights: 323 | # average attention weights over heads 324 | attn_weights = attn_weights.mean(dim=0) 325 | else: 326 | attn_weights = None 327 | 328 | return attn, attn_weights 329 | 330 | @staticmethod 331 | def _append_prev_key_padding_mask( 332 | key_padding_mask, 333 | prev_key_padding_mask, 334 | batch_size, 335 | src_len, 336 | static_kv, 337 | ): 338 | # saved key padding masks have shape (bsz, seq_len) 339 | if prev_key_padding_mask is not None and static_kv: 340 | key_padding_mask = prev_key_padding_mask 341 | elif prev_key_padding_mask is not None and key_padding_mask is not None: 342 | key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1) 343 | # During incremental decoding, as the padding token enters and 344 | # leaves the frame, there will be a time when prev or current 345 | # is None 346 | elif prev_key_padding_mask is not None: 347 | filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1)).bool() 348 | if prev_key_padding_mask.is_cuda: 349 | filler = filler.cuda() 350 | key_padding_mask = torch.cat((prev_key_padding_mask, filler), dim=1) 351 | elif key_padding_mask is not None: 352 | filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)).bool() 353 | if key_padding_mask.is_cuda: 354 | filler = filler.cuda() 355 | key_padding_mask = torch.cat((filler, key_padding_mask), dim=1) 356 | return key_padding_mask 357 | 358 | def reorder_incremental_state(self, incremental_state, new_order): 359 | """Reorder buffered internal state (for incremental generation).""" 360 | input_buffer = self._get_input_buffer(incremental_state) 361 | if input_buffer is not None: 362 | for k in input_buffer.keys(): 363 | if input_buffer[k] is not None: 364 | input_buffer[k] = input_buffer[k].index_select(0, new_order) 365 | self._set_input_buffer(incremental_state, input_buffer) 366 | 367 | def _get_input_buffer(self, incremental_state): 368 | return utils.get_incremental_state( 369 | self, 370 | incremental_state, 371 | 'attn_state', 372 | ) or {} 373 | 374 | def _set_input_buffer(self, incremental_state, buffer): 375 | utils.set_incremental_state( 376 | self, 377 | incremental_state, 378 | 'attn_state', 379 | buffer, 380 | ) 381 | 382 | def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): 383 | return attn_weights 384 | 385 | def upgrade_state_dict_named(self, state_dict, name): 386 | prefix = name + '.' if name != '' else '' 387 | items_to_add = {} 388 | keys_to_remove = [] 389 | for k in state_dict.keys(): 390 | if k.endswith(prefix + 'in_proj_weight'): 391 | # in_proj_weight used to be q + k + v with same dimensions 392 | dim = int(state_dict[k].shape[0] / 3) 393 | items_to_add[prefix + 'q_proj.weight'] = state_dict[k][:dim] 394 | items_to_add[prefix + 'k_proj.weight'] = state_dict[k][dim:2 * dim] 395 | items_to_add[prefix + 'v_proj.weight'] = state_dict[k][2 * dim:] 396 | 397 | keys_to_remove.append(k) 398 | 399 | k_bias = prefix + 'in_proj_bias' 400 | if k_bias in state_dict.keys(): 401 | dim = int(state_dict[k].shape[0] / 3) 402 | items_to_add[prefix + 'q_proj.bias'] = state_dict[k_bias][:dim] 403 | items_to_add[prefix + 'k_proj.bias'] = state_dict[k_bias][dim:2 * dim] 404 | items_to_add[prefix + 'v_proj.bias'] = state_dict[k_bias][2 * dim:] 405 | 406 | keys_to_remove.append(prefix + 'in_proj_bias') 407 | 408 | for k in keys_to_remove: 409 | del state_dict[k] 410 | 411 | for key, value in items_to_add.items(): 412 | state_dict[key] = value -------------------------------------------------------------------------------- /tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from tokenization_utils import PreTrainedTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 37 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 38 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 39 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 40 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 41 | 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 42 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 43 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 44 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 45 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 46 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 47 | } 48 | } 49 | 50 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 51 | 'bert-base-uncased': 512, 52 | 'bert-large-uncased': 512, 53 | 'bert-base-cased': 512, 54 | 'bert-large-cased': 512, 55 | 'bert-base-multilingual-uncased': 512, 56 | 'bert-base-multilingual-cased': 512, 57 | 'bert-base-chinese': 512, 58 | 'bert-base-german-cased': 512, 59 | 'bert-large-uncased-whole-word-masking': 512, 60 | 'bert-large-cased-whole-word-masking': 512, 61 | 'bert-large-uncased-whole-word-masking-finetuned-squad': 512, 62 | 'bert-large-cased-whole-word-masking-finetuned-squad': 512, 63 | 'bert-base-cased-finetuned-mrpc': 512, 64 | } 65 | 66 | def load_vocab(vocab_file): 67 | """Loads a vocabulary file into a dictionary.""" 68 | vocab = collections.OrderedDict() 69 | with open(vocab_file, "r", encoding="utf-8") as reader: 70 | tokens = reader.readlines() 71 | for index, token in enumerate(tokens): 72 | token = token.rstrip('\n') 73 | vocab[token] = index 74 | return vocab 75 | 76 | 77 | def whitespace_tokenize(text): 78 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 79 | text = text.strip() 80 | if not text: 81 | return [] 82 | tokens = text.split() 83 | return tokens 84 | 85 | 86 | class BertTokenizer(PreTrainedTokenizer): 87 | r""" 88 | Constructs a BertTokenizer. 89 | :class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece 90 | Args: 91 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 92 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 93 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 94 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 95 | minimum of this value (if specified) and the underlying BERT model's sequence length. 96 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 97 | do_wordpiece_only=False 98 | """ 99 | 100 | vocab_files_names = VOCAB_FILES_NAMES 101 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 102 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 103 | 104 | def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, 105 | unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", 106 | mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs): 107 | """Constructs a BertTokenizer. 108 | Args: 109 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file 110 | **do_lower_case**: (`optional`) boolean (default True) 111 | Whether to lower case the input 112 | Only has an effect when do_basic_tokenize=True 113 | **do_basic_tokenize**: (`optional`) boolean (default True) 114 | Whether to do basic tokenization before wordpiece. 115 | **never_split**: (`optional`) list of string 116 | List of tokens which will never be split during tokenization. 117 | Only has an effect when do_basic_tokenize=True 118 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 119 | Whether to tokenize Chinese characters. 120 | This should likely be deactivated for Japanese: 121 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 122 | """ 123 | super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, 124 | pad_token=pad_token, cls_token=cls_token, 125 | mask_token=mask_token, **kwargs) 126 | if not os.path.isfile(vocab_file): 127 | raise ValueError( 128 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 129 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 130 | self.vocab = load_vocab(vocab_file) 131 | self.ids_to_tokens = collections.OrderedDict( 132 | [(ids, tok) for tok, ids in self.vocab.items()]) 133 | self.do_basic_tokenize = do_basic_tokenize 134 | if do_basic_tokenize: 135 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 136 | never_split=never_split, 137 | tokenize_chinese_chars=tokenize_chinese_chars) 138 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 139 | 140 | @property 141 | def vocab_size(self): 142 | return len(self.vocab) 143 | 144 | def _tokenize(self, text): 145 | split_tokens = [] 146 | if self.do_basic_tokenize: 147 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 148 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 149 | split_tokens.append(sub_token) 150 | else: 151 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 152 | return split_tokens 153 | 154 | def _convert_token_to_id(self, token): 155 | """ Converts a token (str/unicode) in an id using the vocab. """ 156 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 157 | 158 | def _convert_id_to_token(self, index): 159 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 160 | return self.ids_to_tokens.get(index, self.unk_token) 161 | 162 | def convert_tokens_to_string(self, tokens): 163 | """ Converts a sequence of tokens (string) in a single string. """ 164 | out_string = ' '.join(tokens).replace(' ##', '').strip() 165 | return out_string 166 | 167 | def save_vocabulary(self, vocab_path): 168 | """Save the tokenizer vocabulary to a directory or file.""" 169 | index = 0 170 | if os.path.isdir(vocab_path): 171 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) 172 | with open(vocab_file, "w", encoding="utf-8") as writer: 173 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 174 | if index != token_index: 175 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 176 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 177 | index = token_index 178 | writer.write(token + u'\n') 179 | index += 1 180 | return (vocab_file,) 181 | 182 | @classmethod 183 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 184 | """ Instantiate a BertTokenizer from pre-trained vocabulary files. 185 | """ 186 | if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: 187 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 188 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 189 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 190 | "you may want to check this behavior.") 191 | kwargs['do_lower_case'] = False 192 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 193 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 194 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 195 | "but you may want to check this behavior.") 196 | kwargs['do_lower_case'] = True 197 | 198 | return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 199 | 200 | 201 | class BasicTokenizer(object): 202 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 203 | 204 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 205 | """ Constructs a BasicTokenizer. 206 | Args: 207 | **do_lower_case**: Whether to lower case the input. 208 | **never_split**: (`optional`) list of str 209 | Kept for backward compatibility purposes. 210 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 211 | List of token not to split. 212 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 213 | Whether to tokenize Chinese characters. 214 | This should likely be deactivated for Japanese: 215 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 216 | """ 217 | if never_split is None: 218 | never_split = [] 219 | self.do_lower_case = do_lower_case 220 | self.never_split = never_split 221 | self.tokenize_chinese_chars = tokenize_chinese_chars 222 | 223 | def tokenize(self, text, never_split=None): 224 | """ Basic Tokenization of a piece of text. 225 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 226 | Args: 227 | **never_split**: (`optional`) list of str 228 | Kept for backward compatibility purposes. 229 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 230 | List of token not to split. 231 | """ 232 | never_split = self.never_split + (never_split if never_split is not None else []) 233 | text = self._clean_text(text) 234 | # This was added on November 1st, 2018 for the multilingual and Chinese 235 | # models. This is also applied to the English models now, but it doesn't 236 | # matter since the English models were not trained on any Chinese data 237 | # and generally don't have any Chinese data in them (there are Chinese 238 | # characters in the vocabulary because Wikipedia does have some Chinese 239 | # words in the English Wikipedia.). 240 | if self.tokenize_chinese_chars: 241 | text = self._tokenize_chinese_chars(text) 242 | orig_tokens = whitespace_tokenize(text) 243 | split_tokens = [] 244 | for token in orig_tokens: 245 | if self.do_lower_case and token not in never_split: 246 | token = token.lower() 247 | token = self._run_strip_accents(token) 248 | split_tokens.extend(self._run_split_on_punc(token)) 249 | 250 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 251 | return output_tokens 252 | 253 | def _run_strip_accents(self, text): 254 | """Strips accents from a piece of text.""" 255 | text = unicodedata.normalize("NFD", text) 256 | output = [] 257 | for char in text: 258 | cat = unicodedata.category(char) 259 | if cat == "Mn": 260 | continue 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _run_split_on_punc(self, text, never_split=None): 265 | """Splits punctuation on a piece of text.""" 266 | if never_split is not None and text in never_split: 267 | return [text] 268 | chars = list(text) 269 | i = 0 270 | start_new_word = True 271 | output = [] 272 | while i < len(chars): 273 | char = chars[i] 274 | if _is_punctuation(char): 275 | output.append([char]) 276 | start_new_word = True 277 | else: 278 | if start_new_word: 279 | output.append([]) 280 | start_new_word = False 281 | output[-1].append(char) 282 | i += 1 283 | 284 | return ["".join(x) for x in output] 285 | 286 | def _tokenize_chinese_chars(self, text): 287 | """Adds whitespace around any CJK character.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if self._is_chinese_char(cp): 292 | output.append(" ") 293 | output.append(char) 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | def _is_chinese_char(self, cp): 300 | """Checks whether CP is the codepoint of a CJK character.""" 301 | # This defines a "chinese character" as anything in the CJK Unicode block: 302 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 303 | # 304 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 305 | # despite its name. The modern Korean Hangul alphabet is a different block, 306 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 307 | # space-separated words, so they are not treated specially and handled 308 | # like the all of the other languages. 309 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 310 | (cp >= 0x3400 and cp <= 0x4DBF) or # 311 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 312 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 313 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 314 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 315 | (cp >= 0xF900 and cp <= 0xFAFF) or # 316 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 317 | return True 318 | 319 | return False 320 | 321 | def _clean_text(self, text): 322 | """Performs invalid character removal and whitespace cleanup on text.""" 323 | output = [] 324 | for char in text: 325 | cp = ord(char) 326 | if cp == 0 or cp == 0xfffd or _is_control(char): 327 | continue 328 | if _is_whitespace(char): 329 | output.append(" ") 330 | else: 331 | output.append(char) 332 | return "".join(output) 333 | 334 | 335 | class WordpieceTokenizer(object): 336 | """Runs WordPiece tokenization.""" 337 | 338 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 339 | self.vocab = vocab 340 | self.unk_token = unk_token 341 | self.max_input_chars_per_word = max_input_chars_per_word 342 | 343 | def tokenize(self, text): 344 | """Tokenizes a piece of text into its word pieces. 345 | This uses a greedy longest-match-first algorithm to perform tokenization 346 | using the given vocabulary. 347 | For example: 348 | input = "unaffable" 349 | output = ["un", "##aff", "##able"] 350 | Args: 351 | text: A single token or whitespace separated tokens. This should have 352 | already been passed through `BasicTokenizer`. 353 | Returns: 354 | A list of wordpiece tokens. 355 | """ 356 | 357 | output_tokens = [] 358 | for token in whitespace_tokenize(text): 359 | chars = list(token) 360 | if len(chars) > self.max_input_chars_per_word: 361 | output_tokens.append(self.unk_token) 362 | continue 363 | 364 | is_bad = False 365 | start = 0 366 | sub_tokens = [] 367 | while start < len(chars): 368 | end = len(chars) 369 | cur_substr = None 370 | while start < end: 371 | substr = "".join(chars[start:end]) 372 | if start > 0: 373 | substr = "##" + substr 374 | if substr in self.vocab: 375 | cur_substr = substr 376 | break 377 | end -= 1 378 | if cur_substr is None: 379 | is_bad = True 380 | break 381 | sub_tokens.append(cur_substr) 382 | start = end 383 | 384 | if is_bad: 385 | output_tokens.append(self.unk_token) 386 | else: 387 | output_tokens.extend(sub_tokens) 388 | return output_tokens 389 | 390 | 391 | def _is_whitespace(char): 392 | """Checks whether `chars` is a whitespace character.""" 393 | # \t, \n, and \r are technically contorl characters but we treat them 394 | # as whitespace since they are generally considered as such. 395 | if char == " " or char == "\t" or char == "\n" or char == "\r": 396 | return True 397 | cat = unicodedata.category(char) 398 | if cat == "Zs": 399 | return True 400 | return False 401 | 402 | 403 | def _is_control(char): 404 | """Checks whether `chars` is a control character.""" 405 | # These are technically control characters but we count them as whitespace 406 | # characters. 407 | if char == "\t" or char == "\n" or char == "\r": 408 | return False 409 | cat = unicodedata.category(char) 410 | if cat.startswith("C"): 411 | return True 412 | return False 413 | 414 | 415 | def _is_punctuation(char): 416 | """Checks whether `chars` is a punctuation character.""" 417 | cp = ord(char) 418 | # We treat all non-letter/number ASCII as punctuation. 419 | # Characters such as "^", "$", and "`" are not in the Unicode 420 | # Punctuation class but we treat them as punctuation anyways, for 421 | # consistency. 422 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 423 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 424 | return True 425 | cat = unicodedata.category(char) 426 | if cat.startswith("P"): 427 | return True 428 | return False -------------------------------------------------------------------------------- /mass/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Train a network across multiple GPUs. 8 | """ 9 | 10 | from collections import OrderedDict 11 | import contextlib 12 | from itertools import chain 13 | import math 14 | import os 15 | import sys 16 | 17 | import torch 18 | 19 | from fairseq import checkpoint_utils, distributed_utils, models, optim, utils 20 | from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter 21 | from fairseq.optim import lr_scheduler 22 | 23 | 24 | class Trainer(object): 25 | """Main class for data parallel training. 26 | This class supports synchronous distributed data parallel training, 27 | where multiple workers each have a full model replica and gradients 28 | are accumulated across workers before each update. We use 29 | :class:`~torch.nn.parallel.DistributedDataParallel` to handle 30 | communication of the gradients across workers. 31 | """ 32 | 33 | def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=None): 34 | self.args = args 35 | self.task = task 36 | 37 | # copy model and criterion to current device 38 | self._criterion = criterion 39 | self._model = model 40 | self.cuda = torch.cuda.is_available() and not args.cpu 41 | if args.fp16: 42 | self._criterion = self._criterion.half() 43 | self._model = self._model.half() 44 | if self.cuda: 45 | self._criterion = self._criterion.cuda() 46 | self._model = self._model.cuda() 47 | 48 | self._dummy_batch = dummy_batch 49 | self._oom_batch = oom_batch or dummy_batch 50 | 51 | self._lr_scheduler = None 52 | self._num_updates = 0 53 | self._optim_history = None 54 | self._optimizer = None 55 | self._prev_grad_norm = None 56 | self._wrapped_criterion = None 57 | self._wrapped_model = None 58 | 59 | # Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes. 60 | # It is less flexible and syncs only the default stats. 61 | self._all_reduce_list = [0.0] * 6 62 | self.fast_stat_sync = args.fast_stat_sync 63 | 64 | self.init_meters(args) 65 | 66 | def init_meters(self, args): 67 | self.meters = OrderedDict() 68 | self.meters['train_loss'] = AverageMeter() 69 | self.meters['train_nll_loss'] = AverageMeter() 70 | self.meters['valid_loss'] = AverageMeter() 71 | self.meters['valid_nll_loss'] = AverageMeter() 72 | self.meters['wps'] = TimeMeter() # words per second 73 | self.meters['ups'] = TimeMeter() # updates per second 74 | self.meters['wpb'] = AverageMeter() # words per batch 75 | self.meters['bsz'] = AverageMeter() # sentences per batch 76 | self.meters['gnorm'] = AverageMeter() # gradient norm 77 | self.meters['clip'] = AverageMeter() # % of updates clipped 78 | self.meters['oom'] = AverageMeter() # out of memory 79 | if args.fp16: 80 | self.meters['loss_scale'] = AverageMeter() # dynamic loss scale 81 | self.meters['wall'] = TimeMeter() # wall time in seconds 82 | self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds 83 | 84 | @property 85 | def criterion(self): 86 | if self._wrapped_criterion is None: 87 | if ( 88 | utils.has_parameters(self._criterion) 89 | and self.args.distributed_world_size > 1 90 | and not self.args.use_bmuf 91 | ): 92 | self._wrapped_criterion = models.DistributedFairseqModel( 93 | self.args, self._criterion 94 | ) 95 | else: 96 | self._wrapped_criterion = self._criterion 97 | return self._wrapped_criterion 98 | 99 | @property 100 | def model(self): 101 | if self._wrapped_model is None: 102 | if self.args.distributed_world_size > 1 and not self.args.use_bmuf: 103 | self._wrapped_model = models.DistributedFairseqModel( 104 | self.args, self._model, 105 | ) 106 | else: 107 | self._wrapped_model = self._model 108 | return self._wrapped_model 109 | 110 | @property 111 | def optimizer(self): 112 | if self._optimizer is None: 113 | self._build_optimizer() 114 | return self._optimizer 115 | 116 | @property 117 | def lr_scheduler(self): 118 | if self._lr_scheduler is None: 119 | self._build_optimizer() # this will initialize self._lr_scheduler 120 | return self._lr_scheduler 121 | 122 | def _build_optimizer(self): 123 | params = list( 124 | filter( 125 | lambda p: p.requires_grad, 126 | chain(self.model.parameters(), self.criterion.parameters()), 127 | ) 128 | ) 129 | 130 | if self.args.fp16: 131 | if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: 132 | print('| WARNING: your device does NOT support faster training with --fp16, ' 133 | 'please switch to FP32 which is likely to be faster') 134 | if self.args.memory_efficient_fp16: 135 | self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params) 136 | else: 137 | self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) 138 | else: 139 | if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: 140 | print('| NOTICE: your device may support faster training with --fp16') 141 | self._optimizer = optim.build_optimizer(self.args, params) 142 | 143 | if self.args.use_bmuf: 144 | self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) 145 | 146 | # We should initialize the learning rate scheduler immediately after 147 | # building the optimizer, so that the initial learning rate is set. 148 | self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) 149 | self._lr_scheduler.step_update(0) 150 | 151 | def save_checkpoint(self, filename, extra_state): 152 | """Save all training state in a checkpoint file.""" 153 | if distributed_utils.is_master(self.args): # only save one checkpoint 154 | extra_state['train_meters'] = self.meters 155 | checkpoint_utils.save_state( 156 | filename, self.args, self.get_model().state_dict(), self.get_criterion(), 157 | self.optimizer, self.lr_scheduler, self.get_num_updates(), 158 | self._optim_history, extra_state, 159 | ) 160 | 161 | def load_checkpoint( 162 | self, 163 | filename, 164 | reset_optimizer=False, 165 | reset_lr_scheduler=False, 166 | optimizer_overrides=None, 167 | reset_meters=False, 168 | ): 169 | """Load all training state from a checkpoint file.""" 170 | extra_state, self._optim_history, last_optim_state = None, [], None 171 | 172 | try: 173 | from fairseq.fb_pathmgr import fb_pathmgr 174 | bexists = fb_pathmgr.isfile(filename) 175 | except (ModuleNotFoundError, ImportError): 176 | bexists = os.path.exists(filename) 177 | 178 | if bexists: 179 | state = checkpoint_utils.load_checkpoint_to_cpu(filename) 180 | 181 | # load model parameters 182 | try: 183 | self.get_model().load_state_dict(state['model'], strict=True, args=self.args) 184 | if utils.has_parameters(self.get_criterion()): 185 | self.get_criterion().load_state_dict(state['criterion'], strict=True) 186 | except Exception: 187 | raise Exception( 188 | 'Cannot load model parameters from checkpoint {}; ' 189 | 'please ensure that the architectures match.'.format(filename) 190 | ) 191 | 192 | extra_state = state['extra_state'] 193 | self._optim_history = state['optimizer_history'] 194 | last_optim_state = state.get('last_optimizer_state', None) 195 | 196 | if last_optim_state is not None and not reset_optimizer: 197 | # rebuild optimizer after loading model, since params may have changed 198 | self._build_optimizer() 199 | 200 | # only reload optimizer and lr_scheduler if they match 201 | last_optim = self._optim_history[-1] 202 | assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \ 203 | 'Criterion does not match; please reset the optimizer (--reset-optimizer).' 204 | assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ 205 | 'Optimizer does not match; please reset the optimizer (--reset-optimizer).' 206 | 207 | if not reset_lr_scheduler: 208 | self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) 209 | self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) 210 | 211 | self.set_num_updates(last_optim['num_updates']) 212 | 213 | if extra_state is not None: 214 | epoch = extra_state['train_iterator']['epoch'] 215 | print('| loaded checkpoint {} (epoch {} @ {} updates)'.format( 216 | filename, epoch, self.get_num_updates())) 217 | 218 | self.lr_step(epoch) 219 | 220 | if 'train_meters' in extra_state and not reset_meters: 221 | self.meters.update(extra_state['train_meters']) 222 | del extra_state['train_meters'] 223 | 224 | # reset TimeMeters, since their start times don't make sense anymore 225 | for meter in self.meters.values(): 226 | if isinstance(meter, TimeMeter): 227 | meter.reset() 228 | else: 229 | print('| no existing checkpoint found {}'.format(filename)) 230 | 231 | return extra_state 232 | 233 | def get_train_iterator(self, epoch, combine=True, load_dataset=True, data_selector=None, shard_batch_itr=True): 234 | """Return an EpochBatchIterator over the training set for a given epoch.""" 235 | if load_dataset: 236 | print('| loading train data for epoch {}'.format(epoch)) 237 | self.task.load_dataset( 238 | self.args.train_subset, 239 | epoch=epoch, 240 | combine=combine, 241 | data_selector=data_selector, 242 | ) 243 | return self.task.get_batch_iterator( 244 | dataset=self.task.dataset(self.args.train_subset), 245 | max_tokens=self.args.max_tokens, 246 | max_sentences=self.args.max_sentences, 247 | max_positions=utils.resolve_max_positions( 248 | self.task.max_positions(), 249 | self.model.max_positions(), 250 | ), 251 | ignore_invalid_inputs=True, 252 | required_batch_size_multiple=self.args.required_batch_size_multiple, 253 | seed=self.args.seed, 254 | num_shards=self.args.distributed_world_size if shard_batch_itr else 1, 255 | shard_id=self.args.distributed_rank if shard_batch_itr else 0, 256 | num_workers=self.args.num_workers, 257 | epoch=epoch, 258 | ) 259 | 260 | def train_step(self, samples, dummy_batch=False, raise_oom=False): 261 | """Do forward, backward and parameter update.""" 262 | if self._dummy_batch is None: 263 | self._dummy_batch = samples[0] 264 | 265 | self._set_seed() 266 | self.model.train() 267 | self.criterion.train() 268 | self.zero_grad() 269 | 270 | if not dummy_batch: 271 | self.meters['train_wall'].start() 272 | 273 | # forward and backward pass 274 | logging_outputs, sample_sizes, ooms = [], [], 0 275 | for i, sample in enumerate(samples): 276 | sample = self._prepare_sample(sample) 277 | if sample is None: 278 | # when sample is None, run forward/backward on a dummy batch 279 | # and ignore the resulting gradients 280 | sample = self._prepare_sample(self._dummy_batch) 281 | ignore_grad = True 282 | else: 283 | ignore_grad = False 284 | 285 | def maybe_no_sync(): 286 | """ 287 | Whenever *samples* contains more than one mini-batch, we 288 | want to accumulate gradients locally and only call 289 | all-reduce in the last backwards pass. 290 | """ 291 | if ( 292 | self.args.distributed_world_size > 1 293 | and hasattr(self.model, 'no_sync') 294 | and i < len(samples) - 1 295 | ): 296 | return self.model.no_sync() 297 | else: 298 | return contextlib.ExitStack() # dummy contextmanager 299 | 300 | try: 301 | with maybe_no_sync(): 302 | # forward and backward 303 | loss, sample_size, logging_output = self.task.train_step( 304 | sample, self.model, self.criterion, self.optimizer, 305 | ignore_grad 306 | ) 307 | 308 | if not ignore_grad: 309 | logging_outputs.append(logging_output) 310 | sample_sizes.append(sample_size) 311 | 312 | if self.fast_stat_sync: 313 | self._all_reduce_list[0] += sample_size 314 | self._all_reduce_list[1] += logging_output.get('nsentences', 0.0) 315 | self._all_reduce_list[2] += logging_output.get('loss', 0.0) 316 | self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0) 317 | self._all_reduce_list[4] += logging_output.get('ntokens', 0.0) 318 | except RuntimeError as e: 319 | if 'out of memory' in str(e): 320 | self._log_oom(e) 321 | if raise_oom: 322 | raise e 323 | print("| WARNING: attempting to recover from OOM in forward/backward pass", 324 | file=sys.stderr) 325 | ooms += 1 326 | self.zero_grad() 327 | else: 328 | raise e 329 | 330 | if self.fast_stat_sync: 331 | self._all_reduce_list[5] += ooms 332 | 333 | 334 | if ooms > 0 and self._oom_batch is not None: 335 | self.handle_ooms(ooms) 336 | 337 | if dummy_batch: 338 | return None 339 | 340 | # gather logging outputs from all replicas 341 | if self.fast_stat_sync: 342 | # rework all_gather_list 343 | all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list) 344 | if self._sync_stats(): 345 | torch.distributed.all_reduce(all_reduce_list_tensor) 346 | # Normalize loss and nll_loss by "sample_size" 347 | # and convert to log base 2 348 | all_reduce_list_tensor[2:4].div_( 349 | ( 350 | all_reduce_list_tensor[0:1] * 351 | torch.log(torch.cuda.DoubleTensor([2])) 352 | ) 353 | ) 354 | self._all_reduce_list = all_reduce_list_tensor.tolist() 355 | logging_output = {} 356 | [ 357 | sample_size, 358 | logging_output['nsentences'], 359 | logging_output['loss'], 360 | logging_output['nll_loss'], 361 | logging_output['ntokens'], 362 | ooms, 363 | ] = self._all_reduce_list 364 | elif self._sync_stats(): 365 | logging_outputs, sample_sizes, ooms, prev_norms = \ 366 | zip(*distributed_utils.all_gather_list( 367 | [logging_outputs, sample_sizes, ooms, self._prev_grad_norm], 368 | )) 369 | logging_outputs = list(chain.from_iterable(logging_outputs)) 370 | sample_sizes = list(chain.from_iterable(sample_sizes)) 371 | ooms = sum(ooms) 372 | 373 | if not self.args.use_bmuf: 374 | assert ( 375 | all(norm == prev_norms[0] for norm in prev_norms) 376 | or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms) 377 | ), 'Fatal error: gradients are inconsistent between workers' 378 | 379 | self.meters['oom'].update(ooms, len(samples)) 380 | if ooms == self.args.distributed_world_size * len(samples): 381 | print('| WARNING: OOM in all workers, skipping update') 382 | self.zero_grad() 383 | return None 384 | 385 | if not self.fast_stat_sync: 386 | # aggregate logging outputs and sample sizes 387 | logging_output = self.task.aggregate_logging_outputs( 388 | logging_outputs, self.get_criterion() 389 | ) 390 | sample_size = self.task.grad_denom(sample_sizes, self.get_criterion()) 391 | 392 | if not all(k in logging_output for k in ['ntokens', 'nsentences']): 393 | raise Exception(( 394 | 'Please update the {}.aggregate_logging_outputs() method to ' 395 | 'return ntokens and nsentences' 396 | ).format(self.task.__class__.__name__)) 397 | 398 | try: 399 | # normalize grads by sample size 400 | if sample_size > 0: 401 | self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) 402 | 403 | # clip grads 404 | grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) 405 | self._prev_grad_norm = grad_norm 406 | 407 | # take an optimization step 408 | self.optimizer.step() 409 | self.set_num_updates(self.get_num_updates() + 1) 410 | 411 | # task specific update per step 412 | self.task.update_step(self._num_updates) 413 | 414 | # update meters 415 | ntokens = logging_output.get('ntokens', 0) 416 | nsentences = logging_output.get('nsentences', 0) 417 | self.meters['wps'].update(ntokens) 418 | self.meters['ups'].update(1.) 419 | self.meters['wpb'].update(ntokens) 420 | self.meters['bsz'].update(nsentences) 421 | self.meters['gnorm'].update(grad_norm) 422 | self.meters['clip'].update( 423 | 1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0. 424 | ) 425 | self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) 426 | if 'train_acc' in self.meters: 427 | self.meters['train_acc'].update( 428 | logging_output.get('acc', 0), sample_size) 429 | 430 | if 'nll_loss' in logging_output: 431 | self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) 432 | 433 | # clear CUDA cache to reduce memory fragmentation 434 | if (self.args.empty_cache_freq > 0 and 435 | ((self.get_num_updates() + self.args.empty_cache_freq - 1) % 436 | self.args.empty_cache_freq) == 0 and 437 | torch.cuda.is_available() and 438 | not self.args.cpu): 439 | torch.cuda.empty_cache() 440 | except OverflowError as e: 441 | print('| WARNING: overflow detected, ' + str(e)) 442 | self.zero_grad() 443 | logging_output = None 444 | except RuntimeError as e: 445 | if 'out of memory' in str(e): 446 | self._log_oom(e) 447 | print('| ERROR: OOM during optimization, irrecoverable') 448 | raise e 449 | 450 | if self.args.fp16: 451 | self.meters['loss_scale'].reset() 452 | self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) 453 | 454 | self.clear_buffered_stats() 455 | self.meters['train_wall'].stop() 456 | 457 | return logging_output 458 | 459 | def valid_step(self, sample, raise_oom=False): 460 | """Do forward pass in evaluation mode.""" 461 | with torch.no_grad(): 462 | self.model.eval() 463 | self.criterion.eval() 464 | 465 | sample = self._prepare_sample(sample) 466 | if sample is None: 467 | sample = self._prepare_sample(self._dummy_batch) 468 | ignore_results = True 469 | else: 470 | ignore_results = False 471 | 472 | try: 473 | _loss, sample_size, logging_output = self.task.valid_step( 474 | sample, self.model, self.criterion 475 | ) 476 | except RuntimeError as e: 477 | if 'out of memory' in str(e): 478 | self._log_oom(e) 479 | if not raise_oom: 480 | print('| WARNING: ran out of memory in validation step, retrying batch') 481 | for p in self.model.parameters(): 482 | if p.grad is not None: 483 | p.grad = None # free some memory 484 | if self.cuda: 485 | torch.cuda.empty_cache() 486 | return self.valid_step(sample, raise_oom=True) 487 | raise e 488 | 489 | if ignore_results: 490 | logging_output, sample_size = {}, 0 491 | 492 | # gather logging outputs from all replicas 493 | if self.args.distributed_world_size > 1: 494 | logging_output, sample_size = zip(*distributed_utils.all_gather_list( 495 | [logging_output, sample_size], 496 | )) 497 | logging_output = list(logging_output) 498 | sample_size = list(sample_size) 499 | else: 500 | logging_output = [logging_output] 501 | sample_size = [sample_size] 502 | 503 | # aggregate logging outputs and sample sizes 504 | logging_output = self.task.aggregate_logging_outputs( 505 | logging_output, self.get_criterion() 506 | ) 507 | sample_size = self.task.grad_denom( 508 | sample_size, self.get_criterion() 509 | ) 510 | 511 | # update meters for validation 512 | ntokens = logging_output.get('ntokens', 0) 513 | self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size) 514 | if 'valid_acc' in self.meters: 515 | self.meters['valid_acc'].update( 516 | logging_output.get('acc', 0), sample_size) 517 | 518 | if 'nll_loss' in logging_output: 519 | self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) 520 | 521 | return logging_output 522 | 523 | def dummy_train_step(self, dummy_batch): 524 | """Dummy training step for warming caching allocator.""" 525 | self.train_step(dummy_batch, dummy_batch=True) 526 | self.zero_grad() 527 | 528 | def handle_ooms(self, number_of_ooms): 529 | """ 530 | c10d accumulates/syncs gradients between gpus during backward pass. 531 | In case of OOMs, gpus may fail to sync, so we manually iterate 532 | extra to make sure each gpu makes same number of iterations. 533 | """ 534 | for _ in range(number_of_ooms): 535 | self.train_step([self._oom_batch], True) 536 | 537 | def zero_grad(self): 538 | self.optimizer.zero_grad() 539 | 540 | def clear_buffered_stats(self): 541 | self._all_reduce_list = [0.0] * 6 542 | 543 | def lr_step(self, epoch, val_loss=None): 544 | """Adjust the learning rate based on the validation loss.""" 545 | self.lr_scheduler.step(epoch, val_loss) 546 | # prefer updating the LR based on the number of steps 547 | return self.lr_step_update() 548 | 549 | def lr_step_update(self): 550 | """Update the learning rate after each update.""" 551 | return self.lr_scheduler.step_update(self.get_num_updates()) 552 | 553 | def get_lr(self): 554 | """Get the current learning rate.""" 555 | return self.optimizer.get_lr() 556 | 557 | def get_model(self): 558 | """Get the (non-wrapped) model instance.""" 559 | return self._model 560 | 561 | def get_criterion(self): 562 | """Get the (non-wrapped) criterion instance.""" 563 | return self._criterion 564 | 565 | def get_meter(self, name): 566 | """Get a specific meter by name.""" 567 | if name not in self.meters: 568 | return None 569 | return self.meters[name] 570 | 571 | def get_num_updates(self): 572 | """Get the number of parameters updates.""" 573 | return self._num_updates 574 | 575 | def set_num_updates(self, num_updates): 576 | """Set the number of parameters updates.""" 577 | self._num_updates = num_updates 578 | self.lr_step_update() 579 | 580 | def _prepare_sample(self, sample): 581 | if sample is None or len(sample) == 0: 582 | return None 583 | 584 | if self.cuda: 585 | sample = utils.move_to_cuda(sample) 586 | 587 | def apply_half(t): 588 | if t.dtype is torch.float32: 589 | return t.half() 590 | return t 591 | 592 | if self.args.fp16: 593 | sample = utils.apply_to_sample(apply_half, sample) 594 | 595 | return sample 596 | 597 | def _set_seed(self): 598 | # Set seed based on args.seed and the update number so that we get 599 | # reproducible results when resuming from checkpoints 600 | seed = self.args.seed + self.get_num_updates() 601 | torch.manual_seed(seed) 602 | if self.cuda: 603 | torch.cuda.manual_seed(seed) 604 | 605 | def _sync_stats(self): 606 | return ( 607 | self.args.distributed_world_size > 1 and 608 | ( 609 | (not self.args.use_bmuf) or 610 | ( 611 | self.args.use_bmuf 612 | and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 613 | ) 614 | ) 615 | ) 616 | 617 | def _log_oom(self, exc): 618 | msg = '| OOM: Ran out of memory with exception: {}'.format(exc) 619 | # TODO: print should really go to logger, this print goes 620 | # to stderr, which is buffered, which in many cases is not 621 | # printed out if another exception happens. 622 | # NB(jerry): added a flush to mitigate this 623 | print(msg, file=sys.stderr) 624 | if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"): 625 | for device_idx in range(torch.cuda.device_count()): 626 | print(torch.cuda.memory_summary(device=device_idx), 627 | file=sys.stderr) 628 | sys.stderr.flush() -------------------------------------------------------------------------------- /mass/sequence_generator.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from fairseq import search, utils 6 | from fairseq.data import data_utils 7 | from fairseq.models import FairseqIncrementalDecoder 8 | 9 | 10 | class SequenceGenerator(object): 11 | def __init__( 12 | self, 13 | tgt_dict, 14 | beam_size=1, 15 | max_len_a=0, 16 | max_len_b=200, 17 | min_len=1, 18 | normalize_scores=True, 19 | len_penalty=1., 20 | unk_penalty=0., 21 | retain_dropout=False, 22 | sampling=False, 23 | sampling_topk=-1, 24 | sampling_topp=-1.0, 25 | temperature=1., 26 | diverse_beam_groups=-1, 27 | diverse_beam_strength=0.5, 28 | match_source_len=False, 29 | no_repeat_ngram_size=0, 30 | ): 31 | """Generates translations of a given source sentence. 32 | Args: 33 | tgt_dict (~fairseq.data.Dictionary): target dictionary 34 | beam_size (int, optional): beam width (default: 1) 35 | max_len_a/b (int, optional): generate sequences of maximum length 36 | ax + b, where x is the source length 37 | min_len (int, optional): the minimum length of the generated output 38 | (not including end-of-sentence) 39 | normalize_scores (bool, optional): normalize scores by the length 40 | of the output (default: True) 41 | len_penalty (float, optional): length penalty, where <1.0 favors 42 | shorter, >1.0 favors longer sentences (default: 1.0) 43 | unk_penalty (float, optional): unknown word penalty, where <0 44 | produces more unks, >0 produces fewer (default: 0.0) 45 | retain_dropout (bool, optional): use dropout when generating 46 | (default: False) 47 | sampling (bool, optional): sample outputs instead of beam search 48 | (default: False) 49 | sampling_topk (int, optional): only sample among the top-k choices 50 | at each step (default: -1) 51 | sampling_topp (float, optional): only sample among the smallest set 52 | of words whose cumulative probability mass exceeds p 53 | at each step (default: -1.0) 54 | temperature (float, optional): temperature, where values 55 | >1.0 produce more uniform samples and values <1.0 produce 56 | sharper samples (default: 1.0) 57 | diverse_beam_groups/strength (float, optional): parameters for 58 | Diverse Beam Search sampling 59 | match_source_len (bool, optional): outputs should match the source 60 | length (default: False) 61 | """ 62 | self.pad = tgt_dict.pad() 63 | self.unk = tgt_dict.unk() 64 | self.eos = tgt_dict.eos() 65 | self.vocab_size = len(tgt_dict) 66 | self.beam_size = beam_size 67 | # the max beam size is the dictionary size - 1, since we never select pad 68 | self.beam_size = min(beam_size, self.vocab_size - 1) 69 | self.max_len_a = max_len_a 70 | self.max_len_b = max_len_b 71 | self.min_len = min_len 72 | self.normalize_scores = normalize_scores 73 | self.len_penalty = len_penalty 74 | self.unk_penalty = unk_penalty 75 | self.retain_dropout = retain_dropout 76 | self.temperature = temperature 77 | self.match_source_len = match_source_len 78 | self.no_repeat_ngram_size = no_repeat_ngram_size 79 | assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' 80 | assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' 81 | assert temperature > 0, '--temperature must be greater than 0' 82 | 83 | if sampling: 84 | self.search = search.Sampling(tgt_dict, sampling_topk, sampling_topp) 85 | elif diverse_beam_groups > 0: 86 | self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength) 87 | elif match_source_len: 88 | self.search = search.LengthConstrainedBeamSearch( 89 | tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, 90 | ) 91 | else: 92 | self.search = search.BeamSearch(tgt_dict) 93 | 94 | @torch.no_grad() 95 | def generate(self, models, sample, **kwargs): 96 | """Generate a batch of translations. 97 | Args: 98 | models (List[~fairseq.models.FairseqModel]): ensemble of models 99 | sample (dict): batch 100 | prefix_tokens (torch.LongTensor, optional): force decoder to begin 101 | with these tokens 102 | bos_token (int, optional): beginning of sentence token 103 | (default: self.eos) 104 | """ 105 | model = EnsembleModel(models) 106 | return self._generate(model, sample, **kwargs) 107 | 108 | @torch.no_grad() 109 | def _generate( 110 | self, 111 | model, 112 | sample, 113 | prefix_tokens=None, 114 | bos_token=None, 115 | **kwargs 116 | ): 117 | if not self.retain_dropout: 118 | model.eval() 119 | 120 | # model.forward normally channels prev_output_tokens into the decoder 121 | # separately, but SequenceGenerator directly calls model.encoder 122 | encoder_input = { 123 | k: v for k, v in sample['net_input'].items() 124 | if k != 'prev_output_tokens' 125 | } 126 | lang_pair = encoder_input['lang_pair'] 127 | 128 | src_tokens = encoder_input['src_tokens'] 129 | src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) 130 | input_size = src_tokens.size() 131 | # batch dimension goes first followed by source lengths 132 | bsz = input_size[0] 133 | src_len = input_size[1] 134 | beam_size = self.beam_size 135 | 136 | if self.match_source_len: 137 | max_len = src_lengths.max().item() 138 | else: 139 | max_len = min( 140 | int(self.max_len_a * src_len + self.max_len_b), 141 | # exclude the EOS marker 142 | model.max_decoder_positions() - 1, 143 | ) 144 | 145 | # compute the encoder output for each beam 146 | encoder_outs = model.forward_encoder(encoder_input) 147 | new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) 148 | new_order = new_order.to(src_tokens.device).long() 149 | encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) 150 | 151 | # initialize buffers 152 | scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0) 153 | scores_buf = scores.clone() 154 | tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) 155 | tokens_buf = tokens.clone() 156 | tokens[:, 0] = self.eos if bos_token is None else bos_token 157 | attn, attn_buf = None, None 158 | 159 | # The blacklist indicates candidates that should be ignored. 160 | # For example, suppose we're sampling and have already finalized 2/5 161 | # samples. Then the blacklist would mark 2 positions as being ignored, 162 | # so that we only finalize the remaining 3 samples. 163 | blacklist = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask 164 | 165 | # list of completed sentences 166 | finalized = [[] for i in range(bsz)] 167 | finished = [False for i in range(bsz)] 168 | num_remaining_sent = bsz 169 | 170 | # number of candidate hypos per step 171 | cand_size = 2 * beam_size # 2 x beam size in case half are EOS 172 | 173 | # offset arrays for converting between different indexing schemes 174 | bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) 175 | cand_offsets = torch.arange(0, cand_size).type_as(tokens) 176 | 177 | # helper function for allocating buffers on the fly 178 | buffers = {} 179 | 180 | def buffer(name, type_of=tokens): # noqa 181 | if name not in buffers: 182 | buffers[name] = type_of.new() 183 | return buffers[name] 184 | 185 | def is_finished(sent, step, unfin_idx): 186 | """ 187 | Check whether we've finished generation for a given sentence, by 188 | comparing the worst score among finalized hypotheses to the best 189 | possible score among unfinalized hypotheses. 190 | """ 191 | assert len(finalized[sent]) <= beam_size 192 | if len(finalized[sent]) == beam_size or step == max_len: 193 | return True 194 | return False 195 | 196 | def finalize_hypos(step, bbsz_idx, eos_scores): 197 | """ 198 | Finalize the given hypotheses at this step, while keeping the total 199 | number of finalized hypotheses per sentence <= beam_size. 200 | Note: the input must be in the desired finalization order, so that 201 | hypotheses that appear earlier in the input are preferred to those 202 | that appear later. 203 | Args: 204 | step: current time step 205 | bbsz_idx: A vector of indices in the range [0, bsz*beam_size), 206 | indicating which hypotheses to finalize 207 | eos_scores: A vector of the same size as bbsz_idx containing 208 | scores for each hypothesis 209 | """ 210 | assert bbsz_idx.numel() == eos_scores.numel() 211 | 212 | # clone relevant token and attention tensors 213 | tokens_clone = tokens.index_select(0, bbsz_idx) 214 | tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS 215 | assert not tokens_clone.eq(self.eos).any() 216 | tokens_clone[:, step] = self.eos 217 | attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None 218 | 219 | # compute scores per token position 220 | pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1] 221 | pos_scores[:, step] = eos_scores 222 | # convert from cumulative to per-position scores 223 | pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] 224 | 225 | # normalize sentence-level scores 226 | if self.normalize_scores: 227 | eos_scores /= (step + 1) ** self.len_penalty 228 | 229 | cum_unfin = [] 230 | prev = 0 231 | for f in finished: 232 | if f: 233 | prev += 1 234 | else: 235 | cum_unfin.append(prev) 236 | 237 | sents_seen = set() 238 | for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())): 239 | unfin_idx = idx // beam_size 240 | sent = unfin_idx + cum_unfin[unfin_idx] 241 | 242 | sents_seen.add((sent, unfin_idx)) 243 | 244 | if self.match_source_len and step > src_lengths[unfin_idx]: 245 | score = -math.inf 246 | 247 | def get_hypo(): 248 | 249 | if attn_clone is not None: 250 | # remove padding tokens from attn scores 251 | hypo_attn = attn_clone[i] 252 | else: 253 | hypo_attn = None 254 | 255 | return { 256 | 'tokens': tokens_clone[i], 257 | 'score': score, 258 | 'attention': hypo_attn, # src_len x tgt_len 259 | 'alignment': None, 260 | 'positional_scores': pos_scores[i], 261 | } 262 | 263 | if len(finalized[sent]) < beam_size: 264 | finalized[sent].append(get_hypo()) 265 | 266 | newly_finished = [] 267 | for sent, unfin_idx in sents_seen: 268 | # check termination conditions for this sentence 269 | if not finished[sent] and is_finished(sent, step, unfin_idx): 270 | finished[sent] = True 271 | newly_finished.append(unfin_idx) 272 | return newly_finished 273 | 274 | reorder_state = None 275 | batch_idxs = None 276 | for step in range(max_len + 1): # one extra step for EOS marker 277 | # reorder decoder internal states based on the prev choice of beams 278 | if reorder_state is not None: 279 | if batch_idxs is not None: 280 | # update beam indices to take into account removed sentences 281 | corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs) 282 | reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size) 283 | model.reorder_incremental_state(reorder_state) 284 | encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state) 285 | lprobs, avg_attn_scores = model.forward_decoder( 286 | tokens[:, :step + 1], encoder_outs=encoder_outs, temperature=self.temperature, lang_pair=lang_pair, 287 | ) 288 | 289 | lprobs[:, self.pad] = -math.inf # never select pad 290 | lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty 291 | 292 | # handle max length constraint 293 | if step >= max_len: 294 | lprobs[:, :self.eos] = -math.inf 295 | lprobs[:, self.eos + 1:] = -math.inf 296 | 297 | # handle prefix tokens (possibly with different lengths) 298 | if prefix_tokens is not None and step < prefix_tokens.size(1) and step < max_len: 299 | prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) 300 | prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) 301 | prefix_mask = prefix_toks.ne(self.pad) 302 | lprobs[prefix_mask] = -math.inf 303 | lprobs[prefix_mask] = lprobs[prefix_mask].scatter_( 304 | -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] 305 | ) 306 | # if prefix includes eos, then we should make sure tokens and 307 | # scores are the same across all beams 308 | eos_mask = prefix_toks.eq(self.eos) 309 | if eos_mask.any(): 310 | # validate that the first beam matches the prefix 311 | first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1] 312 | eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] 313 | target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] 314 | assert (first_beam == target_prefix).all() 315 | 316 | def replicate_first_beam(tensor, mask): 317 | tensor = tensor.view(-1, beam_size, tensor.size(-1)) 318 | tensor[mask] = tensor[mask][:, :1, :] 319 | return tensor.view(-1, tensor.size(-1)) 320 | 321 | # copy tokens, scores and lprobs from the first beam to all beams 322 | tokens = replicate_first_beam(tokens, eos_mask_batch_dim) 323 | scores = replicate_first_beam(scores, eos_mask_batch_dim) 324 | lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim) 325 | elif step < self.min_len: 326 | # minimum length constraint (does not apply if using prefix_tokens) 327 | lprobs[:, self.eos] = -math.inf 328 | 329 | if self.no_repeat_ngram_size > 0: 330 | # for each beam and batch sentence, generate a list of previous ngrams 331 | gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] 332 | for bbsz_idx in range(bsz * beam_size): 333 | gen_tokens = tokens[bbsz_idx].tolist() 334 | for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]): 335 | gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \ 336 | gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]] 337 | 338 | # Record attention scores 339 | if avg_attn_scores is not None: 340 | if attn is None: 341 | attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) 342 | attn_buf = attn.clone() 343 | attn[:, :, step + 1].copy_(avg_attn_scores) 344 | 345 | scores = scores.type_as(lprobs) 346 | scores_buf = scores_buf.type_as(lprobs) 347 | eos_bbsz_idx = buffer('eos_bbsz_idx') 348 | eos_scores = buffer('eos_scores', type_of=scores) 349 | 350 | self.search.set_src_lengths(src_lengths) 351 | 352 | if self.no_repeat_ngram_size > 0: 353 | def calculate_banned_tokens(bbsz_idx): 354 | # before decoding the next token, prevent decoding of ngrams that have already appeared 355 | ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist()) 356 | return gen_ngrams[bbsz_idx].get(ngram_index, []) 357 | 358 | if step + 2 - self.no_repeat_ngram_size >= 0: 359 | # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet 360 | banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)] 361 | else: 362 | banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)] 363 | 364 | for bbsz_idx in range(bsz * beam_size): 365 | lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf 366 | 367 | cand_scores, cand_indices, cand_beams = self.search.step( 368 | step, 369 | lprobs.view(bsz, -1, self.vocab_size), 370 | scores.view(bsz, beam_size, -1)[:, :, :step], 371 | ) 372 | 373 | # cand_bbsz_idx contains beam indices for the top candidate 374 | # hypotheses, with a range of values: [0, bsz*beam_size), 375 | # and dimensions: [bsz, cand_size] 376 | cand_bbsz_idx = cand_beams.add(bbsz_offsets) 377 | 378 | # finalize hypotheses that end in eos, except for blacklisted ones 379 | # or candidates with a score of -inf 380 | eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) 381 | eos_mask[:, :beam_size][blacklist] = 0 382 | 383 | # only consider eos when it's among the top beam_size indices 384 | torch.masked_select( 385 | cand_bbsz_idx[:, :beam_size], 386 | mask=eos_mask[:, :beam_size], 387 | out=eos_bbsz_idx, 388 | ) 389 | 390 | finalized_sents = set() 391 | if eos_bbsz_idx.numel() > 0: 392 | torch.masked_select( 393 | cand_scores[:, :beam_size], 394 | mask=eos_mask[:, :beam_size], 395 | out=eos_scores, 396 | ) 397 | finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores) 398 | num_remaining_sent -= len(finalized_sents) 399 | 400 | assert num_remaining_sent >= 0 401 | if num_remaining_sent == 0: 402 | break 403 | assert step < max_len 404 | 405 | if len(finalized_sents) > 0: 406 | new_bsz = bsz - len(finalized_sents) 407 | 408 | # construct batch_idxs which holds indices of batches to keep for the next pass 409 | batch_mask = cand_indices.new_ones(bsz) 410 | batch_mask[cand_indices.new(finalized_sents)] = 0 411 | batch_idxs = batch_mask.nonzero().squeeze(-1) 412 | 413 | eos_mask = eos_mask[batch_idxs] 414 | cand_beams = cand_beams[batch_idxs] 415 | bbsz_offsets.resize_(new_bsz, 1) 416 | cand_bbsz_idx = cand_beams.add(bbsz_offsets) 417 | cand_scores = cand_scores[batch_idxs] 418 | cand_indices = cand_indices[batch_idxs] 419 | if prefix_tokens is not None: 420 | prefix_tokens = prefix_tokens[batch_idxs] 421 | src_lengths = src_lengths[batch_idxs] 422 | blacklist = blacklist[batch_idxs] 423 | 424 | scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) 425 | scores_buf.resize_as_(scores) 426 | tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) 427 | tokens_buf.resize_as_(tokens) 428 | if attn is not None: 429 | attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1) 430 | attn_buf.resize_as_(attn) 431 | bsz = new_bsz 432 | else: 433 | batch_idxs = None 434 | 435 | # Set active_mask so that values > cand_size indicate eos or 436 | # blacklisted hypos and values < cand_size indicate candidate 437 | # active hypos. After this, the min values per row are the top 438 | # candidate active hypos. 439 | active_mask = buffer('active_mask') 440 | eos_mask[:, :beam_size] |= blacklist 441 | torch.add( 442 | eos_mask.type_as(cand_offsets) * cand_size, 443 | cand_offsets[:eos_mask.size(1)], 444 | out=active_mask, 445 | ) 446 | 447 | # get the top beam_size active hypotheses, which are just the hypos 448 | # with the smallest values in active_mask 449 | active_hypos, new_blacklist = buffer('active_hypos'), buffer('new_blacklist') 450 | torch.topk( 451 | active_mask, k=beam_size, dim=1, largest=False, 452 | out=(new_blacklist, active_hypos) 453 | ) 454 | 455 | # update blacklist to ignore any finalized hypos 456 | blacklist = new_blacklist.ge(cand_size)[:, :beam_size] 457 | assert (~blacklist).any(dim=1).all() 458 | 459 | active_bbsz_idx = buffer('active_bbsz_idx') 460 | torch.gather( 461 | cand_bbsz_idx, dim=1, index=active_hypos, 462 | out=active_bbsz_idx, 463 | ) 464 | active_scores = torch.gather( 465 | cand_scores, dim=1, index=active_hypos, 466 | out=scores[:, step].view(bsz, beam_size), 467 | ) 468 | 469 | active_bbsz_idx = active_bbsz_idx.view(-1) 470 | active_scores = active_scores.view(-1) 471 | 472 | # copy tokens and scores for active hypotheses 473 | torch.index_select( 474 | tokens[:, :step + 1], dim=0, index=active_bbsz_idx, 475 | out=tokens_buf[:, :step + 1], 476 | ) 477 | torch.gather( 478 | cand_indices, dim=1, index=active_hypos, 479 | out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], 480 | ) 481 | if step > 0: 482 | torch.index_select( 483 | scores[:, :step], dim=0, index=active_bbsz_idx, 484 | out=scores_buf[:, :step], 485 | ) 486 | torch.gather( 487 | cand_scores, dim=1, index=active_hypos, 488 | out=scores_buf.view(bsz, beam_size, -1)[:, :, step], 489 | ) 490 | 491 | # copy attention for active hypotheses 492 | if attn is not None: 493 | torch.index_select( 494 | attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, 495 | out=attn_buf[:, :, :step + 2], 496 | ) 497 | 498 | # swap buffers 499 | tokens, tokens_buf = tokens_buf, tokens 500 | scores, scores_buf = scores_buf, scores 501 | if attn is not None: 502 | attn, attn_buf = attn_buf, attn 503 | 504 | # reorder incremental state in decoder 505 | reorder_state = active_bbsz_idx 506 | 507 | # sort by score descending 508 | for sent in range(len(finalized)): 509 | finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) 510 | return finalized 511 | 512 | 513 | class EnsembleModel(torch.nn.Module): 514 | """A wrapper around an ensemble of models.""" 515 | 516 | def __init__(self, models): 517 | super().__init__() 518 | self.models = torch.nn.ModuleList(models) 519 | self.incremental_states = None 520 | if all(isinstance(m.decoder, FairseqIncrementalDecoder) for m in models): 521 | self.incremental_states = {m: {} for m in models} 522 | 523 | def has_encoder(self): 524 | return hasattr(self.models[0], 'encoder') 525 | 526 | def max_decoder_positions(self): 527 | return min(m.max_decoder_positions() for m in self.models) 528 | 529 | @torch.no_grad() 530 | def forward_encoder(self, encoder_input): 531 | if not self.has_encoder(): 532 | return None 533 | return [model.encoder(**encoder_input) for model in self.models] 534 | 535 | @torch.no_grad() 536 | def forward_decoder(self, tokens, encoder_outs, temperature=1., lang_pair=None): 537 | if len(self.models) == 1: 538 | return self._decode_one( 539 | tokens, 540 | self.models[0], 541 | encoder_outs[0] if self.has_encoder() else None, 542 | self.incremental_states, 543 | log_probs=True, 544 | temperature=temperature, 545 | lang_pair=lang_pair, 546 | ) 547 | 548 | log_probs = [] 549 | avg_attn = None 550 | for model, encoder_out in zip(self.models, encoder_outs): 551 | probs, attn = self._decode_one( 552 | tokens, 553 | model, 554 | encoder_out, 555 | self.incremental_states, 556 | log_probs=True, 557 | temperature=temperature, 558 | lang_pair=lang_pair, 559 | ) 560 | log_probs.append(probs) 561 | if attn is not None: 562 | if avg_attn is None: 563 | avg_attn = attn 564 | else: 565 | avg_attn.add_(attn) 566 | avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(len(self.models)) 567 | if avg_attn is not None: 568 | avg_attn.div_(len(self.models)) 569 | return avg_probs, avg_attn 570 | 571 | def _decode_one( 572 | self, tokens, model, encoder_out, incremental_states, log_probs, 573 | temperature=1., lang_pair=None, 574 | ): 575 | if self.incremental_states is not None: 576 | decoder_out = list(model.forward_decoder( 577 | tokens, encoder_out=encoder_out, incremental_state=self.incremental_states[model], lang_pair=lang_pair, 578 | )) 579 | else: 580 | decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out, lang_pair=lang_pair,)) 581 | decoder_out[0] = decoder_out[0][:, -1:, :] 582 | if temperature != 1.: 583 | decoder_out[0].div_(temperature) 584 | attn = decoder_out[1] 585 | if type(attn) is dict: 586 | attn = attn.get('attn', None) 587 | if attn is not None: 588 | attn = attn[:, -1, :] 589 | probs = model.get_normalized_probs(decoder_out, log_probs=log_probs) 590 | probs = probs[:, -1, :] 591 | return probs, attn 592 | 593 | def reorder_encoder_out(self, encoder_outs, new_order): 594 | if not self.has_encoder(): 595 | return 596 | return [ 597 | model.encoder.reorder_encoder_out(encoder_out, new_order) 598 | for model, encoder_out in zip(self.models, encoder_outs) 599 | ] 600 | 601 | def reorder_incremental_state(self, new_order): 602 | if self.incremental_states is None: 603 | return 604 | for model in self.models: 605 | model.decoder.reorder_incremental_state(self.incremental_states[model], new_order) -------------------------------------------------------------------------------- /tokenization_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import logging 20 | import os 21 | import json 22 | import six 23 | from io import open 24 | 25 | from file_utils import cached_path 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json' 30 | ADDED_TOKENS_FILE = 'added_tokens.json' 31 | 32 | 33 | class PreTrainedTokenizer(object): 34 | """ Base class for all tokenizers. 35 | Handle all the shared methods for tokenization and special tokens as well as methods dowloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary. 36 | This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...). 37 | Class attributes (overridden by derived classes): 38 | - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string). 39 | - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file. 40 | - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size. 41 | Parameters: 42 | - ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` 43 | - ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` 44 | - ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` 45 | - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` 46 | - ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` 47 | - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` 48 | - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` 49 | - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` 50 | """ 51 | vocab_files_names = {} 52 | pretrained_vocab_files_map = {} 53 | max_model_input_sizes = {} 54 | 55 | SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token", 56 | "pad_token", "cls_token", "mask_token", 57 | "additional_special_tokens"] 58 | 59 | @property 60 | def bos_token(self): 61 | """ Beginning of sentence token (string). Log an error if used while not having been set. """ 62 | if self._bos_token is None: 63 | logger.error("Using bos_token, but it is not set yet.") 64 | return self._bos_token 65 | 66 | @property 67 | def eos_token(self): 68 | """ End of sentence token (string). Log an error if used while not having been set. """ 69 | if self._eos_token is None: 70 | logger.error("Using eos_token, but it is not set yet.") 71 | return self._eos_token 72 | 73 | @property 74 | def unk_token(self): 75 | """ Unknown token (string). Log an error if used while not having been set. """ 76 | if self._unk_token is None: 77 | logger.error("Using unk_token, but it is not set yet.") 78 | return self._unk_token 79 | 80 | @property 81 | def sep_token(self): 82 | """ Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """ 83 | if self._sep_token is None: 84 | logger.error("Using sep_token, but it is not set yet.") 85 | return self._sep_token 86 | 87 | @property 88 | def pad_token(self): 89 | """ Padding token (string). Log an error if used while not having been set. """ 90 | if self._pad_token is None: 91 | logger.error("Using pad_token, but it is not set yet.") 92 | return self._pad_token 93 | 94 | @property 95 | def cls_token(self): 96 | """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ 97 | if self._cls_token is None: 98 | logger.error("Using cls_token, but it is not set yet.") 99 | return self._cls_token 100 | 101 | @property 102 | def mask_token(self): 103 | """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ 104 | if self._mask_token is None: 105 | logger.error("Using mask_token, but it is not set yet.") 106 | return self._mask_token 107 | 108 | @property 109 | def additional_special_tokens(self): 110 | """ All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """ 111 | if self._additional_special_tokens is None: 112 | logger.error("Using additional_special_tokens, but it is not set yet.") 113 | return self._additional_special_tokens 114 | 115 | @bos_token.setter 116 | def bos_token(self, value): 117 | self._bos_token = value 118 | 119 | @eos_token.setter 120 | def eos_token(self, value): 121 | self._eos_token = value 122 | 123 | @unk_token.setter 124 | def unk_token(self, value): 125 | self._unk_token = value 126 | 127 | @sep_token.setter 128 | def sep_token(self, value): 129 | self._sep_token = value 130 | 131 | @pad_token.setter 132 | def pad_token(self, value): 133 | self._pad_token = value 134 | 135 | @cls_token.setter 136 | def cls_token(self, value): 137 | self._cls_token = value 138 | 139 | @mask_token.setter 140 | def mask_token(self, value): 141 | self._mask_token = value 142 | 143 | @additional_special_tokens.setter 144 | def additional_special_tokens(self, value): 145 | self._additional_special_tokens = value 146 | 147 | def __init__(self, max_len=None, **kwargs): 148 | self._bos_token = None 149 | self._eos_token = None 150 | self._unk_token = None 151 | self._sep_token = None 152 | self._pad_token = None 153 | self._cls_token = None 154 | self._mask_token = None 155 | self._additional_special_tokens = [] 156 | 157 | self.max_len = max_len if max_len is not None else int(1e12) 158 | self.added_tokens_encoder = {} 159 | self.added_tokens_decoder = {} 160 | 161 | for key, value in kwargs.items(): 162 | if key in self.SPECIAL_TOKENS_ATTRIBUTES: 163 | if key == 'additional_special_tokens': 164 | assert isinstance(value, (list, tuple)) and all( 165 | isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value) 166 | else: 167 | assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) 168 | setattr(self, key, value) 169 | 170 | @classmethod 171 | def from_pretrained(cls, *inputs, **kwargs): 172 | r""" Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer. 173 | Parameters: 174 | pretrained_model_name_or_path: either: 175 | - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``. 176 | - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``. 177 | - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``. 178 | cache_dir: (`optional`) string: 179 | Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. 180 | inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. 181 | kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. 182 | Examples:: 183 | # We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer 184 | # Download vocabulary from S3 and cache. 185 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 186 | # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`) 187 | tokenizer = BertTokenizer.from_pretrained('./test/saved_model/') 188 | # If the tokenizer uses a single vocabulary file, you can point directly to this file 189 | tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt') 190 | # You can link tokens to special vocabulary when instantiating 191 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='') 192 | # You should be sure '' is in the vocabulary when doing that. 193 | # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead) 194 | assert tokenizer.unk_token == '' 195 | """ 196 | return cls._from_pretrained(*inputs, **kwargs) 197 | 198 | @classmethod 199 | def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 200 | cache_dir = kwargs.pop('cache_dir', None) 201 | 202 | s3_models = list(cls.max_model_input_sizes.keys()) 203 | vocab_files = {} 204 | if pretrained_model_name_or_path in s3_models: 205 | # Get the vocabulary from AWS S3 bucket 206 | for file_id, map_list in cls.pretrained_vocab_files_map.items(): 207 | vocab_files[file_id] = map_list[pretrained_model_name_or_path] 208 | else: 209 | # Get the vocabulary from local files 210 | logger.info( 211 | "Model name '{}' not found in model shortcut name list ({}). " 212 | "Assuming '{}' is a path or url to a directory containing tokenizer files.".format( 213 | pretrained_model_name_or_path, ', '.join(s3_models), 214 | pretrained_model_name_or_path)) 215 | 216 | # Look for the tokenizer main vocabulary files 217 | for file_id, file_name in cls.vocab_files_names.items(): 218 | if os.path.isdir(pretrained_model_name_or_path): 219 | # If a directory is provided we look for the standard filenames 220 | full_file_name = os.path.join(pretrained_model_name_or_path, file_name) 221 | else: 222 | # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file) 223 | full_file_name = pretrained_model_name_or_path 224 | if not os.path.exists(full_file_name): 225 | logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) 226 | full_file_name = None 227 | vocab_files[file_id] = full_file_name 228 | 229 | # Look for the additional tokens files 230 | all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE, 231 | 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE} 232 | 233 | # If a path to a file was provided, get the parent directory 234 | saved_directory = pretrained_model_name_or_path 235 | if os.path.exists(saved_directory) and not os.path.isdir(saved_directory): 236 | saved_directory = os.path.dirname(saved_directory) 237 | 238 | for file_id, file_name in all_vocab_files_names.items(): 239 | full_file_name = os.path.join(saved_directory, file_name) 240 | if not os.path.exists(full_file_name): 241 | logger.info("Didn't find file {}. We won't load it.".format(full_file_name)) 242 | full_file_name = None 243 | vocab_files[file_id] = full_file_name 244 | 245 | if all(full_file_name is None for full_file_name in vocab_files.values()): 246 | logger.error( 247 | "Model name '{}' was not found in model name list ({}). " 248 | "We assumed '{}' was a path or url but couldn't find tokenizer files" 249 | "at this path or url.".format( 250 | pretrained_model_name_or_path, ', '.join(s3_models), 251 | pretrained_model_name_or_path, )) 252 | return None 253 | 254 | # Get files from url, cache, or disk depending on the case 255 | try: 256 | resolved_vocab_files = {} 257 | for file_id, file_path in vocab_files.items(): 258 | if file_path is None: 259 | resolved_vocab_files[file_id] = None 260 | else: 261 | resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir) 262 | except EnvironmentError: 263 | if pretrained_model_name_or_path in s3_models: 264 | logger.error("Couldn't reach server to download vocabulary.") 265 | else: 266 | logger.error( 267 | "Model name '{}' was not found in model name list ({}). " 268 | "We assumed '{}' was a path or url but couldn't find files {} " 269 | "at this path or url.".format( 270 | pretrained_model_name_or_path, ', '.join(s3_models), 271 | pretrained_model_name_or_path, str(vocab_files.keys()))) 272 | return None 273 | 274 | for file_id, file_path in vocab_files.items(): 275 | if file_path == resolved_vocab_files[file_id]: 276 | logger.info("loading file {}".format(file_path)) 277 | else: 278 | logger.info("loading file {} from cache at {}".format( 279 | file_path, resolved_vocab_files[file_id])) 280 | 281 | # Set max length if needed 282 | if pretrained_model_name_or_path in cls.max_model_input_sizes: 283 | # if we're using a pretrained model, ensure the tokenizer 284 | # wont index sequences longer than the number of positional embeddings 285 | max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] 286 | if max_len is not None and isinstance(max_len, (int, float)): 287 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 288 | 289 | # Merge resolved_vocab_files arguments in kwargs. 290 | added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None) 291 | special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None) 292 | for args_name, file_path in resolved_vocab_files.items(): 293 | if args_name not in kwargs: 294 | kwargs[args_name] = file_path 295 | if special_tokens_map_file is not None: 296 | special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8")) 297 | for key, value in special_tokens_map.items(): 298 | if key not in kwargs: 299 | kwargs[key] = value 300 | 301 | # Instantiate tokenizer. 302 | tokenizer = cls(*inputs, **kwargs) 303 | 304 | # Add supplementary tokens. 305 | if added_tokens_file is not None: 306 | added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8")) 307 | added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} 308 | tokenizer.added_tokens_encoder.update(added_tok_encoder) 309 | tokenizer.added_tokens_decoder.update(added_tok_decoder) 310 | 311 | return tokenizer 312 | 313 | def save_pretrained(self, save_directory): 314 | """ Save the tokenizer vocabulary files (with added tokens) and the 315 | special-tokens-to-class-attributes-mapping to a directory. 316 | This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method. 317 | """ 318 | if not os.path.isdir(save_directory): 319 | logger.error("Saving directory ({}) should be a directory".format(save_directory)) 320 | return 321 | 322 | special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE) 323 | added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE) 324 | 325 | with open(special_tokens_map_file, 'w', encoding='utf-8') as f: 326 | f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) 327 | 328 | with open(added_tokens_file, 'w', encoding='utf-8') as f: 329 | if self.added_tokens_encoder: 330 | out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False) 331 | else: 332 | out_str = u"{}" 333 | f.write(out_str) 334 | 335 | vocab_files = self.save_vocabulary(save_directory) 336 | 337 | return vocab_files + (special_tokens_map_file, added_tokens_file) 338 | 339 | def save_vocabulary(self, save_directory): 340 | """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens 341 | and special token mappings. 342 | Please use :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method. 343 | """ 344 | raise NotImplementedError 345 | 346 | def vocab_size(self): 347 | """ Size of the base vocabulary (without the added tokens) """ 348 | raise NotImplementedError 349 | 350 | def __len__(self): 351 | """ Size of the full vocabulary with the added tokens """ 352 | return self.vocab_size + len(self.added_tokens_encoder) 353 | 354 | def add_tokens(self, new_tokens): 355 | """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the 356 | vocabulary, they are added to it with indices starting from length of the current vocabulary. 357 | Parameters: 358 | new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). 359 | Returns: 360 | Number of tokens added to the vocabulary. 361 | Examples:: 362 | # Let's see how to increase the vocabulary of Bert model and tokenizer 363 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 364 | model = BertModel.from_pretrained('bert-base-uncased') 365 | num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) 366 | print('We have added', num_added_toks, 'tokens') 367 | model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. 368 | """ 369 | if not new_tokens: 370 | return 0 371 | 372 | to_add_tokens = [] 373 | for token in new_tokens: 374 | assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode)) 375 | if token != self.unk_token and \ 376 | self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token): 377 | to_add_tokens.append(token) 378 | logger.info("Adding %s to the vocabulary", token) 379 | 380 | added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens)) 381 | added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} 382 | self.added_tokens_encoder.update(added_tok_encoder) 383 | self.added_tokens_decoder.update(added_tok_decoder) 384 | 385 | return len(to_add_tokens) 386 | 387 | def add_special_tokens(self, special_tokens_dict): 388 | """ Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them 389 | to class attributes. If special tokens are NOT in the vocabulary, they are added 390 | to it (indexed starting from the last index of the current vocabulary). 391 | Parameters: 392 | special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``]. 393 | 394 | Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). 395 | Returns: 396 | Number of tokens added to the vocabulary. 397 | Examples:: 398 | # Let's see how to add a new classification token to GPT-2 399 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 400 | model = GPT2Model.from_pretrained('gpt2') 401 | special_tokens_dict = {'cls_token': ''} 402 | num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) 403 | print('We have added', num_added_toks, 'tokens') 404 | model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. 405 | assert tokenizer.cls_token == '' 406 | """ 407 | if not special_tokens_dict: 408 | return 0 409 | 410 | added_tokens = 0 411 | for key, value in special_tokens_dict.items(): 412 | assert key in self.SPECIAL_TOKENS_ATTRIBUTES 413 | if key == 'additional_special_tokens': 414 | assert isinstance(value, (list, tuple)) and all( 415 | isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value) 416 | added_tokens += self.add_tokens(value) 417 | else: 418 | assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) 419 | added_tokens += self.add_tokens([value]) 420 | logger.info("Assigning %s to the %s key of the tokenizer", value, key) 421 | setattr(self, key, value) 422 | 423 | return added_tokens 424 | 425 | def tokenize(self, text, **kwargs): 426 | """ Converts a string in a sequence of tokens (string), using the tokenizer. 427 | Split in words for word-based vocabulary or sub-words for sub-word-based 428 | vocabularies (BPE/SentencePieces/WordPieces). 429 | Take care of added tokens. 430 | """ 431 | 432 | def split_on_tokens(tok_list, text): 433 | if not text: 434 | return [] 435 | if not tok_list: 436 | return self._tokenize(text, **kwargs) 437 | tok = tok_list[0] 438 | split_text = text.split(tok) 439 | return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \ 440 | for sub_text in split_text), [])[:-1] 441 | 442 | added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens 443 | tokenized_text = split_on_tokens(added_tokens, text) 444 | return tokenized_text 445 | 446 | def _tokenize(self, text, **kwargs): 447 | """ Converts a string in a sequence of tokens (string), using the tokenizer. 448 | Split in words for word-based vocabulary or sub-words for sub-word-based 449 | vocabularies (BPE/SentencePieces/WordPieces). 450 | Do NOT take care of added tokens. 451 | """ 452 | raise NotImplementedError 453 | 454 | def convert_tokens_to_ids(self, tokens): 455 | """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id 456 | (resp. a sequence of ids), using the vocabulary. 457 | """ 458 | if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): 459 | return self._convert_token_to_id_with_added_voc(tokens) 460 | 461 | ids = [] 462 | for token in tokens: 463 | ids.append(self._convert_token_to_id_with_added_voc(token)) 464 | if len(ids) > self.max_len: 465 | logger.warning("Token indices sequence length is longer than the specified maximum sequence length " 466 | "for this model ({} > {}). Running this sequence through the model will result in " 467 | "indexing errors".format(len(ids), self.max_len)) 468 | return ids 469 | 470 | def _convert_token_to_id_with_added_voc(self, token): 471 | if token in self.added_tokens_encoder: 472 | return self.added_tokens_encoder[token] 473 | return self._convert_token_to_id(token) 474 | 475 | def _convert_token_to_id(self, token): 476 | raise NotImplementedError 477 | 478 | def encode(self, text): 479 | """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. 480 | 481 | Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``. 482 | """ 483 | return self.convert_tokens_to_ids(self.tokenize(text)) 484 | 485 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 486 | """ Converts a single index or a sequence of indices (integers) in a token " 487 | (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens. 488 | Args: 489 | skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False 490 | """ 491 | if isinstance(ids, int): 492 | if ids in self.added_tokens_decoder: 493 | return self.added_tokens_decoder[ids] 494 | else: 495 | return self._convert_id_to_token(ids) 496 | tokens = [] 497 | for index in ids: 498 | if index in self.all_special_ids and skip_special_tokens: 499 | continue 500 | if index in self.added_tokens_decoder: 501 | tokens.append(self.added_tokens_decoder[index]) 502 | else: 503 | tokens.append(self._convert_id_to_token(index)) 504 | return tokens 505 | 506 | def _convert_id_to_token(self, index): 507 | raise NotImplementedError 508 | 509 | def convert_tokens_to_string(self, tokens): 510 | """ Converts a sequence of tokens (string) in a single string. 511 | The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids)) 512 | but we often want to remove sub-word tokenization artifacts at the same time. 513 | """ 514 | return ' '.join(self.convert_ids_to_tokens(tokens)) 515 | 516 | def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): 517 | """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary 518 | with options to remove special tokens and clean up tokenization spaces. 519 | Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. 520 | """ 521 | filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) 522 | text = self.convert_tokens_to_string(filtered_tokens) 523 | if clean_up_tokenization_spaces: 524 | text = self.clean_up_tokenization(text) 525 | return text 526 | 527 | @property 528 | def special_tokens_map(self): 529 | """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their 530 | values ('', ''...) 531 | """ 532 | set_attr = {} 533 | for attr in self.SPECIAL_TOKENS_ATTRIBUTES: 534 | attr_value = getattr(self, "_" + attr) 535 | if attr_value: 536 | set_attr[attr] = attr_value 537 | return set_attr 538 | 539 | @property 540 | def all_special_tokens(self): 541 | """ List all the special tokens ('', ''...) mapped to class attributes 542 | (cls_token, unk_token...). 543 | """ 544 | all_toks = [] 545 | set_attr = self.special_tokens_map 546 | for attr_value in set_attr.values(): 547 | all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value]) 548 | all_toks = list(set(all_toks)) 549 | return all_toks 550 | 551 | @property 552 | def all_special_ids(self): 553 | """ List the vocabulary indices of the special tokens ('', ''...) mapped to 554 | class attributes (cls_token, unk_token...). 555 | """ 556 | all_toks = self.all_special_tokens 557 | all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks) 558 | return all_ids 559 | 560 | @staticmethod 561 | def clean_up_tokenization(out_string): 562 | """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms. 563 | """ 564 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' 565 | ).replace(" ' ", 566 | "'").replace( 567 | " n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" 568 | ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", 569 | "'re") 570 | return out_string --------------------------------------------------------------------------------