├── src ├── s2s_ft │ ├── __init__.py │ ├── config.py │ ├── tokenization_minilm.py │ ├── tokenization_unilm.py │ ├── convert_state_dict.py │ ├── s2s_loader.py │ ├── configuration_minilm.py │ ├── configuration_unilm.py │ ├── utils.py │ └── modeling.py ├── decode.sh ├── finetune.sh ├── setup.py ├── gen_seq_from_trace.py ├── evaluations │ ├── my_eval_for_cnndm.py │ ├── eval_for_xsum.py │ ├── eval_for_cnndm.py │ └── bs_pyrouge.py ├── decode_seq2seq.py ├── README.md └── run_seq2seq.py ├── LICENSE └── README.md /src/s2s_ft/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/decode.sh: -------------------------------------------------------------------------------- 1 | # path of the fine-tuned checkpoint 2 | DIR=/dnn/sheng.s/dissUnilm/ 3 | kd_weight=0.6 4 | num_training_steps=45000 5 | 6 | MODEL_PATH=${DIR}distill_checkpoints_kd${kd_weight}_step${num_training_steps}/ckpt-${num_training_steps}/ 7 | #SPLIT=dev 8 | SPLIT=test 9 | # input file that you would like to decode 10 | INPUT_JSON=${DIR}cnndm.${SPLIT}.uncased_tokenized.json 11 | 12 | export CUDA_VISIBLE_DEVICES=6 13 | export OMP_NUM_THREADS=4 14 | export MKL_NUM_THREADS=4 15 | 16 | python decode_seq2seq.py \ 17 | --fp16 --model_type unilm --tokenizer_name unilm1.2-base-uncased --do_lower_case --input_file ${INPUT_JSON} --split $SPLIT \ 18 | --model_path ${MODEL_PATH} --max_seq_length 768 --max_tgt_length 160 --batch_size 32 --beam_size 5 \ 19 | --length_penalty 0.7 --forbid_duplicate_ngrams --mode s2s --forbid_ignore_word "." --min_len 48 20 | 21 | #SPLIT=dev 22 | GOLD_PATH=${DIR}${SPLIT}.target 23 | # ${MODEL_PATH}.${SPLIT} is the predicted target file 24 | python evaluations/eval_for_cnndm.py --pred ${MODEL_PATH}.${SPLIT} --gold ${GOLD_PATH} --split ${SPLIT} --trunc_len 160 --perl 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/finetune.sh: -------------------------------------------------------------------------------- 1 | # path of training data 2 | DIR=/dnn/sheng.s/dissUnilm/ 3 | kd_weight=0.6 4 | num_training_steps=60000 5 | export TEACHER_MODEL=${DIR}yang_cnndm_unilmv1.2.pt 6 | export TRAIN_FILE=${DIR}cnndm.train.uncased_tokenized.json 7 | # folder used to save fine-tuned checkpoints 8 | export OUTPUT_DIR=${DIR}distill_checkpoints_kd${kd_weight}_step${num_training_steps}_real 9 | # folder used to cache package dependencies 10 | export CACHE_DIR=${DIR}transformer_package_cache 11 | 12 | 13 | #export CUDA_VISIBLE_DEVICES=0,1,2,3 14 | export CUDA_VISIBLE_DEVICES=4,5,6,7 15 | python -m torch.distributed.launch --nproc_per_node=4 run_seq2seq.py \ 16 | --train_file $TRAIN_FILE --output_dir $OUTPUT_DIR \ 17 | --model_type unilm --model_name_or_path unilm1.2-base-uncased --do_lower_case --fp16 --fp16_opt_level O2 \ 18 | --max_source_seq_length 608 --max_target_seq_length 160 --per_gpu_train_batch_size 8 --gradient_accumulation_steps 2 \ 19 | --learning_rate 7e-5 --num_warmup_steps 1000 --num_training_steps $num_training_steps --cache_dir $CACHE_DIR --save_steps 1500 \ 20 | --use_distill 1 --kd_weight $kd_weight --teacher_model $TEACHER_MODEL --teacher_dropout 1 --min_lr 0 21 | -------------------------------------------------------------------------------- /src/setup.py: -------------------------------------------------------------------------------- 1 | from io import open 2 | from setuptools import find_packages, setup 3 | 4 | 5 | extras = { 6 | 'serving': ['pydantic', 'uvicorn', 'fastapi'], 7 | 'serving-tf': ['pydantic', 'uvicorn', 'fastapi'], 8 | 'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch'] 9 | } 10 | extras['all'] = [package for package in extras.values()] 11 | 12 | setup( 13 | name="s2s-ft", 14 | version="0.0.1", 15 | author="UniLM Team", 16 | author_email="unilm@microsoft.com", 17 | description="Fine-Tuning Bidirectional Transformers for Sequence-to-Sequence Learning", 18 | long_description=open("README.md", "r", encoding='utf-8').read(), 19 | long_description_content_type="text/markdown", 20 | keywords='Fine-Tuning Bidirectional Transformers for Sequence-to-Sequence Learning', 21 | license='Apache', 22 | url="https://github.com/microsoft/unilm/tree/master/s2s-ft", 23 | packages=find_packages(exclude=["*.tests", "*.tests.*", 24 | "tests.*", "tests"]), 25 | install_requires=['numpy', 26 | 'boto3', 27 | 'requests', 28 | 'tqdm', 29 | 'regex != 2019.12.17', 30 | 'sentencepiece', 31 | 'sacremoses', 32 | 'tensorboardX', 33 | 'transformers >= 2.3.0'], 34 | extras_require=extras, 35 | python_requires='>=3.5.0', 36 | classifiers=[ 37 | 'Programming Language :: Python :: 3', 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /src/s2s_ft/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import logging 4 | from transformers import BertConfig, RobertaConfig 5 | from s2s_ft.configuration_unilm import UnilmConfig 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class BertForSeq2SeqConfig(BertConfig): 11 | def __init__(self, label_smoothing=0.1, source_type_id=0, target_type_id=1, **kwargs): 12 | super(BertForSeq2SeqConfig, self).__init__(**kwargs) 13 | self.label_smoothing = label_smoothing 14 | self.source_type_id = source_type_id 15 | self.target_type_id = target_type_id 16 | 17 | @classmethod 18 | def from_exist_config(cls, config, label_smoothing=0.1, max_position_embeddings=None): 19 | required_keys = [ 20 | "vocab_size", "hidden_size", "num_hidden_layers", "num_attention_heads", 21 | "hidden_act", "intermediate_size", "hidden_dropout_prob", "attention_probs_dropout_prob", 22 | "max_position_embeddings", "type_vocab_size", "initializer_range", "layer_norm_eps"] 23 | 24 | kwargs = {} 25 | for key in required_keys: 26 | assert hasattr(config, key) 27 | kwargs[key] = getattr(config, key) 28 | 29 | kwargs["vocab_size_or_config_json_file"] = kwargs["vocab_size"] 30 | if isinstance(config, RobertaConfig): 31 | kwargs["type_vocab_size"] = 0 32 | kwargs["max_position_embeddings"] = kwargs["max_position_embeddings"] - 2 33 | 34 | additional_keys = [ 35 | "source_type_id", "target_type_id" 36 | ] 37 | for key in additional_keys: 38 | if hasattr(config, key): 39 | kwargs[key] = getattr(config, key) 40 | 41 | if max_position_embeddings is not None and max_position_embeddings > config.max_position_embeddings: 42 | kwargs["max_position_embeddings"] = max_position_embeddings 43 | logger.info(" ** Change max position embeddings to %d ** " % max_position_embeddings) 44 | 45 | return cls(label_smoothing=label_smoothing, **kwargs) 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Noisy Self-Knowledge Distillation for Text Summarization 2 | Codes for NAACL 2021 paper 'Noisy Self-Knowledge Distillation for Text Summarization' 3 | 4 | The code is based on UNILM, and summarization data can be download at (https://github.com/microsoft/unilm/tree/master/s2s-ft) 5 | 6 | 7 | 8 | 9 | ## Train teacher model 10 | 11 | MODEL_PATH=../models/xsum.unilm/ckpt-40000 12 | SPLIT=test 13 | INPUT_JSON=../data/xsum.test.uncased_tokenized.json 14 | 15 | export CUDA_VISIBLE_DEVICES=5 16 | export OMP_NUM_THREADS=4 17 | export MKL_NUM_THREADS=4 18 | 19 | python decode_seq2seq.py \ 20 | --fp16 --model_type unilm --tokenizer_name unilm1.2-base-uncased --input_file ${INPUT_JSON} --split $SPLIT --do_lower_case \ 21 | --model_path ${MODEL_PATH} --max_seq_length 512 --max_tgt_length 48 --batch_size 32 --beam_size 5 \ 22 | --length_penalty 0 --forbid_duplicate_ngrams --mode s2s --forbid_ignore_word "." 23 | 24 | ## Distill a student model 25 | 26 | TRAIN_FILE=../data/xsum.train.uncased_tokenized.json 27 | CACHE_DIR=../../cache 28 | OUTPUT_DIR=../models/xsum.unilm.distill 29 | TEACHER=../models/xsum.unilm/ckpt-40000/pytorch_model.bin 30 | 31 | BATCH_SIZE=8 32 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 33 | python -m torch.distributed.launch --nproc_per_node=8 --master_port 29886 run_seq2seq.py \ 34 | --train_file $TRAIN_FILE --output_dir $OUTPUT_DIR \ 35 | --model_type unilm --model_name_or_path unilm1.2-base-uncased --do_lower_case --fp16 --fp16_opt_level O2 \ 36 | --max_source_seq_length 464 --max_target_seq_length 48 --per_gpu_train_batch_size $BATCH_SIZE --gradient_accumulation_steps 1 \ 37 | --learning_rate 7e-5 --num_warmup_steps 500 --num_training_steps 40000 --cache_dir $CACHE_DIR --save_steps 2000 \ 38 | --use_distill 1 --kd_weight 0.6 --teacher_dropout_prob 0.15 --use_teacher_dropout 1 --teacher_model $TEACHER --word_drop_prob 0.1 --use_noisy_student 1 --sent_shuffle_k 2 39 | 40 | ## Decode 41 | MODEL_PATH=../models/xsum.unilm.distill/ckpt-40000 42 | SPLIT=test 43 | INPUT_JSON=../data/xsum.test.uncased_tokenized.json 44 | 45 | export CUDA_VISIBLE_DEVICES=1 46 | export OMP_NUM_THREADS=4 47 | export MKL_NUM_THREADS=4 48 | 49 | python decode_seq2seq.py --model_type unilm --tokenizer_name unilm1.2-base-uncased --input_file ${INPUT_JSON} --split $SPLIT --do_lower_case \ 50 | --model_path ${MODEL_PATH} --max_seq_length 512 --max_tgt_length 48 --batch_size 32 --beam_size 8 \ 51 | --length_penalty 0.9 --forbid_duplicate_ngrams --mode s2s --forbid_ignore_word "." --min_len 5 52 | 53 | 54 | -------------------------------------------------------------------------------- /src/s2s_ft/tokenization_minilm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # The MIT License (MIT) 3 | 4 | # Copyright (c) Microsoft Corporation 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | """Tokenization classes for MiniLM.""" 24 | 25 | from __future__ import absolute_import, division, print_function, unicode_literals 26 | 27 | import collections 28 | import logging 29 | import os 30 | import unicodedata 31 | from io import open 32 | 33 | from transformers.tokenization_bert import BertTokenizer, whitespace_tokenize 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 38 | 39 | PRETRAINED_VOCAB_FILES_MAP = { 40 | 'vocab_file': 41 | { 42 | 'minilm-l12-h384-uncased': "https://unilm.blob.core.windows.net/ckpt/minilm-l12-h384-uncased-vocab.txt", 43 | } 44 | } 45 | 46 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 47 | 'minilm-l12-h384-uncased': 512, 48 | } 49 | 50 | 51 | class MinilmTokenizer(BertTokenizer): 52 | r""" 53 | Constructs a MinilmTokenizer. 54 | :class:`~transformers.MinilmTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 55 | Args: 56 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 57 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 58 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 59 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 60 | minimum of this value (if specified) and the underlying BERT model's sequence length. 61 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 62 | do_wordpiece_only=False 63 | """ 64 | 65 | vocab_files_names = VOCAB_FILES_NAMES 66 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 67 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 68 | 69 | 70 | class WhitespaceTokenizer(object): 71 | def tokenize(self, text): 72 | return whitespace_tokenize(text) 73 | -------------------------------------------------------------------------------- /src/s2s_ft/tokenization_unilm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # The MIT License (MIT) 3 | 4 | # Copyright (c) Microsoft Corporation 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | """Tokenization classes for UniLM.""" 24 | 25 | from __future__ import absolute_import, division, print_function, unicode_literals 26 | 27 | import collections 28 | import logging 29 | import os 30 | import unicodedata 31 | from io import open 32 | 33 | from transformers.tokenization_bert import BertTokenizer, whitespace_tokenize 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 38 | 39 | PRETRAINED_VOCAB_FILES_MAP = { 40 | 'vocab_file': 41 | { 42 | 'unilm-large-cased': "https://unilm.blob.core.windows.net/ckpt/unilm-large-cased-vocab.txt", 43 | 'unilm-base-cased': "https://unilm.blob.core.windows.net/ckpt/unilm-base-cased-vocab.txt", 44 | 'unilm1-large-cased': "https://unilm.blob.core.windows.net/ckpt/unilm1-large-cased-vocab.txt", 45 | 'unilm1-base-cased': "https://unilm.blob.core.windows.net/ckpt/unilm1-base-cased-vocab.txt", 46 | 'unilm1.2-base-uncased': "https://unilm.blob.core.windows.net/ckpt/unilm1.2-base-uncased-vocab.txt" 47 | } 48 | } 49 | 50 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 51 | 'unilm-large-cased': 512, 52 | 'unilm-base-cased': 512, 53 | 'unilm1-large-cased': 512, 54 | 'unilm1-base-cased': 512, 55 | 'unilm1.2-base-uncased': 512, 56 | } 57 | 58 | 59 | class UnilmTokenizer(BertTokenizer): 60 | r""" 61 | Constructs a UnilmTokenizer. 62 | :class:`~transformers.UnilmTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 63 | Args: 64 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 65 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 66 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 67 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 68 | minimum of this value (if specified) and the underlying BERT model's sequence length. 69 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 70 | do_wordpiece_only=False 71 | """ 72 | 73 | vocab_files_names = VOCAB_FILES_NAMES 74 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 75 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 76 | 77 | 78 | class WhitespaceTokenizer(object): 79 | def tokenize(self, text): 80 | return whitespace_tokenize(text) 81 | -------------------------------------------------------------------------------- /src/s2s_ft/convert_state_dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | from transformers.modeling_utils import cached_path, WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def get_checkpoint_from_transformer_cache( 10 | archive_file, pretrained_model_name_or_path, pretrained_model_archive_map, 11 | cache_dir, force_download, proxies, resume_download, 12 | ): 13 | try: 14 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, 15 | proxies=proxies, resume_download=resume_download) 16 | except EnvironmentError: 17 | if pretrained_model_name_or_path in pretrained_model_archive_map: 18 | msg = "Couldn't reach server at '{}' to download pretrained weights.".format( 19 | archive_file) 20 | else: 21 | msg = "Model name '{}' was not found in model name list ({}). " \ 22 | "We assumed '{}' was a path or url to model weight files named one of {} but " \ 23 | "couldn't find any such file at this path or url.".format( 24 | pretrained_model_name_or_path, 25 | ', '.join(pretrained_model_archive_map.keys()), 26 | archive_file, 27 | [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME]) 28 | raise EnvironmentError(msg) 29 | 30 | if resolved_archive_file == archive_file: 31 | logger.info("loading weights file {}".format(archive_file)) 32 | else: 33 | logger.info("loading weights file {} from cache at {}".format( 34 | archive_file, resolved_archive_file)) 35 | 36 | return torch.load(resolved_archive_file, map_location='cpu') 37 | 38 | 39 | def hf_roberta_to_hf_bert(state_dict): 40 | logger.info(" * Convert Huggingface RoBERTa format to Huggingface BERT format * ") 41 | 42 | new_state_dict = {} 43 | 44 | for key in state_dict: 45 | value = state_dict[key] 46 | if key == 'roberta.embeddings.position_embeddings.weight': 47 | value = value[2:] 48 | if key == 'roberta.embeddings.token_type_embeddings.weight': 49 | continue 50 | if key.startswith('roberta'): 51 | key = 'bert.' + key[8:] 52 | elif key.startswith('lm_head'): 53 | if 'layer_norm' in key or 'dense' in key: 54 | key = 'cls.predictions.transform.' + key[8:] 55 | else: 56 | key = 'cls.predictions.' + key[8:] 57 | key = key.replace('layer_norm', 'LayerNorm') 58 | 59 | new_state_dict[key] = value 60 | 61 | return new_state_dict 62 | 63 | 64 | def hf_distilbert_to_hf_bert(state_dict): 65 | logger.info(" * Convert Huggingface DistilBERT format to Huggingface BERT format * ") 66 | 67 | new_state_dict = {} 68 | 69 | for key in state_dict: 70 | value = state_dict[key] 71 | if key == 'roberta.embeddings.position_embeddings.weight': 72 | value = value[2:] 73 | if key == 'roberta.embeddings.token_type_embeddings.weight': 74 | continue 75 | if key.startswith('roberta'): 76 | key = 'bert.' + key[8:] 77 | elif key.startswith('lm_head'): 78 | if 'layer_norm' in key or 'dense' in key: 79 | key = 'cls.predictions.transform.' + key[8:] 80 | else: 81 | key = 'cls.predictions.' + key[8:] 82 | key = key.replace('layer_norm', 'LayerNorm') 83 | 84 | new_state_dict[key] = value 85 | 86 | return new_state_dict 87 | 88 | 89 | def hf_bert_to_hf_bert(state_dict): 90 | # keep no change 91 | return state_dict 92 | 93 | 94 | state_dict_convert = { 95 | 'bert': hf_bert_to_hf_bert, 96 | 'unilm': hf_bert_to_hf_bert, 97 | 'minilm': hf_bert_to_hf_bert, 98 | 'roberta': hf_roberta_to_hf_bert, 99 | 'xlm-roberta': hf_roberta_to_hf_bert, 100 | 'distilbert': hf_distilbert_to_hf_bert, 101 | } 102 | -------------------------------------------------------------------------------- /src/s2s_ft/s2s_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from random import randint, shuffle, choice 4 | from random import random as rand 5 | import math 6 | import logging 7 | import torch 8 | import torch.utils.data 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def get_random_word(vocab_words): 15 | i = randint(0, len(vocab_words)-1) 16 | return vocab_words[i] 17 | 18 | 19 | def batch_list_to_batch_tensors(batch): 20 | batch_tensors = [] 21 | for x in zip(*batch): 22 | if x[0] is None: 23 | batch_tensors.append(None) 24 | elif isinstance(x[0], torch.Tensor): 25 | batch_tensors.append(torch.stack(x)) 26 | else: 27 | batch_tensors.append(torch.tensor(x, dtype=torch.long)) 28 | return batch_tensors 29 | 30 | 31 | def _get_word_split_index(tokens, st, end): 32 | split_idx = [] 33 | i = st 34 | while i < end: 35 | if (not tokens[i].startswith('##')) or (i == st): 36 | split_idx.append(i) 37 | i += 1 38 | split_idx.append(end) 39 | return split_idx 40 | 41 | 42 | def _expand_whole_word(tokens, st, end): 43 | new_st, new_end = st, end 44 | while (new_st >= 0) and tokens[new_st].startswith('##'): 45 | new_st -= 1 46 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 47 | new_end += 1 48 | return new_st, new_end 49 | 50 | 51 | class Pipeline(): 52 | """ Pre-process Pipeline Class : callable """ 53 | 54 | def __init__(self): 55 | super().__init__() 56 | self.skipgram_prb = None 57 | self.skipgram_size = None 58 | self.pre_whole_word = None 59 | self.mask_whole_word = None 60 | self.word_subsample_prb = None 61 | self.sp_prob = None 62 | self.pieces_dir = None 63 | self.vocab_words = None 64 | self.pieces_threshold = 10 65 | self.call_count = 0 66 | self.offline_mode = False 67 | self.skipgram_size_geo_list = None 68 | self.span_same_mask = False 69 | 70 | def __call__(self, instance): 71 | raise NotImplementedError 72 | 73 | 74 | class Preprocess4Seq2seqDecoder(Pipeline): 75 | """ Pre-processing steps for pretraining transformer """ 76 | 77 | def __init__(self, vocab_words, indexer, max_len=512, max_tgt_length=128, 78 | mode="s2s", pos_shift=False, source_type_id=0, target_type_id=1, 79 | cls_token='[CLS]', sep_token='[SEP]', pad_token='[PAD]'): 80 | super().__init__() 81 | self.max_len = max_len 82 | self.vocab_words = vocab_words # vocabulary (sub)words 83 | self.indexer = indexer # function from token to token index 84 | self.max_len = max_len 85 | self._tril_matrix = torch.tril(torch.ones((max_len, max_len), dtype=torch.long)) 86 | self.task_idx = 3 # relax projection layer for different tasks 87 | assert mode in ("s2s", "l2r") 88 | self.mode = mode 89 | self.max_tgt_length = max_tgt_length 90 | self.pos_shift = pos_shift 91 | 92 | self.cls_token = cls_token 93 | self.sep_token = sep_token 94 | self.pad_token = pad_token 95 | 96 | self.source_type_id = source_type_id 97 | self.target_type_id = target_type_id 98 | 99 | self.cc = 0 100 | 101 | def __call__(self, instance): 102 | tokens_a, max_a_len = instance 103 | 104 | padded_tokens_a = [self.cls_token] + tokens_a + [self.sep_token] 105 | assert len(padded_tokens_a) <= max_a_len + 2 106 | if max_a_len + 2 > len(padded_tokens_a): 107 | padded_tokens_a += [self.pad_token] * \ 108 | (max_a_len + 2 - len(padded_tokens_a)) 109 | assert len(padded_tokens_a) == max_a_len + 2 110 | max_len_in_batch = min(self.max_tgt_length + 111 | max_a_len + 2, self.max_len) 112 | tokens = padded_tokens_a 113 | segment_ids = [self.source_type_id] * (len(padded_tokens_a)) \ 114 | + [self.target_type_id] * (max_len_in_batch - len(padded_tokens_a)) 115 | 116 | mask_qkv = None 117 | 118 | position_ids = [] 119 | for i in range(len(tokens_a) + 2): 120 | position_ids.append(i) 121 | for i in range(len(tokens_a) + 2, max_a_len + 2): 122 | position_ids.append(0) 123 | for i in range(max_a_len + 2, max_len_in_batch): 124 | position_ids.append(i - (max_a_len + 2) + len(tokens_a) + 2) 125 | 126 | # Token Indexing 127 | input_ids = self.indexer(tokens) 128 | 129 | self.cc += 1 130 | if self.cc < 20: 131 | logger.info("Input src = %s" % " ".join(self.vocab_words[tk_id] for tk_id in input_ids)) 132 | 133 | # Zero Padding 134 | input_mask = torch.zeros( 135 | max_len_in_batch, max_len_in_batch, dtype=torch.long) 136 | if self.mode == "s2s": 137 | input_mask[:, :len(tokens_a)+2].fill_(1) 138 | else: 139 | st, end = 0, len(tokens_a) + 2 140 | input_mask[st:end, st:end].copy_( 141 | self._tril_matrix[:end, :end]) 142 | input_mask[end:, :len(tokens_a)+2].fill_(1) 143 | second_st, second_end = len(padded_tokens_a), max_len_in_batch 144 | 145 | input_mask[second_st:second_end, second_st:second_end].copy_( 146 | self._tril_matrix[:second_end-second_st, :second_end-second_st]) 147 | 148 | return (input_ids, segment_ids, position_ids, input_mask, mask_qkv, self.task_idx) 149 | -------------------------------------------------------------------------------- /src/s2s_ft/configuration_minilm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # The MIT License (MIT) 3 | 4 | # Copyright (c) Microsoft Corporation 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | """ MiniLM model configuration """ 24 | 25 | from __future__ import absolute_import, division, print_function, unicode_literals 26 | 27 | import json 28 | import logging 29 | import sys 30 | from io import open 31 | 32 | from transformers.configuration_utils import PretrainedConfig 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 37 | 'minilm-l12-h384-uncased': "https://unilm.blob.core.windows.net/ckpt/minilm-l12-h384-uncased-config.json", 38 | } 39 | 40 | 41 | class MinilmConfig(PretrainedConfig): 42 | r""" 43 | :class:`~transformers.MinilmConfig` is the configuration class to store the configuration of a 44 | `MinilmModel`. 45 | Arguments: 46 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `MiniLMModel`. 47 | hidden_size: Size of the encoder layers and the pooler layer. 48 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 49 | num_attention_heads: Number of attention heads for each attention layer in 50 | the Transformer encoder. 51 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 52 | layer in the Transformer encoder. 53 | hidden_act: The non-linear activation function (function or string) in the 54 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 55 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 56 | layers in the embeddings, encoder, and pooler. 57 | attention_probs_dropout_prob: The dropout ratio for the attention 58 | probabilities. 59 | max_position_embeddings: The maximum sequence length that this model might 60 | ever be used with. Typically set this to something large just in case 61 | (e.g., 512 or 1024 or 2048). 62 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 63 | `MiniLMModel`. 64 | initializer_range: The sttdev of the truncated_normal_initializer for 65 | initializing all weight matrices. 66 | layer_norm_eps: The epsilon used by LayerNorm. 67 | """ 68 | pretrained_config_archive_map = MINILM_PRETRAINED_CONFIG_ARCHIVE_MAP 69 | 70 | def __init__(self, 71 | vocab_size=28996, 72 | hidden_size=768, 73 | num_hidden_layers=12, 74 | num_attention_heads=12, 75 | intermediate_size=3072, 76 | hidden_act="gelu", 77 | hidden_dropout_prob=0.1, 78 | attention_probs_dropout_prob=0.1, 79 | max_position_embeddings=512, 80 | type_vocab_size=6, 81 | initializer_range=0.02, 82 | layer_norm_eps=1e-12, 83 | source_type_id=0, 84 | target_type_id=1, 85 | **kwargs): 86 | super(MinilmConfig, self).__init__(**kwargs) 87 | if isinstance(vocab_size, str) or (sys.version_info[0] == 2 88 | and isinstance(vocab_size, unicode)): 89 | with open(vocab_size, "r", encoding='utf-8') as reader: 90 | json_config = json.loads(reader.read()) 91 | for key, value in json_config.items(): 92 | self.__dict__[key] = value 93 | elif isinstance(vocab_size, int): 94 | self.vocab_size = vocab_size 95 | self.hidden_size = hidden_size 96 | self.num_hidden_layers = num_hidden_layers 97 | self.num_attention_heads = num_attention_heads 98 | self.hidden_act = hidden_act 99 | self.intermediate_size = intermediate_size 100 | self.hidden_dropout_prob = hidden_dropout_prob 101 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 102 | self.max_position_embeddings = max_position_embeddings 103 | self.type_vocab_size = type_vocab_size 104 | self.initializer_range = initializer_range 105 | self.layer_norm_eps = layer_norm_eps 106 | self.source_type_id = source_type_id 107 | self.target_type_id = target_type_id 108 | else: 109 | raise ValueError("First argument must be either a vocabulary size (int)" 110 | " or the path to a pretrained model config file (str)") 111 | -------------------------------------------------------------------------------- /src/s2s_ft/configuration_unilm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # The MIT License (MIT) 3 | 4 | # Copyright (c) Microsoft Corporation 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | """ UniLM model configuration """ 24 | 25 | from __future__ import absolute_import, division, print_function, unicode_literals 26 | 27 | import json 28 | import logging 29 | import sys 30 | from io import open 31 | 32 | from transformers.configuration_utils import PretrainedConfig 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP = { 37 | 'unilm-large-cased': "https://unilm.blob.core.windows.net/ckpt/unilm-large-cased-config.json", 38 | 'unilm-base-cased': "https://unilm.blob.core.windows.net/ckpt/unilm-base-cased-config.json", 39 | 'unilm1-large-cased': "https://unilm.blob.core.windows.net/ckpt/unilm1-large-cased-config.json", 40 | 'unilm1-base-cased': "https://unilm.blob.core.windows.net/ckpt/unilm1-base-cased-config.json", 41 | 'unilm1.2-base-uncased': "https://unilm.blob.core.windows.net/ckpt/unilm1.2-base-uncased-config.json", 42 | } 43 | 44 | 45 | class UnilmConfig(PretrainedConfig): 46 | r""" 47 | :class:`~transformers.UnilmConfig` is the configuration class to store the configuration of a 48 | `UnilmModel`. 49 | Arguments: 50 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `UnilmModel`. 51 | hidden_size: Size of the encoder layers and the pooler layer. 52 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 53 | num_attention_heads: Number of attention heads for each attention layer in 54 | the Transformer encoder. 55 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 56 | layer in the Transformer encoder. 57 | hidden_act: The non-linear activation function (function or string) in the 58 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 59 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 60 | layers in the embeddings, encoder, and pooler. 61 | attention_probs_dropout_prob: The dropout ratio for the attention 62 | probabilities. 63 | max_position_embeddings: The maximum sequence length that this model might 64 | ever be used with. Typically set this to something large just in case 65 | (e.g., 512 or 1024 or 2048). 66 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 67 | `UnilmModel`. 68 | initializer_range: The sttdev of the truncated_normal_initializer for 69 | initializing all weight matrices. 70 | layer_norm_eps: The epsilon used by LayerNorm. 71 | """ 72 | pretrained_config_archive_map = UNILM_PRETRAINED_CONFIG_ARCHIVE_MAP 73 | 74 | def __init__(self, 75 | vocab_size=28996, 76 | hidden_size=768, 77 | num_hidden_layers=12, 78 | num_attention_heads=12, 79 | intermediate_size=3072, 80 | hidden_act="gelu", 81 | hidden_dropout_prob=0.1, 82 | attention_probs_dropout_prob=0.1, 83 | max_position_embeddings=512, 84 | type_vocab_size=6, 85 | initializer_range=0.02, 86 | layer_norm_eps=1e-12, 87 | source_type_id=0, 88 | target_type_id=1, 89 | **kwargs): 90 | super(UnilmConfig, self).__init__(**kwargs) 91 | if isinstance(vocab_size, str) or (sys.version_info[0] == 2 92 | and isinstance(vocab_size, unicode)): 93 | with open(vocab_size, "r", encoding='utf-8') as reader: 94 | json_config = json.loads(reader.read()) 95 | for key, value in json_config.items(): 96 | self.__dict__[key] = value 97 | elif isinstance(vocab_size, int): 98 | self.vocab_size = vocab_size 99 | self.hidden_size = hidden_size 100 | self.num_hidden_layers = num_hidden_layers 101 | self.num_attention_heads = num_attention_heads 102 | self.hidden_act = hidden_act 103 | self.intermediate_size = intermediate_size 104 | self.hidden_dropout_prob = hidden_dropout_prob 105 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 106 | self.max_position_embeddings = max_position_embeddings 107 | self.type_vocab_size = type_vocab_size 108 | self.initializer_range = initializer_range 109 | self.layer_norm_eps = layer_norm_eps 110 | self.source_type_id = source_type_id 111 | self.target_type_id = target_type_id 112 | else: 113 | raise ValueError("First argument must be either a vocabulary size (int)" 114 | " or the path to a pretrained model config file (str)") 115 | -------------------------------------------------------------------------------- /src/gen_seq_from_trace.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import math 3 | import argparse 4 | import glob 5 | import logging 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | import unicodedata 9 | 10 | from transformers import BertTokenizer, RobertaTokenizer 11 | from s2s_ft.tokenization_unilm import UnilmTokenizer 12 | from s2s_ft.tokenization_minilm import MinilmTokenizer 13 | 14 | 15 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 16 | datefmt='%m/%d/%Y %H:%M:%S', 17 | level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | TOKENIZER_CLASSES = { 22 | 'bert': BertTokenizer, 23 | 'minilm': MinilmTokenizer, 24 | 'roberta': RobertaTokenizer, 25 | 'unilm': UnilmTokenizer, 26 | } 27 | 28 | 29 | def read_traces_from_file(file_name): 30 | with open(file_name, "rb") as fin: 31 | meta = pickle.load(fin) 32 | num_samples = meta["num_samples"] 33 | samples = [] 34 | for _ in range(num_samples): 35 | samples.append(pickle.load(fin)) 36 | return samples 37 | 38 | 39 | def get_best_sequence(sample, eos_id, pad_id, length_penalty=None, alpha=None, expect=None, min_len=None): 40 | # if not any((length_penalty, alpha, expect, min_len)): 41 | # raise ValueError( 42 | # "You can only specify length penalty or alpha, but not both.") 43 | scores = sample["scores"] 44 | wids_list = sample["wids"] 45 | ptrs = sample["ptrs"] 46 | 47 | last_frame_id = len(scores) - 1 48 | for i, wids in enumerate(wids_list): 49 | if all(wid in (eos_id, pad_id) for wid in wids): 50 | last_frame_id = i 51 | break 52 | while all(wid == pad_id for wid in wids_list[last_frame_id]): 53 | last_frame_id -= 1 54 | 55 | max_score = -math.inf 56 | frame_id = -1 57 | pos_in_frame = -1 58 | 59 | for fid in range(last_frame_id + 1): 60 | for i, wid in enumerate(wids_list[fid]): 61 | if fid <= last_frame_id and scores[fid][i] >= 0: 62 | # skip paddings 63 | continue 64 | if (wid in (eos_id, pad_id)) or fid == last_frame_id: 65 | s = scores[fid][i] 66 | if length_penalty: 67 | if expect: 68 | s -= length_penalty * math.fabs(fid+1 - expect) 69 | else: 70 | s += length_penalty * (fid + 1) 71 | elif alpha: 72 | s = s / math.pow((5 + fid + 1) / 6.0, alpha) 73 | if s > max_score: 74 | # if (frame_id != -1) and min_len and (fid+1 < min_len): 75 | # continue 76 | max_score = s 77 | frame_id = fid 78 | pos_in_frame = i 79 | if frame_id == -1: 80 | seq = [] 81 | else: 82 | seq = [wids_list[frame_id][pos_in_frame]] 83 | for fid in range(frame_id, 0, -1): 84 | pos_in_frame = ptrs[fid][pos_in_frame] 85 | seq.append(wids_list[fid - 1][pos_in_frame]) 86 | seq.reverse() 87 | return seq 88 | 89 | 90 | def detokenize(tk_list): 91 | r_list = [] 92 | for tk in tk_list: 93 | if tk.startswith('##') and len(r_list) > 0: 94 | r_list[-1] = r_list[-1] + tk[2:] 95 | else: 96 | r_list.append(tk) 97 | return r_list 98 | 99 | 100 | def simple_postprocess(tk_list): 101 | # truncate duplicate punctuations 102 | while tk_list and len(tk_list) > 4 and len(tk_list[-1]) == 1 and unicodedata.category(tk_list[-1]).startswith('P') and all(it == tk_list[-1] for it in tk_list[-4:]): 103 | tk_list = tk_list[:-3] 104 | return tk_list 105 | 106 | 107 | # def include_unk(line): 108 | # return " UNK ".join(line.split('')).strip() 109 | 110 | 111 | def main(args): 112 | tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained( 113 | args.tokenizer_name, do_lower_case=args.do_lower_case, 114 | cache_dir=args.cache_dir if args.cache_dir else None) 115 | eos_token = tokenizer.sep_token 116 | pad_token = tokenizer.pad_token 117 | 118 | eos_id, pad_id = set(tokenizer.convert_tokens_to_ids([eos_token, pad_token])) 119 | logger.info("*********************************************") 120 | logger.info(" EOS TOKEN = {}, ID = {}".format(eos_token, eos_id)) 121 | logger.info(" PAD TOKEN = {}, ID = {}".format(pad_token, pad_id)) 122 | logger.info("*********************************************") 123 | 124 | for input_file in tqdm(glob.glob(args.input)): 125 | if not Path(input_file+'.trace.pickle').exists(): 126 | continue 127 | print(input_file) 128 | samples = read_traces_from_file(input_file+'.trace.pickle') 129 | 130 | results = [] 131 | 132 | for s in samples: 133 | word_ids = get_best_sequence(s, eos_id, pad_id, alpha=args.alpha, 134 | length_penalty=args.length_penalty, expect=args.expect, min_len=args.min_len) 135 | tokens = tokenizer.convert_ids_to_tokens(word_ids) 136 | buf = [] 137 | for t in tokens: 138 | if t in (eos_token, pad_token): 139 | break 140 | else: 141 | buf.append(t) 142 | if args.model_type == "roberta": 143 | output_text = " ".join(simple_postprocess(tokenizer.convert_tokens_to_string(buf).split(' '))) 144 | if '\n' in output_text: 145 | output_text = " [S_SEPX_SEP] ".join(output_text.split('\n')) 146 | else: 147 | output_text = " ".join(simple_postprocess(detokenize(buf))) 148 | 149 | results.append(output_text) 150 | 151 | fn_out = input_file + '.' 152 | if args.length_penalty: 153 | fn_out += 'lenp'+str(args.length_penalty) 154 | if args.expect: 155 | fn_out += 'exp'+str(args.expect) 156 | if args.alpha: 157 | fn_out += 'alp'+str(args.alpha) 158 | if args.min_len: 159 | fn_out += 'minl'+str(args.min_len) 160 | with open(fn_out, "w", encoding="utf-8") as fout: 161 | for line in results: 162 | fout.write(line) 163 | fout.write("\n") 164 | logger.info("Output file = [%s]" % fn_out) 165 | 166 | if __name__ == "__main__": 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--input", type=str, help="Input file.") 169 | parser.add_argument("--model_type", default=None, type=str, required=True, 170 | help="Model type selected in the list: " + ", ".join(TOKENIZER_CLASSES.keys())) 171 | parser.add_argument("--alpha", default=None, type=float) 172 | parser.add_argument("--length_penalty", default=None, type=float) 173 | parser.add_argument("--expect", default=None, type=float, 174 | help="Expectation of target length.") 175 | parser.add_argument("--min_len", default=None, type=int) 176 | # tokenizer_name 177 | parser.add_argument("--tokenizer_name", default=None, type=str, required=True, 178 | help="tokenizer name") 179 | parser.add_argument("--do_lower_case", action='store_true', 180 | help="Set this flag if you are using an uncased model.") 181 | parser.add_argument("--cache_dir", default=None, type=str, 182 | help="Where do you want to store the pre-trained models downloaded from s3") 183 | args = parser.parse_args() 184 | 185 | main(args) 186 | -------------------------------------------------------------------------------- /src/s2s_ft/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import itertools 4 | import logging 5 | import os 6 | import json 7 | import random 8 | import glob 9 | import torch 10 | import tqdm 11 | import torch.utils.data 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class Seq2seqDatasetForBert(torch.utils.data.Dataset): 18 | def __init__( 19 | self, features, max_source_len, max_target_len, 20 | vocab_size, cls_id, sep_id, pad_id, mask_id, 21 | random_prob, keep_prob, offset, num_training_instances, word_drop_prob=0,word_shuffle_k=0,sent_shuffle_k=0,sent_drop_prob=0, 22 | span_len=1, span_prob=1.0): 23 | self.features = features 24 | self.max_source_len = max_source_len 25 | self.max_target_len = max_target_len 26 | self.offset = offset 27 | if offset > 0: 28 | logger.info(" **** Set offset %d in Seq2seqDatasetForBert **** ", offset) 29 | self.cls_id = cls_id 30 | self.sep_id = sep_id 31 | self.pad_id = pad_id 32 | self.random_prob = random_prob 33 | self.keep_prob = keep_prob 34 | self.mask_id = mask_id 35 | self.vocab_size = vocab_size 36 | self.num_training_instances = num_training_instances 37 | self.span_len = span_len 38 | self.span_prob = span_prob 39 | self.word_drop_prob = word_drop_prob 40 | self.word_shuffle_k = word_shuffle_k 41 | self.sent_shuffle_k = sent_shuffle_k 42 | self.sent_drop_prob = sent_drop_prob 43 | self.src_sentence_sep_id = 1011 44 | 45 | def __len__(self): 46 | return int(self.num_training_instances) 47 | 48 | def __trunk(self, ids, max_len): 49 | if len(ids) > max_len - 1: 50 | ids = ids[:max_len - 1] 51 | ids = ids + [self.sep_id] 52 | return ids 53 | 54 | def __pad(self, ids, max_len): 55 | if len(ids) < max_len: 56 | return ids + [self.pad_id] * (max_len - len(ids)) 57 | else: 58 | assert len(ids) == max_len 59 | return ids 60 | 61 | def __getitem__(self, idx): 62 | idx = (self.offset + idx) % len(self.features) 63 | feature = self.features[idx] 64 | noisy_source_ids_1 = feature["source_ids"] 65 | noisy_source_ids_2 = feature["source_ids"] 66 | if(self.word_drop_prob>0): 67 | noisy_source_ids_1 = [_id for _id in noisy_source_ids_1 if random.random()>self.word_drop_prob] 68 | noisy_source_ids_2 = [_id for _id in noisy_source_ids_2 if random.random()>self.word_drop_prob] 69 | 70 | 71 | # def perm(i,k): 72 | # return i[0] + (k + 1) * random.random() 73 | 74 | if(self.word_shuffle_k>0): 75 | noisy_source_ids_1 = [x for _, x in sorted(enumerate(noisy_source_ids_1), key=lambda t:t[0]+ (self.word_shuffle_k + 1) * random.random())] 76 | noisy_source_ids_2 = [x for _, x in sorted(enumerate(noisy_source_ids_2), key=lambda t:t[0]+ (self.word_shuffle_k + 1) * random.random())] 77 | 78 | if(self.sent_shuffle_k>0 or self.sent_drop_prob>0): 79 | size = len(noisy_source_ids_1) 80 | idx_list = [idx + 1 for idx, val in 81 | enumerate(noisy_source_ids_1) if val == self.src_sentence_sep_id] 82 | if(len(idx_list)>5): 83 | 84 | source_ids_splitted = [noisy_source_ids_1[i: j] for i, j in zip([0] + idx_list, idx_list + ([size] if idx_list[-1] != size else []))] 85 | source_ids_splitted = [sent for sent in source_ids_splitted if random.random() > self.sent_drop_prob] 86 | 87 | source_ids_splitted = [x for _, x in sorted(enumerate(source_ids_splitted), key=lambda t:t[0]+ (self.sent_shuffle_k + 1) * random.random())] 88 | noisy_source_ids_1 = list(itertools.chain.from_iterable(source_ids_splitted)) 89 | size = len(noisy_source_ids_2) 90 | idx_list = [idx + 1 for idx, val in 91 | enumerate(noisy_source_ids_2) if val == self.src_sentence_sep_id] 92 | if(len(idx_list)>5): 93 | 94 | source_ids_splitted = [noisy_source_ids_2[i: j] for i, j in zip([0] + idx_list, idx_list + ([size] if idx_list[-1] != size else []))] 95 | source_ids_splitted = [sent for sent in source_ids_splitted if random.random() > self.sent_drop_prob] 96 | 97 | source_ids_splitted = [x for _, x in sorted(enumerate(source_ids_splitted), key=lambda t:t[0]+ (self.sent_shuffle_k + 1) * random.random())] 98 | noisy_source_ids_2 = list(itertools.chain.from_iterable(source_ids_splitted)) 99 | 100 | 101 | 102 | source_ids = self.__trunk([self.cls_id] + feature["source_ids"], self.max_source_len) 103 | noisy_source_ids_1 = self.__trunk([self.cls_id] + noisy_source_ids_1, self.max_source_len) 104 | noisy_source_ids_2 = self.__trunk([self.cls_id] + noisy_source_ids_2, self.max_source_len) 105 | 106 | 107 | target_ids = self.__trunk(feature["target_ids"], self.max_target_len) 108 | pseudo_ids = [] 109 | for tk_id in target_ids: 110 | p = random.random() 111 | if p < self.keep_prob: 112 | pseudo_ids.append(tk_id) 113 | elif p < self.keep_prob + self.random_prob: 114 | pseudo_ids.append(random.randint(0, self.vocab_size - 1)) 115 | else: 116 | pseudo_ids.append(self.mask_id) 117 | 118 | num_source_tokens = len(source_ids) 119 | num_noisy_source_tokens_1 = len(noisy_source_ids_1) 120 | num_noisy_source_tokens_2 = len(noisy_source_ids_2) 121 | num_target_tokens = len(target_ids) 122 | 123 | source_ids = self.__pad(source_ids, self.max_source_len) 124 | noisy_source_ids_1 = self.__pad(noisy_source_ids_1, self.max_source_len) 125 | noisy_source_ids_2 = self.__pad(noisy_source_ids_2, self.max_source_len) 126 | target_ids = self.__pad(target_ids, self.max_target_len) 127 | pseudo_ids = self.__pad(pseudo_ids, self.max_target_len) 128 | 129 | if self.span_len > 1: 130 | span_ids = [] 131 | span_id = 1 132 | while len(span_ids) < num_target_tokens: 133 | p = random.random() 134 | if p < self.span_prob: 135 | span_len = random.randint(2, self.span_len) 136 | span_len = min(span_len, num_target_tokens - len(span_ids)) 137 | else: 138 | span_len = 1 139 | span_ids.extend([span_id] * span_len) 140 | span_id += 1 141 | span_ids = self.__pad(span_ids, self.max_target_len) 142 | return source_ids, noisy_source_ids_1, noisy_source_ids_2, target_ids, pseudo_ids, num_source_tokens, num_noisy_source_tokens_1, num_noisy_source_tokens_2, num_target_tokens, span_ids 143 | else: 144 | return source_ids, noisy_source_ids_1, noisy_source_ids_2, target_ids, pseudo_ids, num_source_tokens, num_noisy_source_tokens_1, num_noisy_source_tokens_2, num_target_tokens 145 | 146 | 147 | def batch_list_to_batch_tensors(batch): 148 | batch_tensors = [] 149 | for x in zip(*batch): 150 | if isinstance(x[0], torch.Tensor): 151 | batch_tensors.append(torch.stack(x)) 152 | else: 153 | batch_tensors.append(torch.tensor(x, dtype=torch.long)) 154 | return batch_tensors 155 | 156 | 157 | def get_max_epoch_model(output_dir): 158 | fn_model_list = glob.glob(os.path.join(output_dir, "model.*.bin")) 159 | fn_optim_list = glob.glob(os.path.join(output_dir, "optim.*.bin")) 160 | if (not fn_model_list) or (not fn_optim_list): 161 | return None 162 | os.path.basename(output_dir) 163 | both_set = set([int(os.path.basename(fn).split('.')[1]) for fn in fn_model_list] 164 | ) & set([int(os.path.basename(fn).split('.')[1]) for fn in fn_optim_list]) 165 | if both_set: 166 | return max(both_set) 167 | else: 168 | return None 169 | 170 | 171 | def load_and_cache_examples( 172 | example_file, tokenizer, local_rank, cached_features_file, shuffle=True): 173 | # Make sure only the first process in distributed training process the dataset, and the others will use the cache 174 | if local_rank not in [-1, 0]: 175 | torch.distributed.barrier() 176 | 177 | if cached_features_file is not None and os.path.exists(cached_features_file): 178 | logger.info("Loading features from cached file %s", cached_features_file) 179 | features = torch.load(cached_features_file) 180 | else: 181 | logger.info("Creating features from dataset file at %s", example_file) 182 | 183 | examples = [] 184 | with open(example_file, mode="r", encoding="utf-8") as reader: 185 | for line in reader: 186 | examples.append(json.loads(line)) 187 | features = [] 188 | 189 | for example in tqdm.tqdm(examples): 190 | if isinstance(example["src"], list): 191 | source_tokens = example["src"] 192 | target_tokens = example["tgt"] 193 | else: 194 | source_tokens = tokenizer.tokenize(example["src"]) 195 | target_tokens = tokenizer.tokenize(example["tgt"]) 196 | features.append({ 197 | "source_ids": tokenizer.convert_tokens_to_ids(source_tokens), 198 | "target_ids": tokenizer.convert_tokens_to_ids(target_tokens), 199 | }) 200 | 201 | if shuffle: 202 | random.shuffle(features) 203 | 204 | if local_rank in [-1, 0] and cached_features_file is not None: 205 | logger.info("Saving features into cached file %s", cached_features_file) 206 | torch.save(features, cached_features_file) 207 | 208 | # Make sure only the first process in distributed training process the dataset, and the others will use the cache 209 | if local_rank == 0: 210 | torch.distributed.barrier() 211 | 212 | return features 213 | -------------------------------------------------------------------------------- /src/evaluations/my_eval_for_cnndm.py: -------------------------------------------------------------------------------- 1 | """BERT finetuning runner.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import logging 9 | import glob 10 | import json 11 | import argparse 12 | import math 13 | import string 14 | from multiprocessing import Pool, cpu_count 15 | from tqdm import tqdm, trange 16 | from pathlib import Path 17 | import numpy as np 18 | # pip install py-rouge 19 | import rouge 20 | import time 21 | import tempfile 22 | import shutil 23 | 24 | # pip install pyrouge 25 | from evaluations.bs_pyrouge import Rouge155 26 | 27 | 28 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 29 | datefmt='%m/%d/%Y %H:%M:%S', 30 | level=logging.INFO) 31 | logger = logging.getLogger(__name__) 32 | 33 | parser = argparse.ArgumentParser() 34 | 35 | # Required parameters 36 | parser.add_argument("--gold", type=str, help="Gold output file.") 37 | parser.add_argument("--pred", type=str, help="Input prediction file.") 38 | parser.add_argument("--split", type=str, default="", 39 | help="Data split (train/dev/test).") 40 | parser.add_argument("--save_best", action='store_true', 41 | help="Save best epoch.") 42 | parser.add_argument("--only_eval_best", action='store_true', 43 | help="Only evaluate best epoch.") 44 | parser.add_argument("--trunc_len", type=int, default=60, 45 | help="Truncate line by the maximum length.") 46 | parser.add_argument("--duplicate_rate", type=float, default=0.7, 47 | help="If the duplicat rate (compared with history) is large, we can discard the current sentence.") 48 | default_process_count = max(1, cpu_count() - 1) 49 | parser.add_argument("--processes", type=int, default=default_process_count, 50 | help="Number of processes to use (default %(default)s)") 51 | parser.add_argument("--perl", action='store_true', 52 | help="Using the perl script.") 53 | parser.add_argument('--lazy_eval', action='store_true', 54 | help="Skip evaluation if the .rouge file exists.") 55 | args = parser.parse_args() 56 | 57 | SPECIAL_TOKEN = ["[UNK]", "[PAD]", "[CLS]", "[MASK]"] 58 | evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2, 59 | limit_length=False, apply_avg=True) 60 | 61 | 62 | def test_rouge(cand, ref): 63 | temp_dir = tempfile.mkdtemp() 64 | candidates = cand 65 | references = ref 66 | assert len(candidates) == len(references) 67 | 68 | cnt = len(candidates) 69 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 70 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) 71 | if not os.path.isdir(tmp_dir): 72 | os.mkdir(tmp_dir) 73 | os.mkdir(tmp_dir + "/candidate") 74 | os.mkdir(tmp_dir + "/reference") 75 | try: 76 | for i in range(cnt): 77 | if len(references[i]) < 1: 78 | continue 79 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 80 | encoding="utf-8") as f: 81 | f.write(candidates[i]) 82 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 83 | encoding="utf-8") as f: 84 | f.write(references[i]) 85 | r = Rouge155(temp_dir=temp_dir) 86 | r.model_dir = tmp_dir + "/reference/" 87 | r.system_dir = tmp_dir + "/candidate/" 88 | r.model_filename_pattern = 'ref.#ID#.txt' 89 | r.system_filename_pattern = r'cand.(\d+).txt' 90 | rouge_results = r.convert_and_evaluate() 91 | print(rouge_results) 92 | results_dict = r.output_to_dict(rouge_results) 93 | finally: 94 | if os.path.isdir(tmp_dir): 95 | shutil.rmtree(tmp_dir) 96 | return results_dict 97 | 98 | 99 | def rouge_results_to_str(results_dict): 100 | return ">> ROUGE-F(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 101 | results_dict["rouge_1_f_score"] * 100, 102 | results_dict["rouge_2_f_score"] * 100, 103 | results_dict["rouge_l_f_score"] * 100, 104 | results_dict["rouge_1_recall"] * 100, 105 | results_dict["rouge_2_recall"] * 100, 106 | results_dict["rouge_l_recall"] * 100 107 | ) 108 | 109 | 110 | def count_tokens(tokens): 111 | counter = {} 112 | for t in tokens: 113 | if t in counter.keys(): 114 | counter[t] += 1 115 | else: 116 | counter[t] = 1 117 | return counter 118 | 119 | 120 | def get_f1(text_a, text_b): 121 | tokens_a = text_a.lower().split() 122 | tokens_b = text_b.lower().split() 123 | if len(tokens_a) == 0 or len(tokens_b) == 0: 124 | return 1 if len(tokens_a) == len(tokens_b) else 0 125 | set_a = count_tokens(tokens_a) 126 | set_b = count_tokens(tokens_b) 127 | match = 0 128 | for token in set_a.keys(): 129 | if token in set_b.keys(): 130 | match += min(set_a[token], set_b[token]) 131 | p = match / len(tokens_a) 132 | r = match / len(tokens_b) 133 | return 2.0 * p * r / (p + r + 1e-5) 134 | 135 | 136 | def remove_duplicate(l_list, duplicate_rate): 137 | tk_list = [l.lower().split() for l in l_list] 138 | r_list = [] 139 | history_set = set() 140 | for i, w_list in enumerate(tk_list): 141 | w_set = set(w_list) 142 | if len(w_set & history_set)/len(w_set) <= duplicate_rate: 143 | r_list.append(l_list[i]) 144 | history_set |= w_set 145 | return r_list 146 | 147 | 148 | def process_eval(eval_fn): 149 | gold_list = [] 150 | with open(args.gold, "r", encoding="utf-8") as f_in: 151 | for l in f_in: 152 | line = l.strip().replace(" ", '\n') 153 | gold_list.append(line) 154 | 155 | pred_list = [] 156 | with open(eval_fn, "r", encoding="utf-8") as f_in: 157 | for l in f_in: 158 | buf = [] 159 | for sentence in l.strip().split("[X_SEP]"): 160 | 161 | while " " in sentence: 162 | sentence = sentence.replace(" ", " ") + ' .' 163 | 164 | if any(get_f1(sentence, s) > 1.0 for s in buf): 165 | continue 166 | s_len = len(sentence.split()) 167 | if s_len <= 4: 168 | continue 169 | buf.append(sentence) 170 | if args.duplicate_rate and args.duplicate_rate < 1: 171 | buf = remove_duplicate(buf, args.duplicate_rate) 172 | if args.trunc_len: 173 | num_left = args.trunc_len 174 | trunc_list = [] 175 | for bit in buf: 176 | tk_list = bit.split() 177 | n = min(len(tk_list), num_left) 178 | trunc_list.append(' '.join(tk_list[:n])) 179 | num_left -= n 180 | if num_left <= 0: 181 | break 182 | else: 183 | trunc_list = buf 184 | line = "\n".join(trunc_list) 185 | pred_list.append(line) 186 | with open(eval_fn+'.post', 'w', encoding='utf-8') as f_out: 187 | for l in pred_list: 188 | f_out.write(l.replace('\n', ' [X_SEP] ').strip()) 189 | f_out.write('\n') 190 | # rouge scores 191 | if len(pred_list) < len(gold_list): 192 | # evaluate subset 193 | gold_list = gold_list[:len(pred_list)] 194 | assert len(pred_list) == len(gold_list) 195 | if args.perl: 196 | scores = test_rouge(pred_list, gold_list) 197 | else: 198 | scores = evaluator.get_scores(pred_list, [[it] for it in gold_list]) 199 | return eval_fn, scores 200 | 201 | 202 | def main(): 203 | if args.perl: 204 | eval_fn_list = list(glob.glob(args.pred)) 205 | else: 206 | eval_fn_list = [eval_fn for eval_fn in glob.glob(args.pred) if not( 207 | args.lazy_eval and Path(eval_fn+".rouge").exists())] 208 | eval_fn_list = list(filter(lambda fn: not(fn.endswith( 209 | '.post') or fn.endswith('.rouge')), eval_fn_list)) 210 | 211 | if args.only_eval_best: 212 | best_epoch_dict = {} 213 | for dir_path in set(Path(fn).parent for fn in eval_fn_list): 214 | fn_save = os.path.join(dir_path, 'save_best.dev') 215 | if Path(fn_save).exists(): 216 | with open(fn_save, 'r') as f_in: 217 | __, o_name, __ = f_in.read().strip().split('\n') 218 | epoch = o_name.split('.')[1] 219 | best_epoch_dict[dir_path] = epoch 220 | new_eval_fn_list = [] 221 | for fn in eval_fn_list: 222 | dir_path = Path(fn).parent 223 | if dir_path in best_epoch_dict: 224 | if Path(fn).name.split('.')[1] == best_epoch_dict[dir_path]: 225 | new_eval_fn_list.append(fn) 226 | eval_fn_list = new_eval_fn_list 227 | 228 | logger.info("***** Evaluation: %s *****", ','.join(eval_fn_list)) 229 | num_pool = min(args.processes, len(eval_fn_list)) 230 | p = Pool(num_pool) 231 | r_list = p.imap_unordered(process_eval, eval_fn_list) 232 | r_list = sorted([(fn, scores) 233 | for fn, scores in r_list], key=lambda x: x[0]) 234 | rg2_dict = {} 235 | for fn, scores in r_list: 236 | print(fn) 237 | if args.perl: 238 | print(rouge_results_to_str(scores)) 239 | else: 240 | rg2_dict[fn] = scores['rouge-2']['f'] 241 | print( 242 | "ROUGE-1: {}\tROUGE-2: {}\n".format(scores['rouge-1']['f'], scores['rouge-2']['f'])) 243 | with open(fn+".rouge", 'w') as f_out: 244 | f_out.write(json.dumps( 245 | {'rg1': scores['rouge-1']['f'], 'rg2': scores['rouge-2']['f']})) 246 | p.close() 247 | p.join() 248 | 249 | if args.save_best: 250 | # find best results 251 | group_dict = {} 252 | for k, v in rg2_dict.items(): 253 | d_name, o_name = Path(k).parent, Path(k).name 254 | if (d_name not in group_dict) or (v > group_dict[d_name][1]): 255 | group_dict[d_name] = (o_name, v) 256 | # compare and save the best result 257 | for k, v in group_dict.items(): 258 | fn = os.path.join(k, 'save_best.'+args.split) 259 | o_name_s, rst_s = v 260 | should_save = True 261 | if Path(fn).exists(): 262 | with open(fn, 'r') as f_in: 263 | rst_f = float(f_in.read().strip().split('\n')[-1]) 264 | if rst_s <= rst_f: 265 | should_save = False 266 | if should_save: 267 | with open(fn, 'w') as f_out: 268 | f_out.write('{0}\n{1}\n{2}\n'.format(k, o_name_s, rst_s)) 269 | 270 | 271 | if __name__ == "__main__": 272 | main() 273 | -------------------------------------------------------------------------------- /src/evaluations/eval_for_xsum.py: -------------------------------------------------------------------------------- 1 | """BERT finetuning runner.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import logging 9 | import glob 10 | import json 11 | import argparse 12 | import math 13 | import string 14 | from multiprocessing import Pool, cpu_count 15 | from tqdm import tqdm, trange 16 | from pathlib import Path 17 | import numpy as np 18 | # pip install py-rouge 19 | import rouge 20 | import time 21 | import tempfile 22 | import shutil 23 | 24 | # pip install pyrouge 25 | from evaluations.bs_pyrouge import Rouge155 26 | 27 | 28 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 29 | datefmt='%m/%d/%Y %H:%M:%S', 30 | level=logging.INFO) 31 | logger = logging.getLogger(__name__) 32 | 33 | parser = argparse.ArgumentParser() 34 | 35 | # Required parameters 36 | parser.add_argument("--gold", type=str, help="Gold output file.") 37 | parser.add_argument("--pred", type=str, help="Input prediction file.") 38 | parser.add_argument("--split", type=str, default="", 39 | help="Data split (train/dev/test).") 40 | parser.add_argument("--save_best", action='store_true', 41 | help="Save best epoch.") 42 | parser.add_argument("--only_eval_best", action='store_true', 43 | help="Only evaluate best epoch.") 44 | parser.add_argument("--trunc_len", type=int, default=0, 45 | help="Truncate line by the maximum length.") 46 | default_process_count = max(1, cpu_count() - 1) 47 | parser.add_argument("--processes", type=int, default=default_process_count, 48 | help="Number of processes to use (default %(default)s)") 49 | parser.add_argument("--perl", action='store_true', 50 | help="Using the perl script.") 51 | parser.add_argument('--lazy_eval', action='store_true', 52 | help="Skip evaluation if the .rouge file exists.") 53 | args = parser.parse_args() 54 | 55 | evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2, 56 | limit_length=False, apply_avg=True, weight_factor=1.2) 57 | 58 | 59 | def test_rouge(cand, ref): 60 | temp_dir = tempfile.mkdtemp() 61 | candidates = cand 62 | references = ref 63 | assert len(candidates) == len(references) 64 | 65 | cnt = len(candidates) 66 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 67 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) 68 | if not os.path.isdir(tmp_dir): 69 | os.mkdir(tmp_dir) 70 | os.mkdir(tmp_dir + "/candidate") 71 | os.mkdir(tmp_dir + "/reference") 72 | try: 73 | for i in range(cnt): 74 | if len(references[i]) < 1: 75 | continue 76 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 77 | encoding="utf-8") as f: 78 | f.write(candidates[i]) 79 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 80 | encoding="utf-8") as f: 81 | f.write(references[i]) 82 | r = Rouge155(temp_dir=temp_dir) 83 | r.model_dir = tmp_dir + "/reference/" 84 | r.system_dir = tmp_dir + "/candidate/" 85 | r.model_filename_pattern = 'ref.#ID#.txt' 86 | r.system_filename_pattern = r'cand.(\d+).txt' 87 | rouge_results = r.convert_and_evaluate() 88 | print(rouge_results) 89 | results_dict = r.output_to_dict(rouge_results) 90 | finally: 91 | if os.path.isdir(tmp_dir): 92 | shutil.rmtree(tmp_dir) 93 | return results_dict 94 | 95 | 96 | def rouge_results_to_str(results_dict): 97 | return ">> ROUGE-F(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 98 | results_dict["rouge_1_f_score"] * 100, 99 | results_dict["rouge_2_f_score"] * 100, 100 | results_dict["rouge_l_f_score"] * 100, 101 | results_dict["rouge_1_recall"] * 100, 102 | results_dict["rouge_2_recall"] * 100, 103 | results_dict["rouge_l_recall"] * 100 104 | ) 105 | 106 | 107 | def count_tokens(tokens): 108 | counter = {} 109 | for t in tokens: 110 | if t in counter.keys(): 111 | counter[t] += 1 112 | else: 113 | counter[t] = 1 114 | return counter 115 | 116 | 117 | def get_f1(text_a, text_b): 118 | tokens_a = text_a.lower().split() 119 | tokens_b = text_b.lower().split() 120 | if len(tokens_a) == 0 or len(tokens_b) == 0: 121 | return 1 if len(tokens_a) == len(tokens_b) else 0 122 | set_a = count_tokens(tokens_a) 123 | set_b = count_tokens(tokens_b) 124 | match = 0 125 | for token in set_a.keys(): 126 | if token in set_b.keys(): 127 | match += min(set_a[token], set_b[token]) 128 | p = match / len(tokens_a) 129 | r = match / len(tokens_b) 130 | return 2.0 * p * r / (p + r + 1e-5) 131 | 132 | 133 | _tok_dict = {} 134 | 135 | 136 | def _is_digit(w): 137 | for ch in w: 138 | if not(ch.isdigit() or ch == ','): 139 | return False 140 | return True 141 | 142 | 143 | def fix_tokenization(text): 144 | input_tokens = text.split() 145 | output_tokens = [] 146 | i = 0 147 | prev_dash = False 148 | while i < len(input_tokens): 149 | tok = input_tokens[i] 150 | flag_prev_dash = False 151 | if tok in _tok_dict.keys(): 152 | output_tokens.append(_tok_dict[tok]) 153 | i += 1 154 | elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t": 155 | output_tokens[-1] = output_tokens[-1][:-1] 156 | output_tokens.append("n't") 157 | i += 2 158 | elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"): 159 | output_tokens.append("'"+input_tokens[i + 1]) 160 | i += 2 161 | elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".": 162 | output_tokens.append("...") 163 | i += 3 164 | elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]): 165 | # $ 3 , 000 -> $ 3,000 166 | output_tokens[-1] += ','+input_tokens[i + 1] 167 | i += 2 168 | elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit(): 169 | # 3 . 03 -> $ 3.03 170 | output_tokens[-1] += '.'+input_tokens[i + 1] 171 | i += 2 172 | elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.': 173 | # U . N . -> U.N. 174 | k = i+3 175 | while k+2 < len(input_tokens): 176 | if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.': 177 | k += 2 178 | else: 179 | break 180 | output_tokens[-1] += ''.join(input_tokens[i:k]) 181 | i += 2 182 | elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation: 183 | output_tokens[-1] += tok 184 | i += 1 185 | else: 186 | output_tokens.append(tok) 187 | i += 1 188 | prev_dash = flag_prev_dash 189 | return " ".join(output_tokens) 190 | 191 | 192 | def process_eval(eval_fn): 193 | gold_list = [] 194 | with open(args.gold, "r", encoding="utf-8") as f_in: 195 | for l in f_in: 196 | line = l.strip() 197 | gold_list.append(line) 198 | 199 | pred_list = [] 200 | with open(eval_fn, "r", encoding="utf-8") as f_in: 201 | for l in f_in: 202 | buf = [] 203 | sentence = fix_tokenization(l.strip()).replace("(", " -LRB- ").replace(")", " -RRB- ") 204 | while " " in sentence: 205 | sentence = sentence.replace(" ", " ") 206 | buf.append(sentence) 207 | if args.trunc_len: 208 | num_left = args.trunc_len 209 | trunc_list = [] 210 | for bit in buf: 211 | tk_list = bit.split() 212 | n = min(len(tk_list), num_left) 213 | trunc_list.append(' '.join(tk_list[:n])) 214 | num_left -= n 215 | if num_left <= 0: 216 | break 217 | else: 218 | trunc_list = buf 219 | line = "\n".join(trunc_list) 220 | pred_list.append(line) 221 | with open(eval_fn+'.post', 'w', encoding='utf-8') as f_out: 222 | for l in pred_list: 223 | f_out.write(l.strip()) 224 | f_out.write('\n') 225 | # rouge scores 226 | if len(pred_list) < len(gold_list): 227 | # evaluate subset 228 | gold_list = gold_list[:len(pred_list)] 229 | assert len(pred_list) == len(gold_list) 230 | if args.perl: 231 | scores = test_rouge(pred_list, gold_list) 232 | else: 233 | scores = evaluator.get_scores(pred_list, [[it] for it in gold_list]) 234 | return eval_fn, scores 235 | 236 | 237 | def main(): 238 | if args.perl: 239 | eval_fn_list = list(glob.glob(args.pred)) 240 | else: 241 | eval_fn_list = [eval_fn for eval_fn in glob.glob(args.pred) if not( 242 | args.lazy_eval and Path(eval_fn+".rouge").exists())] 243 | eval_fn_list = list(filter(lambda fn: not(fn.endswith( 244 | '.post') or fn.endswith('.rouge')), eval_fn_list)) 245 | 246 | if args.only_eval_best: 247 | best_epoch_dict = {} 248 | for dir_path in set(Path(fn).parent for fn in eval_fn_list): 249 | fn_save = os.path.join(dir_path, 'save_best.dev') 250 | if Path(fn_save).exists(): 251 | with open(fn_save, 'r') as f_in: 252 | __, o_name, __ = f_in.read().strip().split('\n') 253 | epoch = o_name.split('.')[1] 254 | best_epoch_dict[dir_path] = epoch 255 | new_eval_fn_list = [] 256 | for fn in eval_fn_list: 257 | dir_path = Path(fn).parent 258 | if dir_path in best_epoch_dict: 259 | if Path(fn).name.split('.')[1] == best_epoch_dict[dir_path]: 260 | new_eval_fn_list.append(fn) 261 | eval_fn_list = new_eval_fn_list 262 | 263 | logger.info("***** Evaluation: %s *****", ','.join(eval_fn_list)) 264 | num_pool = min(args.processes, len(eval_fn_list)) 265 | p = Pool(num_pool) 266 | r_list = p.imap_unordered(process_eval, eval_fn_list) 267 | r_list = sorted([(fn, scores) 268 | for fn, scores in r_list], key=lambda x: x[0]) 269 | rg2_dict = {} 270 | for fn, scores in r_list: 271 | print(fn) 272 | if args.perl: 273 | print(rouge_results_to_str(scores)) 274 | else: 275 | rg2_dict[fn] = scores['rouge-2']['f'] 276 | print( 277 | "ROUGE-1: {}\tROUGE-2: {}\n".format(scores['rouge-1']['f'], scores['rouge-2']['f'])) 278 | with open(fn+".rouge", 'w') as f_out: 279 | f_out.write(json.dumps( 280 | {'rg1': scores['rouge-1']['f'], 'rg2': scores['rouge-2']['f']})) 281 | p.close() 282 | p.join() 283 | 284 | if args.save_best: 285 | # find best results 286 | group_dict = {} 287 | for k, v in rg2_dict.items(): 288 | d_name, o_name = Path(k).parent, Path(k).name 289 | if (d_name not in group_dict) or (v > group_dict[d_name][1]): 290 | group_dict[d_name] = (o_name, v) 291 | # compare and save the best result 292 | for k, v in group_dict.items(): 293 | fn = os.path.join(k, 'save_best.'+args.split) 294 | o_name_s, rst_s = v 295 | should_save = True 296 | if Path(fn).exists(): 297 | with open(fn, 'r') as f_in: 298 | rst_f = float(f_in.read().strip().split('\n')[-1]) 299 | if rst_s <= rst_f: 300 | should_save = False 301 | if should_save: 302 | with open(fn, 'w') as f_out: 303 | f_out.write('{0}\n{1}\n{2}\n'.format(k, o_name_s, rst_s)) 304 | logger.info("Should save: {}".format(json.dumps(v, indent=2))) 305 | 306 | 307 | if __name__ == "__main__": 308 | main() 309 | -------------------------------------------------------------------------------- /src/decode_seq2seq.py: -------------------------------------------------------------------------------- 1 | """BERT finetuning runner.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import json 9 | import logging 10 | import argparse 11 | import math 12 | from tqdm import tqdm, trange 13 | import numpy as np 14 | import torch 15 | import random 16 | import pickle 17 | 18 | from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder, BertConfig 19 | from transformers.tokenization_bert import whitespace_tokenize 20 | import s2s_ft.s2s_loader as seq2seq_loader 21 | from s2s_ft.utils import load_and_cache_examples 22 | from transformers import \ 23 | BertTokenizer, RobertaTokenizer 24 | from s2s_ft.tokenization_unilm import UnilmTokenizer 25 | from s2s_ft.tokenization_minilm import MinilmTokenizer 26 | 27 | TOKENIZER_CLASSES = { 28 | 'bert': BertTokenizer, 29 | 'minilm': MinilmTokenizer, 30 | 'roberta': RobertaTokenizer, 31 | 'unilm': UnilmTokenizer, 32 | } 33 | 34 | class WhitespaceTokenizer(object): 35 | def tokenize(self, text): 36 | return whitespace_tokenize(text) 37 | 38 | 39 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 40 | datefmt='%m/%d/%Y %H:%M:%S', 41 | level=logging.INFO) 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | def detokenize(tk_list): 46 | r_list = [] 47 | for tk in tk_list: 48 | if tk.startswith('##') and len(r_list) > 0: 49 | r_list[-1] = r_list[-1] + tk[2:] 50 | else: 51 | r_list.append(tk) 52 | return r_list 53 | 54 | 55 | def ascii_print(text): 56 | text = text.encode("ascii", "ignore") 57 | print(text) 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser() 62 | 63 | # Required parameters 64 | parser.add_argument("--model_type", default=None, type=str, required=True, 65 | help="Model type selected in the list: " + ", ".join(TOKENIZER_CLASSES.keys())) 66 | parser.add_argument("--model_path", default=None, type=str, required=True, 67 | help="Path to the model checkpoint.") 68 | parser.add_argument("--config_path", default=None, type=str, 69 | help="Path to config.json for the model.") 70 | 71 | # tokenizer_name 72 | parser.add_argument("--tokenizer_name", default=None, type=str, required=True, 73 | help="tokenizer name") 74 | parser.add_argument("--max_seq_length", default=512, type=int, 75 | help="The maximum total input sequence length after WordPiece tokenization. \n" 76 | "Sequences longer than this will be truncated, and sequences shorter \n" 77 | "than this will be padded.") 78 | 79 | # decoding parameters 80 | parser.add_argument('--fp16', action='store_true', 81 | help="Whether to use 16-bit float precision instead of 32-bit") 82 | parser.add_argument('--amp', action='store_true', 83 | help="Whether to use amp for fp16") 84 | parser.add_argument("--input_file", type=str, help="Input file") 85 | parser.add_argument('--subset', type=int, default=0, 86 | help="Decode a subset of the input dataset.") 87 | parser.add_argument("--output_file", type=str, help="output file") 88 | parser.add_argument("--split", type=str, default="", 89 | help="Data split (train/val/test).") 90 | parser.add_argument('--tokenized_input', action='store_true', 91 | help="Whether the input is tokenized.") 92 | parser.add_argument('--seed', type=int, default=123, 93 | help="random seed for initialization") 94 | parser.add_argument("--do_lower_case", action='store_true', 95 | help="Set this flag if you are using an uncased model.") 96 | parser.add_argument('--batch_size', type=int, default=4, 97 | help="Batch size for decoding.") 98 | parser.add_argument('--beam_size', type=int, default=1, 99 | help="Beam size for searching") 100 | parser.add_argument('--length_penalty', type=float, default=0, 101 | help="Length penalty for beam search") 102 | 103 | parser.add_argument('--forbid_duplicate_ngrams', action='store_true') 104 | parser.add_argument('--forbid_ignore_word', type=str, default=None, 105 | help="Forbid the word during forbid_duplicate_ngrams") 106 | parser.add_argument("--min_len", default=1, type=int) 107 | parser.add_argument('--need_score_traces', action='store_true') 108 | parser.add_argument('--ngram_size', type=int, default=3) 109 | parser.add_argument('--mode', default="s2s", 110 | choices=["s2s", "l2r", "both"]) 111 | parser.add_argument('--max_tgt_length', type=int, default=128, 112 | help="maximum length of target sequence") 113 | parser.add_argument('--s2s_special_token', action='store_true', 114 | help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") 115 | parser.add_argument('--s2s_add_segment', action='store_true', 116 | help="Additional segmental for the encoder of S2S.") 117 | parser.add_argument('--s2s_share_segment', action='store_true', 118 | help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).") 119 | parser.add_argument('--pos_shift', action='store_true', 120 | help="Using position shift for fine-tuning.") 121 | parser.add_argument("--cache_dir", default=None, type=str, 122 | help="Where do you want to store the pre-trained models downloaded from s3") 123 | 124 | args = parser.parse_args() 125 | 126 | if args.need_score_traces and args.beam_size <= 1: 127 | raise ValueError( 128 | "Score trace is only available for beam search with beam size > 1.") 129 | if args.max_tgt_length >= args.max_seq_length - 2: 130 | raise ValueError("Maximum tgt length exceeds max seq length - 2.") 131 | 132 | device = torch.device( 133 | "cuda" if torch.cuda.is_available() else "cpu") 134 | n_gpu = torch.cuda.device_count() 135 | 136 | if args.seed > 0: 137 | random.seed(args.seed) 138 | np.random.seed(args.seed) 139 | torch.manual_seed(args.seed) 140 | if n_gpu > 0: 141 | torch.cuda.manual_seed_all(args.seed) 142 | else: 143 | random_seed = random.randint(0, 10000) 144 | logger.info("Set random seed as: {}".format(random_seed)) 145 | random.seed(random_seed) 146 | np.random.seed(random_seed) 147 | torch.manual_seed(random_seed) 148 | if n_gpu > 0: 149 | torch.cuda.manual_seed_all(args.seed) 150 | 151 | tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained( 152 | args.tokenizer_name, do_lower_case=args.do_lower_case, 153 | cache_dir=args.cache_dir if args.cache_dir else None, max_len=args.max_seq_length) 154 | 155 | if args.model_type == "roberta": 156 | vocab = tokenizer.encoder 157 | else: 158 | vocab = tokenizer.vocab 159 | 160 | 161 | config_file = args.config_path if args.config_path else os.path.join(args.model_path, "config.json") 162 | logger.info("Read decoding config from: %s" % config_file) 163 | config = BertConfig.from_json_file(config_file) 164 | 165 | bi_uni_pipeline = [] 166 | bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder( 167 | list(vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, 168 | max_tgt_length=args.max_tgt_length, pos_shift=args.pos_shift, 169 | source_type_id=config.source_type_id, target_type_id=config.target_type_id, 170 | cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token, pad_token=tokenizer.pad_token)) 171 | 172 | mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( 173 | [tokenizer.mask_token, tokenizer.sep_token, tokenizer.sep_token]) 174 | forbid_ignore_set = None 175 | if args.forbid_ignore_word: 176 | w_list = [] 177 | for w in args.forbid_ignore_word.split('|'): 178 | if w.startswith('[') and w.endswith(']'): 179 | w_list.append(w.upper()) 180 | else: 181 | w_list.append(w) 182 | forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list)) 183 | print(args.model_path) 184 | found_checkpoint_flag = False 185 | for model_recover_path in [args.model_path.strip()]: 186 | logger.info("***** Recover model: %s *****", model_recover_path) 187 | found_checkpoint_flag = True 188 | model = BertForSeq2SeqDecoder.from_pretrained( 189 | model_recover_path, config=config, mask_word_id=mask_word_id, search_beam_size=args.beam_size, 190 | length_penalty=args.length_penalty, eos_id=eos_word_ids, sos_id=sos_word_id, 191 | forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, 192 | ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, 193 | max_position_embeddings=args.max_seq_length, pos_shift=args.pos_shift, 194 | ) 195 | 196 | if args.fp16: 197 | model.half() 198 | model.to(device) 199 | if n_gpu > 1: 200 | model = torch.nn.DataParallel(model) 201 | 202 | torch.cuda.empty_cache() 203 | model.eval() 204 | next_i = 0 205 | max_src_length = args.max_seq_length - 2 - args.max_tgt_length 206 | 207 | to_pred = load_and_cache_examples( 208 | args.input_file, tokenizer, local_rank=-1, 209 | cached_features_file=None, shuffle=False) 210 | 211 | input_lines = [] 212 | for line in to_pred: 213 | _line = tokenizer.convert_ids_to_tokens(line["source_ids"])[:max_src_length] 214 | 215 | input_lines.append(_line) 216 | if args.subset > 0: 217 | logger.info("Decoding subset: %d", args.subset) 218 | input_lines = input_lines[:args.subset] 219 | 220 | input_lines = sorted(list(enumerate(input_lines)), 221 | key=lambda x: -len(x[1])) 222 | output_lines = [""] * len(input_lines) 223 | score_trace_list = [None] * len(input_lines) 224 | total_batch = math.ceil(len(input_lines) / args.batch_size) 225 | 226 | with tqdm(total=total_batch) as pbar: 227 | batch_count = 0 228 | first_batch = True 229 | while next_i < len(input_lines): 230 | _chunk = input_lines[next_i:next_i + args.batch_size] 231 | buf_id = [x[0] for x in _chunk] 232 | buf = [x[1] for x in _chunk] 233 | next_i += args.batch_size 234 | batch_count += 1 235 | max_a_len = max([len(x) for x in buf]) 236 | instances = [] 237 | for instance in [(x, max_a_len) for x in buf]: 238 | for proc in bi_uni_pipeline: 239 | instances.append(proc(instance)) 240 | with torch.no_grad(): 241 | batch = seq2seq_loader.batch_list_to_batch_tensors( 242 | instances) 243 | batch = [ 244 | t.to(device) if t is not None else None for t in batch] 245 | input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch 246 | traces = model(input_ids, token_type_ids, 247 | position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv) 248 | if args.beam_size > 1: 249 | traces = {k: v.tolist() for k, v in traces.items()} 250 | output_ids = traces['pred_seq'] 251 | else: 252 | output_ids = traces.tolist() 253 | for i in range(len(buf)): 254 | w_ids = output_ids[i] 255 | output_buf = tokenizer.convert_ids_to_tokens(w_ids) 256 | output_tokens = [] 257 | for t in output_buf: 258 | if t in (tokenizer.sep_token, tokenizer.pad_token): 259 | break 260 | output_tokens.append(t) 261 | if args.model_type == "roberta": 262 | output_sequence = tokenizer.convert_tokens_to_string(output_tokens) 263 | else: 264 | output_sequence = ' '.join(detokenize(output_tokens)) 265 | if '\n' in output_sequence: 266 | output_sequence = " [X_SEP] ".join(output_sequence.split('\n')) 267 | output_lines[buf_id[i]] = output_sequence 268 | if first_batch or batch_count % 50 == 0: 269 | logger.info("{} = {}".format(buf_id[i], output_sequence)) 270 | if args.need_score_traces: 271 | score_trace_list[buf_id[i]] = { 272 | 'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]} 273 | pbar.update(1) 274 | first_batch = False 275 | if args.output_file: 276 | fn_out = args.output_file 277 | else: 278 | fn_out = model_recover_path+'.'+args.split 279 | with open(fn_out, "w", encoding="utf-8") as fout: 280 | for l in output_lines: 281 | fout.write(l) 282 | fout.write("\n") 283 | 284 | if args.need_score_traces: 285 | with open(fn_out + ".trace.pickle", "wb") as fout_trace: 286 | pickle.dump( 287 | {"version": 0.0, "num_samples": len(input_lines)}, fout_trace) 288 | for x in score_trace_list: 289 | pickle.dump(x, fout_trace) 290 | 291 | if not found_checkpoint_flag: 292 | logger.info("Not found the model checkpoint file!") 293 | 294 | 295 | if __name__ == "__main__": 296 | main() 297 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # s2s-ft: Sequence-to-Sequence Fine-Tuning 2 | **A PyTorch package used to fine-tune pre-trained Transformers for sequence-to-sequence language generation** 3 | 4 | ## Environment 5 | 6 | The recommended way to run the code is using docker: 7 | ```bash 8 | docker run -it --rm --runtime=nvidia --ipc=host --privileged pytorch/pytorch:1.2-cuda10.0-cudnn7-devel bash 9 | ``` 10 | 11 | The following Python package need to be installed: 12 | ```bash 13 | pip install --user methodtools py-rouge pyrouge nltk 14 | python -c "import nltk; nltk.download('punkt')" 15 | git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext 16 | ``` 17 | 18 | Install the repo as a package: 19 | ```bash 20 | git clone this repo into ${code_dir} 21 | 22 | cd ${code_dir} ; pip install --editable . 23 | ``` 24 | 25 | ## Pre-trained Models 26 | 27 | We recommend to use the uncased model: 28 | - [unilm1.2-base-uncased](https://unilm.blob.core.windows.net/ckpt/unilm1.2-base-uncased.bin): 12-layer, 768-hidden, 12-heads, 110M parameters 29 | 30 | If you would like to use a cased model: 31 | - [unilm1-base-cased](https://unilm.blob.core.windows.net/ckpt/unilm1-base-cased.bin): 12-layer, 768-hidden, 12-heads, 110M parameters 32 | - [unilm1-large-cased](https://unilm.blob.core.windows.net/ckpt/unilm1-large-cased.bin): 24-layer, 1024-hidden, 16-heads, 340M parameters 33 | 34 | If you prefer [small pretrained models](https://github.com/microsoft/unilm/tree/master/minilm) for faster inference speed: 35 | - [minilm-l12-h384-uncased](https://1drv.ms/u/s!AjHn0yEmKG8qixAYyu2Fvq5ulnU7?e=DFApTA): 12-layer, 384-hidden, 12-heads, 33M parameters 36 | 37 | ## Input File Format 38 | 39 | We support two dataset formats: 40 | 41 | 1. Text format: each line contains a json string of an example. `"src"` contains source sequence text, `"tgt"` contains target sequence text (`"tgt"` can be ignored for decoding). The data should be pre-processed as follows: 42 | 43 | ```bash 44 | {"src": "Messages posted on social media claimed the user planned to `` kill as many people as possible ''", "tgt": "Threats to kill pupils in a shooting at a Blackpool school are being investigated by Lancashire police ."} 45 | {"src": "Media playback is unsupported on your device", "tgt": "A slide running the entire length of one of the steepest city centre streets in Europe has been turned into a massive three-lane water adventure ."} 46 | {"src": "Chris Erskine crossed low for Kris Doolan to tap home and give the Jags an early lead .", "tgt": "Partick Thistle will finish in the Scottish Premiership 's top six for the first time after beating Motherwell"} 47 | ``` 48 | 49 | 2. Tokenized format: if you use tokenized data (with the same WordPiece tokenizers as BERT), `"src"` is a list of source sequence tokens, and `"tgt"` is a list of target sequence tokens (`"tgt"` can be ignored for decoding): 50 | 51 | ```bash 52 | {"src": ["messages", "posted", "on", "social", "media", "claimed", "the", "user", "planned", "to", "\"", "kill", "as", "many", "people", "as", "possible", "\""], "tgt": ["threats", "to", "kill", "pupils", "in", "a", "shooting", "at", "a", "blackpool", "school", "are", "being", "investigated", "by", "lancashire", "police", "."]} 53 | {"src": ["media", "playback", "is", "un", "##su", "##pp", "##orted", "on", "your", "device"], "tgt": ["a", "slide", "running", "the", "entire", "length", "of", "one", "of", "the", "steep", "##est", "city", "centre", "streets", "in", "europe", "has", "been", "turned", "into", "a", "massive", "three", "-", "lane", "water", "adventure", "."]} 54 | {"src": ["chris", "erskine", "crossed", "low", "for", "kris", "doo", "##lan", "to", "tap", "home", "and", "give", "the", "ja", "##gs", "an", "early", "lead", "."], "tgt": ["part", "##ick", "thistle", "will", "finish", "in", "the", "scottish", "premiership", "'", "s", "top", "six", "for", "the", "first", "time", "after", "beating", "mother", "##well"]} 55 | ``` 56 | 57 | The code automatically detects the input format. If the json line contains `list`, we process the input as the tokenized format; if the json line contains `string`, the code will tokenize them. 58 | 59 | 60 | ## Example: [XSum](https://github.com/EdinburghNLP/XSum) with unilm1.2-base-uncased 61 | 62 | ### Fine-tuning 63 | 64 | Pre-processed json dataset links: [text format](https://unilm.blob.core.windows.net/s2s-ft-data/xsum.json.zip), or [tokenized format](https://unilm.blob.core.windows.net/s2s-ft-data/xsum.uncased_tokenized.zip). 65 | 66 | ```bash 67 | # path of training data 68 | TRAIN_FILE=/your/path/to/train.json 69 | # folder used to save fine-tuned checkpoints 70 | OUTPUT_DIR=/your/path/to/save_checkpoints 71 | # folder used to cache package dependencies 72 | CACHE_DIR=/your/path/to/transformer_package_cache 73 | 74 | export CUDA_VISIBLE_DEVICES=0,1,2,3 75 | python -m torch.distributed.launch --nproc_per_node=4 run_seq2seq.py \ 76 | --train_file ${TRAIN_FILE} --output_dir ${OUTPUT_DIR} \ 77 | --model_type unilm --model_name_or_path unilm1.2-base-uncased \ 78 | --do_lower_case --fp16 --fp16_opt_level O2 --max_source_seq_length 464 --max_target_seq_length 48 \ 79 | --per_gpu_train_batch_size 16 --gradient_accumulation_steps 1 \ 80 | --learning_rate 7e-5 --num_warmup_steps 500 --num_training_steps 32000 --cache_dir ${CACHE_DIR} 81 | ``` 82 | 83 | - The fine-tuning batch size = `number of gpus` * `per_gpu_train_batch_size` * `gradient_accumulation_steps`. So in the above example, the batch size is `4*16*1 = 64`. The three arguments need to be adjusted together in order to remain the total batch size unchanged. 84 | - `--do_lower_case`: for uncased models 85 | 86 | ### Decoding 87 | 88 | ```bash 89 | # path of the fine-tuned checkpoint 90 | MODEL_PATH=/your/path/to/model_checkpoint 91 | SPLIT=validation 92 | # input file that you would like to decode 93 | INPUT_JSON=/your/path/to/${SPLIT}.json 94 | 95 | export CUDA_VISIBLE_DEVICES=0 96 | export OMP_NUM_THREADS=4 97 | export MKL_NUM_THREADS=4 98 | 99 | python decode_seq2seq.py \ 100 | --fp16 --model_type unilm --tokenizer_name unilm1.2-base-uncased --input_file ${INPUT_JSON} --split $SPLIT --do_lower_case \ 101 | --model_path ${MODEL_PATH} --max_seq_length 512 --max_tgt_length 48 --batch_size 32 --beam_size 5 \ 102 | --length_penalty 0 --forbid_duplicate_ngrams --mode s2s --forbid_ignore_word "." 103 | ``` 104 | 105 | - The decoding results are saved at `${MODEL_PATH}.${SPLIT}`. 106 | - `--do_lower_case`: for uncased models 107 | 108 | ### Evalation 109 | 110 | The golden answer text files can be downloaded at [here](https://unilm.blob.core.windows.net/s2s-ft-data/xsum.eval.zip). 111 | 112 | ```bash 113 | SPLIT=validation 114 | GOLD_PATH=/your/path/to/${SPLIT}.target 115 | # ${MODEL_PATH}.${SPLIT} is the predicted target file 116 | python evaluations/eval_for_xsum.py --pred ${MODEL_PATH}.${SPLIT} --gold ${GOLD_PATH} --split ${SPLIT} 117 | ``` 118 | 119 | 120 | ## Example: [XSum](https://github.com/EdinburghNLP/XSum) with minilm-l12-h384-uncased 121 | 122 | ### Fine-tuning 123 | 124 | Pre-processed json dataset links: [text format](https://unilm.blob.core.windows.net/s2s-ft-data/xsum.json.zip), or [tokenized format](https://unilm.blob.core.windows.net/s2s-ft-data/xsum.uncased_tokenized.zip). 125 | 126 | ```bash 127 | # path of training data 128 | TRAIN_FILE=/your/path/to/train.json 129 | # folder used to save fine-tuned checkpoints 130 | OUTPUT_DIR=/your/path/to/save_checkpoints 131 | # folder used to cache package dependencies 132 | CACHE_DIR=/your/path/to/transformer_package_cache 133 | 134 | export CUDA_VISIBLE_DEVICES=0,1,2,3 135 | python -m torch.distributed.launch --nproc_per_node=4 run_seq2seq.py \ 136 | --train_file ${TRAIN_FILE} --output_dir ${OUTPUT_DIR} \ 137 | --model_type minilm --model_name_or_path minilm-l12-h384-uncased \ 138 | --do_lower_case --fp16 --fp16_opt_level O2 --max_source_seq_length 464 --max_target_seq_length 48 \ 139 | --per_gpu_train_batch_size 16 --gradient_accumulation_steps 1 \ 140 | --learning_rate 1e-4 --num_warmup_steps 500 --num_training_steps 108000 --cache_dir ${CACHE_DIR} 141 | ``` 142 | 143 | - The fine-tuning batch size = `number of gpus` * `per_gpu_train_batch_size` * `gradient_accumulation_steps`. So in the above example, the batch size is `4*16*1 = 64`. The three arguments need to be adjusted together in order to remain the total batch size unchanged. 144 | - `--do_lower_case`: for uncased models 145 | 146 | ### Decoding 147 | 148 | ```bash 149 | # path of the fine-tuned checkpoint 150 | MODEL_PATH=/your/path/to/model_checkpoint 151 | SPLIT=validation 152 | # input file that you would like to decode 153 | INPUT_JSON=/your/path/to/${SPLIT}.json 154 | 155 | export CUDA_VISIBLE_DEVICES=0 156 | export OMP_NUM_THREADS=4 157 | export MKL_NUM_THREADS=4 158 | 159 | python decode_seq2seq.py \ 160 | --fp16 --model_type minilm --tokenizer_name minilm-l12-h384-uncased --input_file ${INPUT_JSON} --split $SPLIT --do_lower_case \ 161 | --model_path ${MODEL_PATH} --max_seq_length 512 --max_tgt_length 48 --batch_size 32 --beam_size 5 \ 162 | --length_penalty 0 --forbid_duplicate_ngrams --mode s2s --forbid_ignore_word "." 163 | ``` 164 | 165 | - The decoding results are saved at `${MODEL_PATH}.${SPLIT}`. 166 | - `--do_lower_case`: for uncased models 167 | 168 | ### Evalation 169 | 170 | The golden answer text files can be downloaded at [here](https://unilm.blob.core.windows.net/s2s-ft-data/xsum.eval.zip). 171 | 172 | ```bash 173 | SPLIT=validation 174 | GOLD_PATH=/your/path/to/${SPLIT}.target 175 | # ${MODEL_PATH}.${SPLIT} is the predicted target file 176 | python evaluations/eval_for_xsum.py --pred ${MODEL_PATH}.${SPLIT} --gold ${GOLD_PATH} --split ${SPLIT} 177 | ``` 178 | 179 | 180 | ## Example: CNN / Daily Mail with unilm1-base-cased 181 | 182 | Pre-processed json dataset links: [tokenized format](https://unilm.blob.core.windows.net/s2s-ft-data/cnndm.cased_tokenized.zip). 183 | 184 | ### Fine-tuning 185 | 186 | ```bash 187 | # path of training data 188 | export TRAIN_FILE=/your/path/to/train.json 189 | # path used to cache training data 190 | export CACHED_FEATURE_FILE=/your/path/to/cnndm_train.cased.features.pt 191 | # folder used to save fine-tuned checkpoints 192 | export OUTPUT_DIR=/your/path/to/save_checkpoints 193 | # folder used to cache package dependencies 194 | export CACHE_DIR=/your/path/to/transformer_package_cache 195 | 196 | export CUDA_VISIBLE_DEVICES=0,1,2,3 197 | python -m torch.distributed.launch --nproc_per_node=4 run_seq2seq.py \ 198 | --train_file $TRAIN_FILE --cached_train_features_file $CACHED_FEATURE_FILE --output_dir $OUTPUT_DIR \ 199 | --model_type unilm --model_name_or_path unilm1-base-cased --fp16 --fp16_opt_level O2 \ 200 | --max_source_seq_length 608 --max_target_seq_length 160 --per_gpu_train_batch_size 8 --gradient_accumulation_steps 2 \ 201 | --learning_rate 7e-5 --num_warmup_steps 1000 --num_training_steps 45000 --cache_dir $CACHE_DIR --save_steps 1500 202 | ``` 203 | 204 | - The fine-tuning batch size = `number of gpus` * `per_gpu_train_batch_size` * `gradient_accumulation_steps`. So in the above example, the batch size is `4*8*2 = 64`. The three arguments need to be adjusted together in order to remain the total batch size unchanged. 205 | - A fine-tuned checkpoint is provided at [here](https://unilm.blob.core.windows.net/ckpt/cnndm.unilm1-base-cased.bin). 206 | 207 | 208 | ### Decoding 209 | 210 | ```bash 211 | # path of the fine-tuned checkpoint 212 | MODEL_PATH=/your/path/to/model_checkpoint 213 | SPLIT=dev 214 | # input file that you would like to decode 215 | INPUT_JSON=/your/path/to/${SPLIT}.json 216 | 217 | export CUDA_VISIBLE_DEVICES=0 218 | export OMP_NUM_THREADS=4 219 | export MKL_NUM_THREADS=4 220 | 221 | python decode_seq2seq.py \ 222 | --fp16 --model_type unilm --tokenizer_name unilm1-base-cased --input_file ${INPUT_JSON} --split $SPLIT \ 223 | --model_path ${MODEL_PATH} --max_seq_length 768 --max_tgt_length 160 --batch_size 32 --beam_size 5 \ 224 | --length_penalty 0 --forbid_duplicate_ngrams --mode s2s --forbid_ignore_word "." 225 | ``` 226 | 227 | - The decoding results are saved at `${MODEL_PATH}.${SPLIT}`. 228 | 229 | ### Evalation 230 | 231 | The golden answer text files can be downloaded at [here](https://unilm.blob.core.windows.net/s2s-ft-data/cnndm.eval.zip). 232 | 233 | ```bash 234 | SPLIT=dev 235 | GOLD_PATH=/your/path/to/${SPLIT}.target 236 | # ${MODEL_PATH}.${SPLIT} is the predicted target file 237 | python evaluations/eval_for_cnndm.py --pred ${MODEL_PATH}.${SPLIT} --gold ${GOLD_PATH} --split ${SPLIT} --trunc_len 160 238 | ``` 239 | 240 | 241 | 242 | 243 | ## Example: CNN / Daily Mail with unilm1.2-base-uncased 244 | 245 | Pre-processed json dataset links: [tokenized format](https://unilm.blob.core.windows.net/s2s-ft-data/cnndm.uncased_tokenized.zip). 246 | 247 | ### Fine-tuning 248 | 249 | ```bash 250 | # path of training data 251 | export TRAIN_FILE=/your/path/to/train.json 252 | # folder used to save fine-tuned checkpoints 253 | export OUTPUT_DIR=/your/path/to/save_checkpoints 254 | # folder used to cache package dependencies 255 | export CACHE_DIR=/your/path/to/transformer_package_cache 256 | 257 | export CUDA_VISIBLE_DEVICES=0,1,2,3 258 | python -m torch.distributed.launch --nproc_per_node=4 run_seq2seq.py \ 259 | --train_file $TRAIN_FILE --output_dir $OUTPUT_DIR \ 260 | --model_type unilm --model_name_or_path unilm1.2-base-uncased --do_lower_case --fp16 --fp16_opt_level O2 \ 261 | --max_source_seq_length 608 --max_target_seq_length 160 --per_gpu_train_batch_size 8 --gradient_accumulation_steps 2 \ 262 | --learning_rate 7e-5 --num_warmup_steps 1000 --num_training_steps 45000 --cache_dir $CACHE_DIR --save_steps 1500 263 | ``` 264 | 265 | - The fine-tuning batch size = `number of gpus` * `per_gpu_train_batch_size` * `gradient_accumulation_steps`. So in the above example, the batch size is `4*8*2 = 64`. The three arguments need to be adjusted together in order to remain the total batch size unchanged. 266 | - `--do_lower_case`: for uncased models 267 | 268 | ### Decoding 269 | 270 | ```bash 271 | # path of the fine-tuned checkpoint 272 | MODEL_PATH=/your/path/to/model_checkpoint 273 | SPLIT=dev 274 | # input file that you would like to decode 275 | INPUT_JSON=/your/path/to/${SPLIT}.json 276 | 277 | export CUDA_VISIBLE_DEVICES=0 278 | export OMP_NUM_THREADS=4 279 | export MKL_NUM_THREADS=4 280 | 281 | python decode_seq2seq.py \ 282 | --fp16 --model_type unilm --tokenizer_name unilm1.2-base-uncased --do_lower_case --input_file ${INPUT_JSON} --split $SPLIT \ 283 | --model_path ${MODEL_PATH} --max_seq_length 768 --max_tgt_length 160 --batch_size 32 --beam_size 5 \ 284 | --length_penalty 0 --forbid_duplicate_ngrams --mode s2s --forbid_ignore_word "." --min_len 48 285 | ``` 286 | 287 | - The decoding results are saved at `${MODEL_PATH}.${SPLIT}`. 288 | 289 | ### Evalation 290 | 291 | The golden answer text files can be downloaded at [here](https://unilm.blob.core.windows.net/s2s-ft-data/cnndm.eval.zip). 292 | 293 | ```bash 294 | SPLIT=dev 295 | GOLD_PATH=/your/path/to/${SPLIT}.target 296 | # ${MODEL_PATH}.${SPLIT} is the predicted target file 297 | python evaluations/eval_for_cnndm.py --pred ${MODEL_PATH}.${SPLIT} --gold ${GOLD_PATH} --split ${SPLIT} --trunc_len 160 298 | ``` 299 | 300 | ## License 301 | This project is licensed under the license found in the LICENSE file in the root directory of this source tree. 302 | Portions of the source code are based on the [transformers](https://github.com/huggingface/transformers) project. 303 | 304 | [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct) 305 | -------------------------------------------------------------------------------- /src/evaluations/eval_for_cnndm.py: -------------------------------------------------------------------------------- 1 | """BERT finetuning runner.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import logging 9 | import glob 10 | import json 11 | import argparse 12 | import math 13 | import string 14 | from multiprocessing import Pool, cpu_count 15 | from tqdm import tqdm, trange 16 | from pathlib import Path 17 | import numpy as np 18 | # pip install py-rouge 19 | import rouge 20 | import time 21 | import tempfile 22 | import shutil 23 | 24 | # pip install pyrouge 25 | from evaluations.bs_pyrouge import Rouge155 26 | 27 | 28 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 29 | datefmt='%m/%d/%Y %H:%M:%S', 30 | level=logging.INFO) 31 | logger = logging.getLogger(__name__) 32 | 33 | parser = argparse.ArgumentParser() 34 | 35 | # Required parameters 36 | parser.add_argument("--gold", type=str, help="Gold output file.") 37 | parser.add_argument("--pred", type=str, help="Input prediction file.") 38 | parser.add_argument("--split", type=str, default="", 39 | help="Data split (train/dev/test).") 40 | parser.add_argument("--save_best", action='store_true', 41 | help="Save best epoch.") 42 | parser.add_argument("--only_eval_best", action='store_true', 43 | help="Only evaluate best epoch.") 44 | parser.add_argument("--trunc_len", type=int, default=60, 45 | help="Truncate line by the maximum length.") 46 | parser.add_argument("--duplicate_rate", type=float, default=0.7, 47 | help="If the duplicat rate (compared with history) is large, we can discard the current sentence.") 48 | default_process_count = max(1, cpu_count() - 1) 49 | parser.add_argument("--processes", type=int, default=default_process_count, 50 | help="Number of processes to use (default %(default)s)") 51 | parser.add_argument("--perl", action='store_true', 52 | help="Using the perl script.") 53 | parser.add_argument('--lazy_eval', action='store_true', 54 | help="Skip evaluation if the .rouge file exists.") 55 | args = parser.parse_args() 56 | 57 | SPECIAL_TOKEN = ["[UNK]", "[PAD]", "[CLS]", "[MASK]"] 58 | evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2, 59 | limit_length=False, apply_avg=True, weight_factor=1.2) 60 | 61 | 62 | def test_rouge(cand, ref): 63 | temp_dir = tempfile.mkdtemp() 64 | candidates = cand 65 | references = ref 66 | assert len(candidates) == len(references) 67 | 68 | cnt = len(candidates) 69 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 70 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) 71 | if not os.path.isdir(tmp_dir): 72 | os.mkdir(tmp_dir) 73 | os.mkdir(tmp_dir + "/candidate") 74 | os.mkdir(tmp_dir + "/reference") 75 | try: 76 | for i in range(cnt): 77 | if len(references[i]) < 1: 78 | continue 79 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 80 | encoding="utf-8") as f: 81 | f.write(candidates[i]) 82 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 83 | encoding="utf-8") as f: 84 | f.write(references[i]) 85 | r = Rouge155(temp_dir=temp_dir) 86 | r.model_dir = tmp_dir + "/reference/" 87 | r.system_dir = tmp_dir + "/candidate/" 88 | r.model_filename_pattern = 'ref.#ID#.txt' 89 | r.system_filename_pattern = r'cand.(\d+).txt' 90 | rouge_results = r.convert_and_evaluate() 91 | print(rouge_results) 92 | results_dict = r.output_to_dict(rouge_results) 93 | finally: 94 | if os.path.isdir(tmp_dir): 95 | shutil.rmtree(tmp_dir) 96 | return results_dict 97 | 98 | 99 | def rouge_results_to_str(results_dict): 100 | return ">> ROUGE-F(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 101 | results_dict["rouge_1_f_score"] * 100, 102 | results_dict["rouge_2_f_score"] * 100, 103 | results_dict["rouge_l_f_score"] * 100, 104 | results_dict["rouge_1_recall"] * 100, 105 | results_dict["rouge_2_recall"] * 100, 106 | results_dict["rouge_l_recall"] * 100 107 | ) 108 | 109 | 110 | def count_tokens(tokens): 111 | counter = {} 112 | for t in tokens: 113 | if t in counter.keys(): 114 | counter[t] += 1 115 | else: 116 | counter[t] = 1 117 | return counter 118 | 119 | 120 | def get_f1(text_a, text_b): 121 | tokens_a = text_a.lower().split() 122 | tokens_b = text_b.lower().split() 123 | if len(tokens_a) == 0 or len(tokens_b) == 0: 124 | return 1 if len(tokens_a) == len(tokens_b) else 0 125 | set_a = count_tokens(tokens_a) 126 | set_b = count_tokens(tokens_b) 127 | match = 0 128 | for token in set_a.keys(): 129 | if token in set_b.keys(): 130 | match += min(set_a[token], set_b[token]) 131 | p = match / len(tokens_a) 132 | r = match / len(tokens_b) 133 | return 2.0 * p * r / (p + r + 1e-5) 134 | 135 | 136 | _tok_dict = {"(": "-LRB-", ")": "-RRB-", 137 | "[": "-LSB-", "]": "-RSB-", 138 | "{": "-LCB-", "}": "-RCB-"} 139 | 140 | 141 | def _is_digit(w): 142 | for ch in w: 143 | if not(ch.isdigit() or ch == ','): 144 | return False 145 | return True 146 | 147 | 148 | def fix_tokenization(text): 149 | input_tokens = text.split() 150 | output_tokens = [] 151 | has_left_quote = False 152 | has_left_single_quote = False 153 | 154 | i = 0 155 | prev_dash = False 156 | while i < len(input_tokens): 157 | tok = input_tokens[i] 158 | flag_prev_dash = False 159 | if tok in _tok_dict.keys(): 160 | output_tokens.append(_tok_dict[tok]) 161 | i += 1 162 | elif tok == "\"": 163 | if has_left_quote: 164 | output_tokens.append("''") 165 | else: 166 | output_tokens.append("``") 167 | has_left_quote = not has_left_quote 168 | i += 1 169 | elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t": 170 | output_tokens[-1] = output_tokens[-1][:-1] 171 | output_tokens.append("n't") 172 | i += 2 173 | elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"): 174 | output_tokens.append("'"+input_tokens[i + 1]) 175 | i += 2 176 | elif tok == "'": 177 | if has_left_single_quote: 178 | output_tokens.append("'") 179 | else: 180 | output_tokens.append("`") 181 | has_left_single_quote = not has_left_single_quote 182 | i += 1 183 | elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".": 184 | output_tokens.append("...") 185 | i += 3 186 | elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]): 187 | # $ 3 , 000 -> $ 3,000 188 | output_tokens[-1] += ','+input_tokens[i + 1] 189 | i += 2 190 | elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit(): 191 | # 3 . 03 -> $ 3.03 192 | output_tokens[-1] += '.'+input_tokens[i + 1] 193 | i += 2 194 | elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.': 195 | # U . N . -> U.N. 196 | k = i+3 197 | while k+2 < len(input_tokens): 198 | if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.': 199 | k += 2 200 | else: 201 | break 202 | output_tokens[-1] += ''.join(input_tokens[i:k]) 203 | i += 2 204 | elif tok == "-": 205 | if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-": 206 | output_tokens.append("--") 207 | i += 2 208 | elif i == len(input_tokens) - 1 or i == 0: 209 | output_tokens.append("-") 210 | i += 1 211 | elif output_tokens[-1] not in string.punctuation and input_tokens[i + 1][0] not in string.punctuation: 212 | output_tokens[-1] += "-" 213 | i += 1 214 | flag_prev_dash = True 215 | else: 216 | output_tokens.append("-") 217 | i += 1 218 | elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation: 219 | output_tokens[-1] += tok 220 | i += 1 221 | else: 222 | output_tokens.append(tok) 223 | i += 1 224 | prev_dash = flag_prev_dash 225 | return " ".join(output_tokens) 226 | 227 | 228 | def remove_duplicate(l_list, duplicate_rate): 229 | tk_list = [l.lower().split() for l in l_list] 230 | r_list = [] 231 | history_set = set() 232 | for i, w_list in enumerate(tk_list): 233 | w_set = set(w_list) 234 | if len(w_set & history_set)/len(w_set) <= duplicate_rate: 235 | r_list.append(l_list[i]) 236 | history_set |= w_set 237 | return r_list 238 | 239 | 240 | def process_eval(eval_fn): 241 | gold_list = [] 242 | with open(args.gold, "r", encoding="utf-8") as f_in: 243 | for l in f_in: 244 | line = l.strip().replace(" ", '\n') 245 | gold_list.append(line) 246 | 247 | pred_list = [] 248 | with open(eval_fn, "r", encoding="utf-8") as f_in: 249 | for l in f_in: 250 | buf = [] 251 | for sentence in l.strip().split("[X_SEP]"): 252 | sentence = fix_tokenization(sentence) 253 | 254 | sentence = sentence.replace("(", " -LRB- ").replace(")", " -RRB- ") 255 | sentence = sentence.replace("[", " -LSB- ").replace("]", " -RSB- ") 256 | while " " in sentence: 257 | sentence = sentence.replace(" ", " ") 258 | 259 | if any(get_f1(sentence, s) > 1.0 for s in buf): 260 | continue 261 | s_len = len(sentence.split()) 262 | if s_len <= 4: 263 | continue 264 | buf.append(sentence) 265 | if args.duplicate_rate and args.duplicate_rate < 1: 266 | buf = remove_duplicate(buf, args.duplicate_rate) 267 | if args.trunc_len: 268 | num_left = args.trunc_len 269 | trunc_list = [] 270 | for bit in buf: 271 | tk_list = bit.split() 272 | n = min(len(tk_list), num_left) 273 | trunc_list.append(' '.join(tk_list[:n])) 274 | num_left -= n 275 | if num_left <= 0: 276 | break 277 | else: 278 | trunc_list = buf 279 | line = "\n".join(trunc_list) 280 | pred_list.append(line) 281 | with open(eval_fn+'.post', 'w', encoding='utf-8') as f_out: 282 | for l in pred_list: 283 | f_out.write(l.replace('\n', ' [X_SEP] ').strip()) 284 | f_out.write('\n') 285 | # rouge scores 286 | if len(pred_list) < len(gold_list): 287 | # evaluate subset 288 | gold_list = gold_list[:len(pred_list)] 289 | assert len(pred_list) == len(gold_list) 290 | if args.perl: 291 | scores = test_rouge(pred_list, gold_list) 292 | else: 293 | scores = evaluator.get_scores(pred_list, [[it] for it in gold_list]) 294 | return eval_fn, scores 295 | 296 | 297 | def main(): 298 | if args.perl: 299 | eval_fn_list = list(glob.glob(args.pred)) 300 | else: 301 | eval_fn_list = [eval_fn for eval_fn in glob.glob(args.pred) if not( 302 | args.lazy_eval and Path(eval_fn+".rouge").exists())] 303 | eval_fn_list = list(filter(lambda fn: not(fn.endswith( 304 | '.post') or fn.endswith('.rouge')), eval_fn_list)) 305 | 306 | if args.only_eval_best: 307 | best_epoch_dict = {} 308 | for dir_path in set(Path(fn).parent for fn in eval_fn_list): 309 | fn_save = os.path.join(dir_path, 'save_best.dev') 310 | if Path(fn_save).exists(): 311 | with open(fn_save, 'r') as f_in: 312 | __, o_name, __ = f_in.read().strip().split('\n') 313 | epoch = o_name.split('.')[1] 314 | best_epoch_dict[dir_path] = epoch 315 | new_eval_fn_list = [] 316 | for fn in eval_fn_list: 317 | dir_path = Path(fn).parent 318 | if dir_path in best_epoch_dict: 319 | if Path(fn).name.split('.')[1] == best_epoch_dict[dir_path]: 320 | new_eval_fn_list.append(fn) 321 | eval_fn_list = new_eval_fn_list 322 | 323 | logger.info("***** Evaluation: %s *****", ','.join(eval_fn_list)) 324 | num_pool = min(args.processes, len(eval_fn_list)) 325 | p = Pool(num_pool) 326 | r_list = p.imap_unordered(process_eval, eval_fn_list) 327 | r_list = sorted([(fn, scores) 328 | for fn, scores in r_list], key=lambda x: x[0]) 329 | rg2_dict = {} 330 | for fn, scores in r_list: 331 | print(fn) 332 | if args.perl: 333 | print(rouge_results_to_str(scores)) 334 | else: 335 | rg2_dict[fn] = scores['rouge-2']['f'] 336 | print( 337 | "ROUGE-1: {}\tROUGE-2: {}\n".format(scores['rouge-1']['f'], scores['rouge-2']['f'])) 338 | with open(fn+".rouge", 'w') as f_out: 339 | f_out.write(json.dumps( 340 | {'rg1': scores['rouge-1']['f'], 'rg2': scores['rouge-2']['f']})) 341 | p.close() 342 | p.join() 343 | 344 | if args.save_best: 345 | # find best results 346 | group_dict = {} 347 | for k, v in rg2_dict.items(): 348 | d_name, o_name = Path(k).parent, Path(k).name 349 | if (d_name not in group_dict) or (v > group_dict[d_name][1]): 350 | group_dict[d_name] = (o_name, v) 351 | # compare and save the best result 352 | for k, v in group_dict.items(): 353 | fn = os.path.join(k, 'save_best.'+args.split) 354 | o_name_s, rst_s = v 355 | should_save = True 356 | if Path(fn).exists(): 357 | with open(fn, 'r') as f_in: 358 | rst_f = float(f_in.read().strip().split('\n')[-1]) 359 | if rst_s <= rst_f: 360 | should_save = False 361 | if should_save: 362 | with open(fn, 'w') as f_out: 363 | f_out.write('{0}\n{1}\n{2}\n'.format(k, o_name_s, rst_s)) 364 | 365 | 366 | if __name__ == "__main__": 367 | main() 368 | -------------------------------------------------------------------------------- /src/evaluations/bs_pyrouge.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, division 2 | 3 | import os 4 | import re 5 | import codecs 6 | import platform 7 | 8 | from subprocess import check_output 9 | from tempfile import mkdtemp 10 | from functools import partial 11 | 12 | try: 13 | from configparser import ConfigParser 14 | except ImportError: 15 | from ConfigParser import ConfigParser 16 | 17 | from pyrouge.utils import log 18 | from pyrouge.utils.file_utils import verify_dir 19 | 20 | 21 | REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", 22 | "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} 23 | 24 | 25 | def clean(x): 26 | return re.sub( 27 | r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", 28 | lambda m: REMAP.get(m.group()), x) 29 | 30 | 31 | class DirectoryProcessor: 32 | 33 | @staticmethod 34 | def process(input_dir, output_dir, function): 35 | """ 36 | Apply function to all files in input_dir and save the resulting ouput 37 | files in output_dir. 38 | 39 | """ 40 | if not os.path.exists(output_dir): 41 | os.makedirs(output_dir) 42 | logger = log.get_global_console_logger() 43 | logger.info("Processing files in {}.".format(input_dir)) 44 | input_file_names = os.listdir(input_dir) 45 | for input_file_name in input_file_names: 46 | input_file = os.path.join(input_dir, input_file_name) 47 | with codecs.open(input_file, "r", encoding="UTF-8") as f: 48 | input_string = f.read() 49 | output_string = function(input_string) 50 | output_file = os.path.join(output_dir, input_file_name) 51 | with codecs.open(output_file, "w", encoding="UTF-8") as f: 52 | f.write(clean(output_string.lower())) 53 | logger.info("Saved processed files to {}.".format(output_dir)) 54 | 55 | 56 | class Rouge155(object): 57 | """ 58 | This is a wrapper for the ROUGE 1.5.5 summary evaluation package. 59 | This class is designed to simplify the evaluation process by: 60 | 61 | 1) Converting summaries into a format ROUGE understands. 62 | 2) Generating the ROUGE configuration file automatically based 63 | on filename patterns. 64 | 65 | This class can be used within Python like this: 66 | 67 | rouge = Rouge155() 68 | rouge.system_dir = 'test/systems' 69 | rouge.model_dir = 'test/models' 70 | 71 | # The system filename pattern should contain one group that 72 | # matches the document ID. 73 | rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' 74 | 75 | # The model filename pattern has '#ID#' as a placeholder for the 76 | # document ID. If there are multiple model summaries, pyrouge 77 | # will use the provided regex to automatically match them with 78 | # the corresponding system summary. Here, [A-Z] matches 79 | # multiple model summaries for a given #ID#. 80 | rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 81 | 82 | rouge_output = rouge.evaluate() 83 | print(rouge_output) 84 | output_dict = rouge.output_to_dict(rouge_ouput) 85 | print(output_dict) 86 | -> {'rouge_1_f_score': 0.95652, 87 | 'rouge_1_f_score_cb': 0.95652, 88 | 'rouge_1_f_score_ce': 0.95652, 89 | 'rouge_1_precision': 0.95652, 90 | [...] 91 | 92 | 93 | To evaluate multiple systems: 94 | 95 | rouge = Rouge155() 96 | rouge.system_dir = '/PATH/TO/systems' 97 | rouge.model_dir = 'PATH/TO/models' 98 | for system_id in ['id1', 'id2', 'id3']: 99 | rouge.system_filename_pattern = \ 100 | 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) 101 | rouge.model_filename_pattern = \ 102 | 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 103 | rouge_output = rouge.evaluate(system_id) 104 | print(rouge_output) 105 | 106 | """ 107 | 108 | def __init__(self, rouge_dir=None, rouge_args=None, temp_dir=None): 109 | """ 110 | Create a Rouge155 object. 111 | 112 | rouge_dir: Directory containing Rouge-1.5.5.pl 113 | rouge_args: Arguments to pass through to ROUGE if you 114 | don't want to use the default pyrouge 115 | arguments. 116 | 117 | """ 118 | self.temp_dir = temp_dir 119 | self.log = log.get_global_console_logger() 120 | self.__set_dir_properties() 121 | self._config_file = None 122 | self._settings_file = self.__get_config_path() 123 | self.__set_rouge_dir(rouge_dir) 124 | self.args = self.__clean_rouge_args(rouge_args) 125 | self._system_filename_pattern = None 126 | self._model_filename_pattern = None 127 | 128 | def save_home_dir(self): 129 | config = ConfigParser() 130 | section = 'pyrouge settings' 131 | config.add_section(section) 132 | config.set(section, 'home_dir', self._home_dir) 133 | with open(self._settings_file, 'w') as f: 134 | config.write(f) 135 | self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) 136 | 137 | @property 138 | def settings_file(self): 139 | """ 140 | Path of the setttings file, which stores the ROUGE home dir. 141 | 142 | """ 143 | return self._settings_file 144 | 145 | @property 146 | def bin_path(self): 147 | """ 148 | The full path of the ROUGE binary (although it's technically 149 | a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl 150 | 151 | """ 152 | if self._bin_path is None: 153 | raise Exception( 154 | "ROUGE path not set. Please set the ROUGE home directory " 155 | "and ensure that ROUGE-1.5.5.pl exists in it.") 156 | return self._bin_path 157 | 158 | @property 159 | def system_filename_pattern(self): 160 | """ 161 | The regular expression pattern for matching system summary 162 | filenames. The regex string. 163 | 164 | E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system 165 | filenames in the SPL2003/system folder of the ROUGE SPL example 166 | in the "sample-test" folder. 167 | 168 | Currently, there is no support for multiple systems. 169 | 170 | """ 171 | return self._system_filename_pattern 172 | 173 | @system_filename_pattern.setter 174 | def system_filename_pattern(self, pattern): 175 | self._system_filename_pattern = pattern 176 | 177 | @property 178 | def model_filename_pattern(self): 179 | """ 180 | The regular expression pattern for matching model summary 181 | filenames. The pattern needs to contain the string "#ID#", 182 | which is a placeholder for the document ID. 183 | 184 | E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model 185 | filenames in the SPL2003/system folder of the ROUGE SPL 186 | example in the "sample-test" folder. 187 | 188 | "#ID#" is a placeholder for the document ID which has been 189 | matched by the "(\d+)" part of the system filename pattern. 190 | The different model summaries for a given document ID are 191 | matched by the "[A-Z]" part. 192 | 193 | """ 194 | return self._model_filename_pattern 195 | 196 | @model_filename_pattern.setter 197 | def model_filename_pattern(self, pattern): 198 | self._model_filename_pattern = pattern 199 | 200 | @property 201 | def config_file(self): 202 | return self._config_file 203 | 204 | @config_file.setter 205 | def config_file(self, path): 206 | config_dir, _ = os.path.split(path) 207 | verify_dir(config_dir, "configuration file") 208 | self._config_file = path 209 | 210 | def split_sentences(self): 211 | """ 212 | ROUGE requires texts split into sentences. In case the texts 213 | are not already split, this method can be used. 214 | 215 | """ 216 | from pyrouge.utils.sentence_splitter import PunktSentenceSplitter 217 | self.log.info("Splitting sentences.") 218 | ss = PunktSentenceSplitter() 219 | 220 | def sent_split_to_string(s): return "\n".join(ss.split(s)) 221 | process_func = partial( 222 | DirectoryProcessor.process, function=sent_split_to_string) 223 | self.__process_summaries(process_func) 224 | 225 | @staticmethod 226 | def convert_summaries_to_rouge_format(input_dir, output_dir): 227 | """ 228 | Convert all files in input_dir into a format ROUGE understands 229 | and saves the files to output_dir. The input files are assumed 230 | to be plain text with one sentence per line. 231 | 232 | input_dir: Path of directory containing the input files. 233 | output_dir: Path of directory in which the converted files 234 | will be saved. 235 | 236 | """ 237 | DirectoryProcessor.process( 238 | input_dir, output_dir, Rouge155.convert_text_to_rouge_format) 239 | 240 | @staticmethod 241 | def convert_text_to_rouge_format(text, title="dummy title"): 242 | """ 243 | Convert a text to a format ROUGE understands. The text is 244 | assumed to contain one sentence per line. 245 | 246 | text: The text to convert, containg one sentence per line. 247 | title: Optional title for the text. The title will appear 248 | in the converted file, but doesn't seem to have 249 | any other relevance. 250 | 251 | Returns: The converted text as string. 252 | 253 | """ 254 | sentences = text.split("\n") 255 | sent_elems = [ 256 | "[{i}] " 257 | "{text}".format(i=i, text=sent) 258 | for i, sent in enumerate(sentences, start=1)] 259 | html = """ 260 | 261 | {title} 262 | 263 | 264 | {elems} 265 | 266 | """.format(title=title, elems="\n".join(sent_elems)) 267 | 268 | return html 269 | 270 | @staticmethod 271 | def write_config_static(system_dir, system_filename_pattern, 272 | model_dir, model_filename_pattern, 273 | config_file_path, system_id=None): 274 | """ 275 | Write the ROUGE configuration file, which is basically a list 276 | of system summary files and their corresponding model summary 277 | files. 278 | 279 | pyrouge uses regular expressions to automatically find the 280 | matching model summary files for a given system summary file 281 | (cf. docstrings for system_filename_pattern and 282 | model_filename_pattern). 283 | 284 | system_dir: Path of directory containing 285 | system summaries. 286 | system_filename_pattern: Regex string for matching 287 | system summary filenames. 288 | model_dir: Path of directory containing 289 | model summaries. 290 | model_filename_pattern: Regex string for matching model 291 | summary filenames. 292 | config_file_path: Path of the configuration file. 293 | system_id: Optional system ID string which 294 | will appear in the ROUGE output. 295 | 296 | """ 297 | system_filenames = [f for f in os.listdir(system_dir)] 298 | system_models_tuples = [] 299 | 300 | system_filename_pattern = re.compile(system_filename_pattern) 301 | for system_filename in sorted(system_filenames): 302 | match = system_filename_pattern.match(system_filename) 303 | if match: 304 | id = match.groups(0)[0] 305 | model_filenames = [model_filename_pattern.replace('#ID#', id)] 306 | # model_filenames = Rouge155.__get_model_filenames_for_id( 307 | # id, model_dir, model_filename_pattern) 308 | system_models_tuples.append( 309 | (system_filename, sorted(model_filenames))) 310 | if not system_models_tuples: 311 | raise Exception( 312 | "Did not find any files matching the pattern {} " 313 | "in the system summaries directory {}.".format( 314 | system_filename_pattern.pattern, system_dir)) 315 | 316 | with codecs.open(config_file_path, 'w', encoding='utf-8') as f: 317 | f.write('') 318 | for task_id, (system_filename, model_filenames) in enumerate( 319 | system_models_tuples, start=1): 320 | 321 | eval_string = Rouge155.__get_eval_string( 322 | task_id, system_id, 323 | system_dir, system_filename, 324 | model_dir, model_filenames) 325 | f.write(eval_string) 326 | f.write("") 327 | 328 | def write_config(self, config_file_path=None, system_id=None): 329 | """ 330 | Write the ROUGE configuration file, which is basically a list 331 | of system summary files and their matching model summary files. 332 | 333 | This is a non-static version of write_config_file_static(). 334 | 335 | config_file_path: Path of the configuration file. 336 | system_id: Optional system ID string which will 337 | appear in the ROUGE output. 338 | 339 | """ 340 | if not system_id: 341 | system_id = 1 342 | if (not config_file_path) or (not self._config_dir): 343 | self._config_dir = mkdtemp(dir=self.temp_dir) 344 | config_filename = "rouge_conf.xml" 345 | else: 346 | config_dir, config_filename = os.path.split(config_file_path) 347 | verify_dir(config_dir, "configuration file") 348 | self._config_file = os.path.join(self._config_dir, config_filename) 349 | Rouge155.write_config_static( 350 | self._system_dir, self._system_filename_pattern, 351 | self._model_dir, self._model_filename_pattern, 352 | self._config_file, system_id) 353 | self.log.info( 354 | "Written ROUGE configuration to {}".format(self._config_file)) 355 | 356 | def evaluate(self, system_id=1, rouge_args=None): 357 | """ 358 | Run ROUGE to evaluate the system summaries in system_dir against 359 | the model summaries in model_dir. The summaries are assumed to 360 | be in the one-sentence-per-line HTML format ROUGE understands. 361 | 362 | system_id: Optional system ID which will be printed in 363 | ROUGE's output. 364 | 365 | Returns: Rouge output as string. 366 | 367 | """ 368 | self.write_config(system_id=system_id) 369 | options = self.__get_options(rouge_args) 370 | command = [self._bin_path] + options 371 | self.log.info( 372 | "Running ROUGE with command {}".format(" ".join(command))) 373 | rouge_output = check_output(command).decode("UTF-8") 374 | return rouge_output 375 | 376 | def convert_and_evaluate(self, system_id=1, 377 | split_sentences=False, rouge_args=None): 378 | """ 379 | Convert plain text summaries to ROUGE format and run ROUGE to 380 | evaluate the system summaries in system_dir against the model 381 | summaries in model_dir. Optionally split texts into sentences 382 | in case they aren't already. 383 | 384 | This is just a convenience method combining 385 | convert_summaries_to_rouge_format() and evaluate(). 386 | 387 | split_sentences: Optional argument specifying if 388 | sentences should be split. 389 | system_id: Optional system ID which will be printed 390 | in ROUGE's output. 391 | 392 | Returns: ROUGE output as string. 393 | 394 | """ 395 | if split_sentences: 396 | self.split_sentences() 397 | self.__write_summaries() 398 | rouge_output = self.evaluate(system_id, rouge_args) 399 | return rouge_output 400 | 401 | def output_to_dict(self, output): 402 | """ 403 | Convert the ROUGE output into python dictionary for further 404 | processing. 405 | 406 | """ 407 | # 0 ROUGE-1 Average_R: 0.02632 (95%-conf.int. 0.02632 - 0.02632) 408 | pattern = re.compile( 409 | r"(\d+) (ROUGE-\S+) (Average_\w): (\d.\d+) " 410 | r"\(95%-conf.int. (\d.\d+) - (\d.\d+)\)") 411 | results = {} 412 | for line in output.split("\n"): 413 | match = pattern.match(line) 414 | if match: 415 | sys_id, rouge_type, measure, result, conf_begin, conf_end = \ 416 | match.groups() 417 | measure = { 418 | 'Average_R': 'recall', 419 | 'Average_P': 'precision', 420 | 'Average_F': 'f_score' 421 | }[measure] 422 | rouge_type = rouge_type.lower().replace("-", '_') 423 | key = "{}_{}".format(rouge_type, measure) 424 | results[key] = float(result) 425 | results["{}_cb".format(key)] = float(conf_begin) 426 | results["{}_ce".format(key)] = float(conf_end) 427 | return results 428 | 429 | ################################################################### 430 | # Private methods 431 | 432 | def __set_rouge_dir(self, home_dir=None): 433 | """ 434 | Verfify presence of ROUGE-1.5.5.pl and data folder, and set 435 | those paths. 436 | 437 | """ 438 | if not home_dir: 439 | self._home_dir = self.__get_rouge_home_dir_from_settings() 440 | else: 441 | self._home_dir = home_dir 442 | self.save_home_dir() 443 | self._bin_path = os.path.join(self._home_dir, 'ROUGE-1.5.5.pl') 444 | self.data_dir = os.path.join(self._home_dir, 'data') 445 | if not os.path.exists(self._bin_path): 446 | raise Exception( 447 | "ROUGE binary not found at {}. Please set the " 448 | "correct path by running pyrouge_set_rouge_path " 449 | "/path/to/rouge/home.".format(self._bin_path)) 450 | 451 | def __get_rouge_home_dir_from_settings(self): 452 | config = ConfigParser() 453 | with open(self._settings_file) as f: 454 | if hasattr(config, "read_file"): 455 | config.read_file(f) 456 | else: 457 | # use deprecated python 2.x method 458 | config.readfp(f) 459 | rouge_home_dir = config.get('pyrouge settings', 'home_dir') 460 | return rouge_home_dir 461 | 462 | @staticmethod 463 | def __get_eval_string( 464 | task_id, system_id, 465 | system_dir, system_filename, 466 | model_dir, model_filenames): 467 | """ 468 | ROUGE can evaluate several system summaries for a given text 469 | against several model summaries, i.e. there is an m-to-n 470 | relation between system and model summaries. The system 471 | summaries are listed in the tag and the model summaries 472 | in the tag. pyrouge currently only supports one system 473 | summary per text, i.e. it assumes a 1-to-n relation between 474 | system and model summaries. 475 | 476 | """ 477 | peer_elems = "

{name}

".format( 478 | id=system_id, name=system_filename) 479 | 480 | model_elems = ["{name}".format( 481 | id=chr(65 + i), name=name) 482 | for i, name in enumerate(model_filenames)] 483 | 484 | model_elems = "\n\t\t\t".join(model_elems) 485 | eval_string = """ 486 | 487 | {model_root} 488 | {peer_root} 489 | 490 | 491 | 492 | {peer_elems} 493 | 494 | 495 | {model_elems} 496 | 497 | 498 | """.format( 499 | task_id=task_id, 500 | model_root=model_dir, model_elems=model_elems, 501 | peer_root=system_dir, peer_elems=peer_elems) 502 | return eval_string 503 | 504 | def __process_summaries(self, process_func): 505 | """ 506 | Helper method that applies process_func to the files in the 507 | system and model folders and saves the resulting files to new 508 | system and model folders. 509 | 510 | """ 511 | temp_dir = mkdtemp(dir=self.temp_dir) 512 | new_system_dir = os.path.join(temp_dir, "system") 513 | os.mkdir(new_system_dir) 514 | new_model_dir = os.path.join(temp_dir, "model") 515 | os.mkdir(new_model_dir) 516 | self.log.info( 517 | "Processing summaries. Saving system files to {} and " 518 | "model files to {}.".format(new_system_dir, new_model_dir)) 519 | process_func(self._system_dir, new_system_dir) 520 | process_func(self._model_dir, new_model_dir) 521 | self._system_dir = new_system_dir 522 | self._model_dir = new_model_dir 523 | 524 | def __write_summaries(self): 525 | self.log.info("Writing summaries.") 526 | self.__process_summaries(self.convert_summaries_to_rouge_format) 527 | 528 | @staticmethod 529 | def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern): 530 | pattern = re.compile(model_filenames_pattern.replace('#ID#', id)) 531 | model_filenames = [ 532 | f for f in os.listdir(model_dir) if pattern.match(f)] 533 | if not model_filenames: 534 | raise Exception( 535 | "Could not find any model summaries for the system" 536 | " summary with ID {}. Specified model filename pattern was: " 537 | "{}".format(id, model_filenames_pattern)) 538 | return model_filenames 539 | 540 | def __get_options(self, rouge_args=None): 541 | """ 542 | Get supplied command line arguments for ROUGE or use default 543 | ones. 544 | 545 | """ 546 | if self.args: 547 | options = self.args.split() 548 | elif rouge_args: 549 | options = rouge_args.split() 550 | else: 551 | options = [ 552 | '-e', self._data_dir, 553 | '-c', 95, 554 | # '-2', 555 | # '-1', 556 | # '-U', 557 | '-m', 558 | # '-v', 559 | '-r', 1000, 560 | '-n', 2, 561 | # '-w', 1.2, 562 | '-a', 563 | ] 564 | options = list(map(str, options)) 565 | 566 | options = self.__add_config_option(options) 567 | return options 568 | 569 | def __create_dir_property(self, dir_name, docstring): 570 | """ 571 | Generate getter and setter for a directory property. 572 | 573 | """ 574 | property_name = "{}_dir".format(dir_name) 575 | private_name = "_" + property_name 576 | setattr(self, private_name, None) 577 | 578 | def fget(self): 579 | return getattr(self, private_name) 580 | 581 | def fset(self, path): 582 | verify_dir(path, dir_name) 583 | setattr(self, private_name, path) 584 | 585 | p = property(fget=fget, fset=fset, doc=docstring) 586 | setattr(self.__class__, property_name, p) 587 | 588 | def __set_dir_properties(self): 589 | """ 590 | Automatically generate the properties for directories. 591 | 592 | """ 593 | directories = [ 594 | ("home", "The ROUGE home directory."), 595 | ("data", "The path of the ROUGE 'data' directory."), 596 | ("system", "Path of the directory containing system summaries."), 597 | ("model", "Path of the directory containing model summaries."), 598 | ] 599 | for (dirname, docstring) in directories: 600 | self.__create_dir_property(dirname, docstring) 601 | 602 | def __clean_rouge_args(self, rouge_args): 603 | """ 604 | Remove enclosing quotation marks, if any. 605 | 606 | """ 607 | if not rouge_args: 608 | return 609 | quot_mark_pattern = re.compile('"(.+)"') 610 | match = quot_mark_pattern.match(rouge_args) 611 | if match: 612 | cleaned_args = match.group(1) 613 | return cleaned_args 614 | else: 615 | return rouge_args 616 | 617 | def __add_config_option(self, options): 618 | return options + [self._config_file] 619 | 620 | def __get_config_path(self): 621 | if platform.system() == "Windows": 622 | parent_dir = os.getenv("APPDATA") 623 | config_dir_name = "pyrouge" 624 | elif os.name == "posix": 625 | parent_dir = os.path.expanduser("~") 626 | config_dir_name = ".pyrouge" 627 | else: 628 | parent_dir = os.path.dirname(__file__) 629 | config_dir_name = "" 630 | config_dir = os.path.join(parent_dir, config_dir_name) 631 | if not os.path.exists(config_dir): 632 | os.makedirs(config_dir) 633 | return os.path.join(config_dir, 'settings.ini') 634 | 635 | 636 | if __name__ == "__main__": 637 | import argparse 638 | from utils.argparsers import rouge_path_parser 639 | 640 | parser = argparse.ArgumentParser(parents=[rouge_path_parser]) 641 | args = parser.parse_args() 642 | 643 | rouge = Rouge155(args.rouge_home) 644 | rouge.save_home_dir() 645 | -------------------------------------------------------------------------------- /src/run_seq2seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import json 7 | import random 8 | from collections import OrderedDict 9 | 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import (DataLoader, SequentialSampler) 13 | from torch.utils.data.distributed import DistributedSampler 14 | 15 | try: 16 | from torch.utils.tensorboard import SummaryWriter 17 | except: 18 | from tensorboardX import SummaryWriter 19 | 20 | import tqdm 21 | 22 | from s2s_ft.modeling import BertForSequenceToSequence, BertForSequenceToSequence_Distill, BertModel, BertOnlyMLMHead, \ 23 | get_linear_schedule_with_warmup 24 | from transformers import AdamW 25 | from transformers import \ 26 | RobertaConfig, BertConfig, \ 27 | BertTokenizer, RobertaTokenizer, \ 28 | XLMRobertaConfig, XLMRobertaTokenizer 29 | from s2s_ft.configuration_unilm import UnilmConfig 30 | from s2s_ft.tokenization_unilm import UnilmTokenizer 31 | from s2s_ft.configuration_minilm import MinilmConfig 32 | from s2s_ft.tokenization_minilm import MinilmTokenizer 33 | 34 | from s2s_ft import utils 35 | from s2s_ft.config import BertForSeq2SeqConfig 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | MODEL_CLASSES = { 41 | 'bert': (BertConfig, BertTokenizer), 42 | 'minilm': (MinilmConfig, MinilmTokenizer), 43 | 'roberta': (RobertaConfig, RobertaTokenizer), 44 | 'xlm-roberta': (XLMRobertaConfig, XLMRobertaTokenizer), 45 | 'unilm': (UnilmConfig, UnilmTokenizer), 46 | } 47 | 48 | 49 | def prepare_for_training(args, model, checkpoint_state_dict, amp): 50 | no_decay = ['bias', 'LayerNorm.weight'] 51 | optimizer_grouped_parameters = [ 52 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 53 | 'weight_decay': args.weight_decay}, 54 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 55 | ] 56 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 57 | 58 | if amp: 59 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 60 | if checkpoint_state_dict: 61 | amp.load_state_dict(checkpoint_state_dict['amp']) 62 | 63 | if checkpoint_state_dict: 64 | optimizer.load_state_dict(checkpoint_state_dict['optimizer']) 65 | model.load_state_dict(checkpoint_state_dict['model']) 66 | 67 | if(args.use_distill>0): 68 | teacher_pt= torch.load(args.teacher_model, map_location='cpu') 69 | teacher_model_pt = OrderedDict((k[5:] if k.startswith('bert') else k, v) for k, v in teacher_pt.items() if k.startswith('bert')) 70 | teacher_cls_pt = OrderedDict((k[4:] if k.startswith('cls') else k, v) for k, v in teacher_pt.items() if k.startswith('cls')) 71 | model.teacher_model.load_state_dict(teacher_model_pt, strict=True) 72 | model.teacher_cls.load_state_dict(teacher_cls_pt, strict=True) 73 | 74 | for p in model.teacher_model.parameters(): 75 | p.requires_grad = False 76 | for p in model.teacher_cls.parameters(): 77 | p.requires_grad = False 78 | 79 | 80 | # multi-gpu training (should be after apex fp16 initialization) 81 | if args.n_gpu > 1: 82 | model = torch.nn.DataParallel(model, find_unused_parameters=True) 83 | 84 | # Distributed training (should be after apex fp16 initialization) 85 | if args.local_rank != -1: 86 | model = torch.nn.parallel.DistributedDataParallel( 87 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 88 | 89 | return model, optimizer 90 | 91 | 92 | def train(args, training_features, model, tokenizer): 93 | """ Train the model """ 94 | if args.local_rank in [-1, 0] and args.log_dir: 95 | tb_writer = SummaryWriter(log_dir=args.log_dir) 96 | else: 97 | tb_writer = None 98 | 99 | if args.fp16: 100 | try: 101 | from apex import amp 102 | except ImportError: 103 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 104 | else: 105 | amp = None 106 | 107 | # model recover 108 | recover_step = utils.get_max_epoch_model(args.output_dir) 109 | 110 | if recover_step: 111 | model_recover_checkpoint = os.path.join(args.output_dir, "model.{}.bin".format(recover_step)) 112 | logger.info(" ** Recover model checkpoint in %s ** ", model_recover_checkpoint) 113 | model_state_dict = torch.load(model_recover_checkpoint, map_location='cpu') 114 | optimizer_recover_checkpoint = os.path.join(args.output_dir, "optim.{}.bin".format(recover_step)) 115 | checkpoint_state_dict = torch.load(optimizer_recover_checkpoint, map_location='cpu') 116 | checkpoint_state_dict['model'] = model_state_dict 117 | else: 118 | checkpoint_state_dict = None 119 | 120 | model.to(args.device) 121 | model, optimizer = prepare_for_training(args, model, checkpoint_state_dict, amp=amp) 122 | 123 | if args.n_gpu == 0 or args.no_cuda: 124 | per_node_train_batch_size = args.per_gpu_train_batch_size * args.gradient_accumulation_steps 125 | else: 126 | per_node_train_batch_size = args.per_gpu_train_batch_size * args.n_gpu * args.gradient_accumulation_steps 127 | 128 | train_batch_size = per_node_train_batch_size * (torch.distributed.get_world_size() if args.local_rank != -1 else 1) 129 | global_step = recover_step if recover_step else 0 130 | 131 | if args.num_training_steps == -1: 132 | args.num_training_steps = int(args.num_training_epochs * len(training_features) / train_batch_size) 133 | 134 | scheduler = get_linear_schedule_with_warmup( 135 | optimizer, num_warmup_steps=args.num_warmup_steps, 136 | num_training_steps=args.num_training_steps, last_epoch=-1, min_lr=args.min_lr, fix_lr=args.fix_lr) 137 | 138 | if checkpoint_state_dict: 139 | scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"]) 140 | 141 | train_dataset = utils.Seq2seqDatasetForBert( 142 | features=training_features, max_source_len=args.max_source_seq_length, 143 | max_target_len=args.max_target_seq_length, vocab_size=tokenizer.vocab_size, 144 | cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id, 145 | mask_id=tokenizer.mask_token_id, random_prob=args.random_prob, keep_prob=args.keep_prob, 146 | offset=train_batch_size * global_step, num_training_instances=train_batch_size * args.num_training_steps, 147 | word_drop_prob=args.word_drop_prob, word_shuffle_k=args.word_shuffle_k,sent_shuffle_k=args.sent_shuffle_k,sent_drop_prob=args.sent_drop_prob 148 | ) 149 | 150 | logger.info("Check dataset:") 151 | for i in range(5): 152 | source_ids, noisy_source_ids_s, noisy_source_ids_t, target_ids, pseudo_ids, num_source_tokens, noisy_num_source_tokens_s, noisy_num_source_tokens_t, num_target_tokens = train_dataset.__getitem__(i) 153 | logger.info("Instance-%d" % i) 154 | logger.info("Source tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(source_ids))) 155 | logger.info("Target tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(target_ids))) 156 | 157 | logger.info("Mode = %s" % str(model)) 158 | 159 | # Train! 160 | logger.info(" ***** Running training ***** *") 161 | logger.info(" Num examples = %d", len(training_features)) 162 | logger.info(" Num Epochs = %.2f", len(train_dataset) / len(training_features)) 163 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 164 | logger.info(" Batch size per node = %d", per_node_train_batch_size) 165 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", train_batch_size) 166 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 167 | logger.info(" Total optimization steps = %d", args.num_training_steps) 168 | 169 | if args.num_training_steps <= global_step: 170 | logger.info("Training is done. Please use a new dir or clean this dir!") 171 | else: 172 | # The training features are shuffled 173 | train_sampler = SequentialSampler(train_dataset) \ 174 | if args.local_rank == -1 else DistributedSampler(train_dataset, shuffle=False) 175 | train_dataloader = DataLoader( 176 | train_dataset, sampler=train_sampler, 177 | batch_size=per_node_train_batch_size // args.gradient_accumulation_steps, 178 | collate_fn=utils.batch_list_to_batch_tensors) 179 | 180 | train_iterator = tqdm.tqdm( 181 | train_dataloader, initial=global_step, 182 | desc="Iter (loss=X.XXX, lr=X.XXXXXXX)", disable=args.local_rank not in [-1, 0]) 183 | 184 | model.train() 185 | model.zero_grad() 186 | 187 | tr_loss, logging_loss = 0.0, 0.0 188 | 189 | for step, batch in enumerate(train_iterator): 190 | batch = tuple(t.to(args.device) for t in batch) 191 | if(args.use_distill>0): 192 | inputs = {'source_ids': batch[0], 193 | 'noisy_source_ids_s': batch[1], 194 | 'noisy_source_ids_t': batch[2], 195 | 'target_ids': batch[3], 196 | 'pseudo_ids': batch[4], 197 | 'num_source_tokens': batch[5], 198 | 'num_noisy_source_tokens_s': batch[6], 199 | 'num_noisy_source_tokens_t': batch[7], 200 | 'num_target_tokens': batch[8]} 201 | else: 202 | inputs = {'source_ids': batch[0], 203 | 'noisy_source_ids': batch[1], 204 | 'target_ids': batch[3], 205 | 'pseudo_ids': batch[4], 206 | 'num_source_tokens': batch[5], 207 | 'num_noisy_source_tokens': batch[6], 208 | 'num_target_tokens': batch[8]} 209 | loss = model(**inputs) 210 | if args.n_gpu > 1: 211 | loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training 212 | 213 | train_iterator.set_description('Iter (loss=%5.3f) lr=%9.7f' % (loss.item(), scheduler.get_lr()[0])) 214 | 215 | if args.gradient_accumulation_steps > 1: 216 | loss = loss / args.gradient_accumulation_steps 217 | 218 | if args.fp16: 219 | with amp.scale_loss(loss, optimizer) as scaled_loss: 220 | scaled_loss.backward() 221 | else: 222 | loss.backward() 223 | 224 | logging_loss += loss.item() 225 | if (step + 1) % args.gradient_accumulation_steps == 0: 226 | if args.fp16: 227 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 228 | else: 229 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 230 | 231 | optimizer.step() 232 | scheduler.step() # Update learning rate schedule 233 | model.zero_grad() 234 | global_step += 1 235 | 236 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 237 | logger.info("") 238 | logger.info(" Step [%d ~ %d]: %.2f", global_step - args.logging_steps, global_step, logging_loss) 239 | logging_loss = 0.0 240 | 241 | if args.local_rank in [-1, 0] and args.save_steps > 0 and \ 242 | (global_step % args.save_steps == 0 or global_step == args.num_training_steps): 243 | 244 | save_path = os.path.join(args.output_dir, "ckpt-%d" % global_step) 245 | os.makedirs(save_path, exist_ok=True) 246 | model_to_save = model.module if hasattr(model, "module") else model 247 | model_to_save.save_pretrained(save_path) 248 | 249 | optim_to_save = { 250 | "optimizer": optimizer.state_dict(), 251 | "lr_scheduler": scheduler.state_dict(), 252 | } 253 | if args.fp16: 254 | optim_to_save["amp"] = amp.state_dict() 255 | torch.save( 256 | optim_to_save, os.path.join(args.output_dir, 'optim.{}.bin'.format(global_step))) 257 | 258 | logger.info("Saving model checkpoint %d into %s", global_step, save_path) 259 | 260 | if args.local_rank in [-1, 0] and tb_writer: 261 | tb_writer.close() 262 | 263 | 264 | 265 | def get_args(): 266 | parser = argparse.ArgumentParser() 267 | 268 | # parser.add_argument("--train_source_file", default=None, type=str, required=True, 269 | # help="Training data contains source") 270 | # parser.add_argument("--train_target_file", default=None, type=str, required=True, 271 | # help="Training data contains target") 272 | parser.add_argument("--train_file", default=None, type=str, required=True, 273 | help="Training data (json format) for training. Keys: source and target") 274 | parser.add_argument("--model_type", default=None, type=str, required=True, 275 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 276 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 277 | help="Path to pre-trained model or shortcut name selected in the list:") 278 | parser.add_argument("--output_dir", default=None, type=str, required=True, 279 | help="The output directory where the model checkpoints and predictions will be written.") 280 | parser.add_argument("--log_dir", default=None, type=str, 281 | help="The output directory where the log will be written.") 282 | 283 | ## Other parameters 284 | parser.add_argument("--config_name", default=None, type=str, 285 | help="Pretrained config name or path if not the same as model_name") 286 | parser.add_argument("--tokenizer_name", default=None, type=str, 287 | help="Pretrained tokenizer name or path if not the same as model_name") 288 | parser.add_argument("--cache_dir", default=None, type=str, 289 | help="Where do you want to store the pre-trained models downloaded from s3") 290 | 291 | parser.add_argument("--max_source_seq_length", default=464, type=int, 292 | help="The maximum total source sequence length after WordPiece tokenization. Sequences " 293 | "longer than this will be truncated, and sequences shorter than this will be padded.") 294 | parser.add_argument("--max_target_seq_length", default=48, type=int, 295 | help="The maximum total target sequence length after WordPiece tokenization. Sequences " 296 | "longer than this will be truncated, and sequences shorter than this will be padded.") 297 | 298 | parser.add_argument("--cached_train_features_file", default=None, type=str, 299 | help="Cached training features file") 300 | parser.add_argument("--do_lower_case", action='store_true', 301 | help="Set this flag if you are using an uncased model.") 302 | 303 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 304 | help="Batch size per GPU/CPU for training.") 305 | parser.add_argument("--learning_rate", default=5e-5, type=float, 306 | help="The initial learning rate for Adam.") 307 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 308 | help="Number of updates steps to accumulate before performing a backward/update pass.") 309 | parser.add_argument("--weight_decay", default=0.01, type=float, 310 | help="Weight decay if we apply some.") 311 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 312 | help="Epsilon for Adam optimizer.") 313 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 314 | help="Max gradient norm.") 315 | parser.add_argument("--label_smoothing", default=0.1, type=float, 316 | help="Max gradient norm.") 317 | parser.add_argument("--num_training_steps", default=-1, type=int, 318 | help="set total number of training steps to perform") 319 | parser.add_argument("--num_training_epochs", default=10, type=int, 320 | help="set total number of training epochs to perform (--num_training_steps has higher priority)") 321 | parser.add_argument("--num_warmup_steps", default=0, type=int, 322 | help="Linear warmup over warmup_steps.") 323 | 324 | parser.add_argument("--F", default=0.1, type=float, 325 | help="prob to random replace a masked token") 326 | parser.add_argument("--keep_prob", default=0.1, type=float, 327 | help="prob to keep no change for a masked token") 328 | parser.add_argument("--random_prob", default=0.1, type=float, 329 | help="prob to random replace a masked token") 330 | 331 | parser.add_argument('--logging_steps', type=int, default=500, 332 | help="Log every X updates steps.") 333 | parser.add_argument('--save_steps', type=int, default=1500, 334 | help="Save checkpoint every X updates steps.") 335 | parser.add_argument("--no_cuda", action='store_true', 336 | help="Whether not to use CUDA when available") 337 | parser.add_argument('--seed', type=int, default=42, 338 | help="random seed for initialization") 339 | 340 | parser.add_argument("--local_rank", type=int, default=-1, 341 | help="local_rank for distributed training on gpus") 342 | parser.add_argument('--fp16', action='store_true', 343 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 344 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 345 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 346 | "See details at https://nvidia.github.io/apex/amp.html") 347 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") 348 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") 349 | 350 | 351 | parser.add_argument('--kd_weight', type=float, default=0, help="kd_weight") 352 | parser.add_argument("--use_distill", default=-1, type=int, help="use_distill") 353 | parser.add_argument("--teacher_model", default='', type=str, help="teacher_model") 354 | parser.add_argument("--use_teacher_dropout", default=-1, type=int, help="use_teacher_dropout") 355 | parser.add_argument("--teacher_dropout_prob", default=0.1, type=float, help="teacher_dropout_prob") 356 | parser.add_argument("--min_lr", default=0, type=float, help="min_lr") 357 | parser.add_argument("--fix_lr", default=0, type=float, help="fix_lr") 358 | 359 | parser.add_argument("--use_random_replace_input", default=-1, type=int, help="use_random_replace_input") 360 | parser.add_argument("--random_replace_input_p", default=-1, type=float, help="random_replace_input_p") 361 | parser.add_argument("--random_replace_input_k", default=-1, type=int, help="random_replace_input_k") 362 | parser.add_argument("--word_drop_prob", default=0, type=float, help="word_drop_prob") 363 | parser.add_argument("--word_shuffle_k", default=0, type=int, help="word_shuffle_k") 364 | parser.add_argument("--sent_shuffle_k", default=0, type=int, help="sent_shuffle_k") 365 | parser.add_argument("--sent_drop_prob", default=0, type=float, help="sent_drop_prob") 366 | parser.add_argument("--use_noisy_student", default=-1, type=int, help="sent_shuffle_k") 367 | parser.add_argument("--use_noisy_teacher", default=-1, type=int, help="use_noisy_teacher") 368 | 369 | 370 | parser.add_argument("--use_my_config", default=-1, type=int, help="use_my_config") 371 | parser.add_argument("--my_num_hidden_layers", default=6, type=int, help="my_num_hidden_layers") 372 | parser.add_argument("--my_num_attention_heads", default=8, type=int, help="my_num_attention_heads") 373 | parser.add_argument("--my_hidden_size", default=512, type=int, help="my_hidden_size") 374 | parser.add_argument("--my_intermediate_size", default=2048, type=int, help="my_intermediate_size") 375 | 376 | args = parser.parse_args() 377 | return args 378 | 379 | 380 | def prepare(args): 381 | # Setup distant debugging if needed 382 | if args.server_ip and args.server_port: 383 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 384 | import ptvsd 385 | print("Waiting for debugger attach") 386 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 387 | ptvsd.wait_for_attach() 388 | 389 | os.makedirs(args.output_dir, exist_ok=True) 390 | json.dump(args.__dict__, open(os.path.join( 391 | args.output_dir, 'train_opt.json'), 'w'), sort_keys=True, indent=2) 392 | 393 | # Setup CUDA, GPU & distributed training 394 | if args.local_rank == -1 or args.no_cuda: 395 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 396 | args.n_gpu = torch.cuda.device_count() 397 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 398 | torch.cuda.set_device(args.local_rank) 399 | device = torch.device("cuda", args.local_rank) 400 | torch.distributed.init_process_group(backend='nccl') 401 | args.n_gpu = 1 402 | args.device = device 403 | 404 | # Setup logging 405 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 406 | datefmt='%m/%d/%Y %H:%M:%S', 407 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 408 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 409 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 410 | 411 | # Set seed 412 | random.seed(args.seed) 413 | np.random.seed(args.seed) 414 | torch.manual_seed(args.seed) 415 | if args.n_gpu > 0: 416 | torch.cuda.manual_seed_all(args.seed) 417 | 418 | logger.info("Training/evaluation parameters %s", args) 419 | 420 | # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. 421 | # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will 422 | # remove the need for this code, but it is still valid. 423 | if args.fp16: 424 | try: 425 | import apex 426 | apex.amp.register_half_function(torch, 'einsum') 427 | except ImportError: 428 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 429 | 430 | 431 | def get_model_and_tokenizer(args): 432 | config_class, tokenizer_class = MODEL_CLASSES[args.model_type] 433 | model_config = config_class.from_pretrained( 434 | args.config_name if args.config_name else args.model_name_or_path, 435 | cache_dir=args.cache_dir if args.cache_dir else None) 436 | 437 | if(args.use_my_config>0): 438 | model_config.num_hidden_layers = args.my_num_hidden_layers 439 | model_config.num_attention_heads = args.my_num_attention_heads 440 | model_config.hidden_size = args.my_hidden_size 441 | model_config.intermediate_size = args.my_intermediate_size 442 | 443 | config = BertForSeq2SeqConfig.from_exist_config( 444 | config=model_config, label_smoothing=args.label_smoothing, 445 | max_position_embeddings=args.max_source_seq_length + args.max_target_seq_length) 446 | logger.info("Model config for seq2seq: %s", str(config)) 447 | 448 | tokenizer = tokenizer_class.from_pretrained( 449 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 450 | do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) 451 | if(args.use_distill>0): 452 | config.use_teacher_dropout = args.use_teacher_dropout 453 | config.teacher_dropout_prob = args.teacher_dropout_prob 454 | if (args.use_my_config > 0): 455 | model = BertForSequenceToSequence(config=config) 456 | else: 457 | model = BertForSequenceToSequence_Distill.from_pretrained( 458 | args.model_name_or_path, config=config, model_type=args.model_type, 459 | reuse_position_embedding=True, 460 | cache_dir=args.cache_dir if args.cache_dir else None) 461 | model.kd_weight = args.kd_weight 462 | else: 463 | if (args.use_my_config > 0): 464 | model = BertForSequenceToSequence(config=config) 465 | else: 466 | model = BertForSequenceToSequence.from_pretrained( 467 | args.model_name_or_path, config=config, model_type=args.model_type, 468 | reuse_position_embedding=True, 469 | cache_dir=args.cache_dir if args.cache_dir else None) 470 | 471 | model.use_random_replace_input = args.use_random_replace_input 472 | model.random_replace_input_p = args.random_replace_input_p 473 | model.random_replace_input_k = args.random_replace_input_k 474 | model.use_noisy_student = args.use_noisy_student 475 | model.use_noisy_teacher = args.use_noisy_teacher 476 | 477 | return model, tokenizer 478 | 479 | 480 | 481 | 482 | def main(): 483 | args = get_args() 484 | prepare(args) 485 | 486 | if args.local_rank not in [-1, 0]: 487 | torch.distributed.barrier() 488 | # Make sure only the first process in distributed training will download model & vocab 489 | # Load pretrained model and tokenizer 490 | model, tokenizer = get_model_and_tokenizer(args) 491 | # teacher_model = get_teacher_model(args) 492 | # model.teacher_model = teacher_model 493 | if args.local_rank == 0: 494 | torch.distributed.barrier() 495 | # Make sure only the first process in distributed training will download model & vocab 496 | 497 | if args.cached_train_features_file is None: 498 | args.cached_train_features_file = os.path.join(args.output_dir, "cached_features_for_training.pt") 499 | training_features = utils.load_and_cache_examples( 500 | example_file=args.train_file, tokenizer=tokenizer, local_rank=args.local_rank, 501 | cached_features_file=args.cached_train_features_file, shuffle=True, 502 | ) 503 | 504 | train(args, training_features, model, tokenizer) 505 | # train_distill(args, training_features, model, teacher_model, tokenizer) 506 | 507 | 508 | if __name__ == "__main__": 509 | main() 510 | -------------------------------------------------------------------------------- /src/s2s_ft/modeling.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import logging 4 | import math 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | from torch.nn.modules.loss import _Loss 11 | import torch.nn.functional as F 12 | from torch.optim.lr_scheduler import _LRScheduler 13 | import types 14 | 15 | from transformers.modeling_bert import \ 16 | BertPreTrainedModel, BertSelfOutput, BertIntermediate, BertOutput, BertPredictionHeadTransform 17 | from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 18 | from transformers.modeling_xlm_roberta import XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 19 | 20 | from s2s_ft.config import BertForSeq2SeqConfig 21 | from s2s_ft.convert_state_dict import get_checkpoint_from_transformer_cache, state_dict_convert 22 | 23 | from typing import List, Tuple 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | BertLayerNorm = torch.nn.LayerNorm 28 | 29 | UNILM_PRETRAINED_MODEL_ARCHIVE_MAP = { 30 | 'unilm-base-cased': "https://unilm.blob.core.windows.net/ckpt/unilm1-base-cased.bin", 31 | 'unilm-large-cased': "https://unilm.blob.core.windows.net/ckpt/unilm1-large-cased.bin", 32 | 'unilm1-base-cased': "https://unilm.blob.core.windows.net/ckpt/unilm1-base-cased.bin", 33 | 'unilm1-large-cased': "https://unilm.blob.core.windows.net/ckpt/unilm1-large-cased.bin", 34 | 'unilm1.2-base-uncased': "https://unilm.blob.core.windows.net/ckpt/unilm1.2-base-uncased.bin" 35 | } 36 | 37 | MINILM_PRETRAINED_MODEL_ARCHIVE_MAP = { 38 | 'minilm-l12-h384-uncased': "https://unilm.blob.core.windows.net/ckpt/minilm-l12-h384-uncased.bin", 39 | } 40 | 41 | class BertPreTrainedForSeq2SeqModel(BertPreTrainedModel): 42 | """ An abstract class to handle weights initialization and 43 | a simple interface for dowloading and loading pretrained models. 44 | """ 45 | config_class = BertForSeq2SeqConfig 46 | supported_convert_pretrained_model_archive_map = { 47 | "bert": BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 48 | "xlm-roberta": XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, 49 | "unilm": UNILM_PRETRAINED_MODEL_ARCHIVE_MAP, 50 | "minilm": MINILM_PRETRAINED_MODEL_ARCHIVE_MAP, 51 | } 52 | base_model_prefix = "bert_for_seq2seq" 53 | pretrained_model_archive_map = { 54 | **XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, 55 | **BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 56 | **UNILM_PRETRAINED_MODEL_ARCHIVE_MAP, 57 | **MINILM_PRETRAINED_MODEL_ARCHIVE_MAP, 58 | } 59 | 60 | def _init_weights(self, module): 61 | """ Initialize the weights """ 62 | if isinstance(module, (nn.Linear, nn.Embedding)): 63 | # Slightly different from the TF version which uses truncated_normal for initialization 64 | # cf https://github.com/pytorch/pytorch/pull/5617 65 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 66 | elif isinstance(module, BertLayerNorm): 67 | module.bias.data.zero_() 68 | module.weight.data.fill_(1.0) 69 | if isinstance(module, nn.Linear) and module.bias is not None: 70 | module.bias.data.zero_() 71 | 72 | @classmethod 73 | def from_pretrained(cls, pretrained_model_name_or_path, reuse_position_embedding=None, *model_args, **kwargs): 74 | model_type = kwargs.pop('model_type', None) 75 | if model_type is not None and "state_dict" not in kwargs: 76 | if model_type in cls.supported_convert_pretrained_model_archive_map: 77 | pretrained_model_archive_map = cls.supported_convert_pretrained_model_archive_map[model_type] 78 | if pretrained_model_name_or_path in pretrained_model_archive_map: 79 | state_dict = get_checkpoint_from_transformer_cache( 80 | archive_file=pretrained_model_archive_map[pretrained_model_name_or_path], 81 | pretrained_model_name_or_path=pretrained_model_name_or_path, 82 | pretrained_model_archive_map=pretrained_model_archive_map, 83 | cache_dir=kwargs.get("cache_dir", None), force_download=kwargs.get("force_download", None), 84 | proxies=kwargs.get("proxies", None), resume_download=kwargs.get("resume_download", None), 85 | ) 86 | state_dict = state_dict_convert[model_type](state_dict) 87 | kwargs["state_dict"] = state_dict 88 | elif os.path.isfile(pretrained_model_name_or_path): 89 | kwargs["state_dict"] = torch.load(pretrained_model_name_or_path, map_location='cpu') 90 | 91 | if kwargs["state_dict"] is None: 92 | logger.info("s2s-ft does't support the model !") 93 | raise NotImplementedError() 94 | 95 | config = kwargs["config"] 96 | state_dict = kwargs["state_dict"] 97 | # initialize new position embeddings (From Microsoft/UniLM) 98 | _k = 'bert.embeddings.position_embeddings.weight' 99 | # if _k in state_dict and config.max_position_embeddings != state_dict[_k].shape[0]: 100 | # logger.info("config.max_position_embeddings != state_dict[bert.embeddings.position_embeddings.weight] ({0} - {1})".format( 101 | # config.max_position_embeddings, state_dict[_k].shape[0])) 102 | # if config.max_position_embeddings > state_dict[_k].shape[0]: 103 | # old_size = state_dict[_k].shape[0] 104 | # # state_dict[_k].data = state_dict[_k].data.resize_(config.max_position_embeddings, state_dict[_k].shape[1]) 105 | # state_dict[_k].resize_( 106 | # config.max_position_embeddings, state_dict[_k].shape[1]) 107 | # start = old_size 108 | # while start < config.max_position_embeddings: 109 | # chunk_size = min( 110 | # old_size, config.max_position_embeddings - start) 111 | # state_dict[_k].data[start:start+chunk_size, 112 | # :].copy_(state_dict[_k].data[:chunk_size, :]) 113 | # start += chunk_size 114 | # elif config.max_position_embeddings < state_dict[_k].shape[0]: 115 | # state_dict[_k].data = state_dict[_k].data[:config.max_position_embeddings, :] 116 | 117 | _k = 'bert.embeddings.position_embeddings.weight' 118 | if _k in state_dict: 119 | if config.max_position_embeddings > state_dict[_k].shape[0]: 120 | logger.info("Resize > position embeddings !") 121 | old_vocab_size = state_dict[_k].shape[0] 122 | new_postion_embedding = state_dict[_k].data.new_tensor(torch.ones( 123 | size=(config.max_position_embeddings, state_dict[_k].shape[1])), dtype=torch.float) 124 | new_postion_embedding = nn.Parameter(data=new_postion_embedding, requires_grad=True) 125 | new_postion_embedding.data.normal_(mean=0.0, std=config.initializer_range) 126 | max_range = config.max_position_embeddings if reuse_position_embedding else old_vocab_size 127 | shift = 0 128 | while shift < max_range: 129 | delta = min(old_vocab_size, max_range - shift) 130 | new_postion_embedding.data[shift: shift + delta, :] = state_dict[_k][:delta, :] 131 | logger.info(" CP [%d ~ %d] into [%d ~ %d] " % (0, delta, shift, shift + delta)) 132 | shift += delta 133 | state_dict[_k] = new_postion_embedding.data 134 | del new_postion_embedding 135 | elif config.max_position_embeddings < state_dict[_k].shape[0]: 136 | logger.info("Resize < position embeddings !") 137 | old_vocab_size = state_dict[_k].shape[0] 138 | new_postion_embedding = state_dict[_k].data.new_tensor(torch.ones( 139 | size=(config.max_position_embeddings, state_dict[_k].shape[1])), dtype=torch.float) 140 | new_postion_embedding = nn.Parameter(data=new_postion_embedding, requires_grad=True) 141 | new_postion_embedding.data.normal_(mean=0.0, std=config.initializer_range) 142 | new_postion_embedding.data.copy_(state_dict[_k][:config.max_position_embeddings, :]) 143 | state_dict[_k] = new_postion_embedding.data 144 | del new_postion_embedding 145 | 146 | return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 147 | 148 | 149 | class BertEmbeddings(nn.Module): 150 | """Construct the embeddings from word, position and token_type embeddings. 151 | """ 152 | def __init__(self, config): 153 | super(BertEmbeddings, self).__init__() 154 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 155 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 156 | if config.type_vocab_size > 0: 157 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 158 | else: 159 | self.token_type_embeddings = None 160 | 161 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 162 | # any TensorFlow checkpoint file 163 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 164 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 165 | 166 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): 167 | if input_ids is not None: 168 | input_shape = input_ids.size() 169 | else: 170 | input_shape = inputs_embeds.size()[:-1] 171 | 172 | seq_length = input_shape[1] 173 | device = input_ids.device if input_ids is not None else inputs_embeds.device 174 | if position_ids is None: 175 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device) 176 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 177 | if token_type_ids is None: 178 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 179 | 180 | if inputs_embeds is None: 181 | inputs_embeds = self.word_embeddings(input_ids) 182 | position_embeddings = self.position_embeddings(position_ids) 183 | 184 | embeddings = inputs_embeds + position_embeddings 185 | 186 | if self.token_type_embeddings: 187 | embeddings = embeddings + self.token_type_embeddings(token_type_ids) 188 | 189 | embeddings = self.LayerNorm(embeddings) 190 | embeddings = self.dropout(embeddings) 191 | return embeddings 192 | 193 | 194 | class BertSelfAttention(nn.Module): 195 | def __init__(self, config): 196 | super(BertSelfAttention, self).__init__() 197 | if config.hidden_size % config.num_attention_heads != 0: 198 | raise ValueError( 199 | "The hidden size (%d) is not a multiple of the number of attention " 200 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 201 | self.output_attentions = config.output_attentions 202 | 203 | self.num_attention_heads = config.num_attention_heads 204 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 205 | self.all_head_size = self.num_attention_heads * self.attention_head_size 206 | 207 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 208 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 209 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 210 | 211 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 212 | 213 | def transpose_for_scores(self, x): 214 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 215 | x = x.view(*new_x_shape) 216 | return x.permute(0, 2, 1, 3) 217 | 218 | def multi_head_attention(self, query, key, value, attention_mask): 219 | query_layer = self.transpose_for_scores(query) 220 | key_layer = self.transpose_for_scores(key) 221 | value_layer = self.transpose_for_scores(value) 222 | 223 | # Take the dot product between "query" and "key" to get the raw attention scores. 224 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 225 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 226 | if attention_mask is not None: 227 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 228 | attention_scores = attention_scores + attention_mask 229 | 230 | # Normalize the attention scores to probabilities. 231 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 232 | 233 | # This is actually dropping out entire tokens to attend to, which might 234 | # seem a bit unusual, but is taken from the original Transformer paper. 235 | attention_probs = self.dropout(attention_probs) 236 | context_layer = torch.matmul(attention_probs, value_layer) 237 | 238 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 239 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 240 | context_layer = context_layer.view(*new_context_layer_shape) 241 | 242 | return (context_layer, attention_probs) if self.output_attentions else (context_layer,) 243 | 244 | def forward(self, hidden_states, attention_mask=None, encoder_hidden_states=None, split_lengths=None): 245 | mixed_query_layer = self.query(hidden_states) 246 | if split_lengths: 247 | assert not self.output_attentions 248 | 249 | # If this is instantiated as a cross-attention module, the keys 250 | # and values come from an encoder; the attention mask needs to be 251 | # such that the encoder's padding tokens are not attended to. 252 | if encoder_hidden_states is not None: 253 | mixed_key_layer = self.key(encoder_hidden_states) 254 | mixed_value_layer = self.value(encoder_hidden_states) 255 | else: 256 | mixed_key_layer = self.key(hidden_states) 257 | mixed_value_layer = self.value(hidden_states) 258 | 259 | if split_lengths: 260 | query_parts = torch.split(mixed_query_layer, split_lengths, dim=1) 261 | key_parts = torch.split(mixed_key_layer, split_lengths, dim=1) 262 | value_parts = torch.split(mixed_value_layer, split_lengths, dim=1) 263 | 264 | key = None 265 | value = None 266 | outputs = [] 267 | sum_length = 0 268 | for (query, _key, _value, part_length) in zip(query_parts, key_parts, value_parts, split_lengths): 269 | key = _key if key is None else torch.cat((key, _key), dim=1) 270 | value = _value if value is None else torch.cat((value, _value), dim=1) 271 | sum_length += part_length 272 | outputs.append(self.multi_head_attention( 273 | query, key, value, attention_mask[:, :, sum_length - part_length: sum_length, :sum_length] 274 | )[0]) 275 | outputs = (torch.cat(outputs, dim=1), ) 276 | else: 277 | outputs = self.multi_head_attention( 278 | mixed_query_layer, mixed_key_layer, mixed_value_layer, attention_mask) 279 | return outputs 280 | 281 | 282 | class BertAttention(nn.Module): 283 | def __init__(self, config): 284 | super(BertAttention, self).__init__() 285 | self.self = BertSelfAttention(config) 286 | self.output = BertSelfOutput(config) 287 | 288 | def forward(self, hidden_states, attention_mask=None, encoder_hidden_states=None, split_lengths=None): 289 | self_outputs = self.self( 290 | hidden_states, attention_mask=attention_mask, 291 | encoder_hidden_states=encoder_hidden_states, split_lengths=split_lengths) 292 | attention_output = self.output(self_outputs[0], hidden_states) 293 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 294 | return outputs 295 | 296 | 297 | class BertLayer(nn.Module): 298 | def __init__(self, config): 299 | super(BertLayer, self).__init__() 300 | self.attention = BertAttention(config) 301 | self.intermediate = BertIntermediate(config) 302 | self.output = BertOutput(config) 303 | 304 | def forward(self, hidden_states, attention_mask=None, split_lengths=None): 305 | self_attention_outputs = self.attention( 306 | hidden_states, attention_mask, split_lengths=split_lengths) 307 | attention_output = self_attention_outputs[0] 308 | 309 | intermediate_output = self.intermediate(attention_output) 310 | layer_output = self.output(intermediate_output, attention_output) 311 | outputs = (layer_output,) + self_attention_outputs[1:] 312 | return outputs 313 | 314 | 315 | class BertEncoder(nn.Module): 316 | def __init__(self, config): 317 | super(BertEncoder, self).__init__() 318 | self.output_attentions = config.output_attentions 319 | self.output_hidden_states = config.output_hidden_states 320 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 321 | 322 | def forward(self, hidden_states, attention_mask=None, split_lengths=None): 323 | all_hidden_states = () 324 | all_attentions = () 325 | for i, layer_module in enumerate(self.layer): 326 | if self.output_hidden_states: 327 | all_hidden_states = all_hidden_states + (hidden_states,) 328 | 329 | layer_outputs = layer_module(hidden_states, attention_mask, split_lengths=split_lengths) 330 | hidden_states = layer_outputs[0] 331 | 332 | if self.output_attentions: 333 | all_attentions = all_attentions + (layer_outputs[1],) 334 | 335 | # Add last layer 336 | if self.output_hidden_states: 337 | all_hidden_states = all_hidden_states + (hidden_states,) 338 | 339 | outputs = (hidden_states,) 340 | if self.output_hidden_states: 341 | outputs = outputs + (all_hidden_states,) 342 | if self.output_attentions: 343 | outputs = outputs + (all_attentions,) 344 | return outputs # last-layer hidden state, (all hidden states), (all attentions) 345 | 346 | 347 | class BertModel(BertPreTrainedForSeq2SeqModel): 348 | r""" 349 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 350 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` 351 | Sequence of hidden-states at the output of the last layer of the model. 352 | **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` 353 | Last layer hidden-state of the first token of the sequence (classification token) 354 | further processed by a Linear layer and a Tanh activation function. The Linear 355 | layer weights are trained from the next sentence prediction (classification) 356 | objective during Bert pretraining. This output is usually *not* a good summary 357 | of the semantic content of the input, you're often better with averaging or pooling 358 | the sequence of hidden-states for the whole input sequence. 359 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 360 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 361 | of shape ``(batch_size, sequence_length, hidden_size)``: 362 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 363 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 364 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 365 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 366 | 367 | Examples:: 368 | 369 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 370 | model = BertModel.from_pretrained('bert-base-uncased') 371 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 372 | outputs = model(input_ids) 373 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 374 | 375 | """ 376 | def __init__(self, config): 377 | super(BertModel, self).__init__(config) 378 | self.config = config 379 | 380 | self.embeddings = BertEmbeddings(config) 381 | self.encoder = BertEncoder(config) 382 | 383 | @property 384 | def dtype(self) -> torch.dtype: 385 | try: 386 | return next(self.parameters()).dtype 387 | except StopIteration: 388 | # For nn.DataParallel compatibility in PyTorch 1.5 389 | 390 | def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, torch.Tensor]]: 391 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 392 | return tuples 393 | 394 | gen = self._named_members(get_members_fn=find_tensor_attributes) 395 | first_tuple = next(gen) 396 | return first_tuple[1].dtype 397 | 398 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, 399 | position_ids=None, inputs_embeds=None, split_lengths=None): 400 | if input_ids is not None and inputs_embeds is not None: 401 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 402 | elif input_ids is not None: 403 | input_shape = input_ids.size() 404 | elif inputs_embeds is not None: 405 | input_shape = inputs_embeds.size()[:-1] 406 | else: 407 | raise ValueError("You have to specify either input_ids or inputs_embeds") 408 | 409 | device = input_ids.device if input_ids is not None else inputs_embeds.device 410 | 411 | if attention_mask is None: 412 | attention_mask = torch.ones(input_shape, device=device) 413 | 414 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 415 | # ourselves in which case we just need to make it broadcastable to all heads. 416 | if attention_mask.dim() == 3: 417 | extended_attention_mask = attention_mask[:, None, :, :] 418 | 419 | # Provided a padding mask of dimensions [batch_size, seq_length] 420 | # - if the model is a decoder, apply a causal mask in addition to the padding mask 421 | # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] 422 | if attention_mask.dim() == 2: 423 | extended_attention_mask = attention_mask[:, None, None, :] 424 | 425 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 426 | # masked positions, this operation will create a tensor which is 0.0 for 427 | # positions we want to attend and -10000.0 for masked positions. 428 | # Since we are adding it to the raw scores before the softmax, this is 429 | # effectively the same as removing these entirely. 430 | # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 431 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 432 | 433 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 434 | 435 | embedding_output = self.embeddings( 436 | input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds) 437 | encoder_outputs = self.encoder( 438 | embedding_output, attention_mask=extended_attention_mask, split_lengths=split_lengths) 439 | sequence_output = encoder_outputs[0] 440 | 441 | outputs = (sequence_output, ) + encoder_outputs[1:] # add hidden_states and attentions if they are here 442 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 443 | 444 | 445 | class LabelSmoothingLoss(_Loss): 446 | """ 447 | With label smoothing, 448 | KL-divergence between q_{smoothed ground truth prob.}(w) 449 | and p_{prob. computed by model}(w) is minimized. 450 | """ 451 | 452 | def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, reduction='mean'): 453 | assert 0.0 < label_smoothing <= 1.0 454 | self.ignore_index = ignore_index 455 | super(LabelSmoothingLoss, self).__init__( 456 | size_average=size_average, reduce=reduce, reduction=reduction) 457 | 458 | assert label_smoothing > 0 459 | assert tgt_vocab_size > 0 460 | 461 | smoothing_value = label_smoothing / (tgt_vocab_size - 2) 462 | one_hot = torch.full((tgt_vocab_size,), smoothing_value) 463 | one_hot[self.ignore_index] = 0 464 | self.register_buffer('one_hot', one_hot.unsqueeze(0)) 465 | self.confidence = 1.0 - label_smoothing 466 | self.tgt_vocab_size = tgt_vocab_size 467 | 468 | def forward(self, output, target): 469 | """ 470 | output (FloatTensor): batch_size * num_pos * n_classes 471 | target (LongTensor): batch_size * num_pos 472 | """ 473 | assert self.tgt_vocab_size == output.size(2) 474 | batch_size, num_pos = target.size(0), target.size(1) 475 | output = output.view(-1, self.tgt_vocab_size) 476 | target = target.view(-1) 477 | model_prob = self.one_hot.float().repeat(target.size(0), 1) 478 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence) 479 | model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) 480 | 481 | return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2) 482 | 483 | 484 | class BertLMPredictionHead(nn.Module): 485 | def __init__(self, config, decoder_weight): 486 | super(BertLMPredictionHead, self).__init__() 487 | self.transform = BertPredictionHeadTransform(config) 488 | 489 | # The output weights are the same as the input embeddings, but there is 490 | # an output-only bias for each token. 491 | self.decoder_weight = decoder_weight 492 | 493 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 494 | 495 | def forward(self, hidden_states): 496 | hidden_states = self.transform(hidden_states) 497 | hidden_states = F.linear(hidden_states, weight=self.decoder_weight, bias=self.bias) 498 | return hidden_states 499 | 500 | 501 | class BertOnlyMLMHead(nn.Module): 502 | def __init__(self, config, decoder_weight): 503 | super(BertOnlyMLMHead, self).__init__() 504 | self.predictions = BertLMPredictionHead(config, decoder_weight) 505 | 506 | def forward(self, sequence_output): 507 | prediction_scores = self.predictions(sequence_output) 508 | return prediction_scores 509 | 510 | 511 | class BertForSequenceToSequence(BertPreTrainedForSeq2SeqModel): 512 | def __init__(self, config): 513 | super(BertForSequenceToSequence, self).__init__(config) 514 | self.bert = BertModel(config) 515 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 516 | self.init_weights() 517 | 518 | self.log_softmax = nn.LogSoftmax() 519 | 520 | # setattr(config, 'label_smoothing', 0.1) 521 | self.source_type_id = config.source_type_id 522 | self.target_type_id = config.target_type_id 523 | 524 | if config.label_smoothing > 0: 525 | self.crit_mask_lm_smoothed = LabelSmoothingLoss( 526 | config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none') 527 | self.crit_mask_lm = None 528 | else: 529 | self.crit_mask_lm_smoothed = None 530 | self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') 531 | 532 | @staticmethod 533 | def create_mask_and_position_ids(num_tokens, max_len, offset=None): 534 | base_position_matrix = torch.arange( 535 | 0, max_len, dtype=num_tokens.dtype, device=num_tokens.device).view(1, -1) 536 | mask = (base_position_matrix < num_tokens.view(-1, 1)).type_as(num_tokens) 537 | if offset is not None: 538 | base_position_matrix = base_position_matrix + offset.view(-1, 1) 539 | position_ids = base_position_matrix * mask 540 | return mask, position_ids 541 | 542 | @staticmethod 543 | def create_attention_mask(source_mask, target_mask, source_position_ids, target_span_ids): 544 | weight = torch.cat((torch.zeros_like(source_position_ids), target_span_ids, -target_span_ids), dim=1) 545 | from_weight = weight.unsqueeze(-1) 546 | to_weight = weight.unsqueeze(1) 547 | 548 | true_tokens = (0 <= to_weight) & (torch.cat((source_mask, target_mask, target_mask), dim=1) == 1).unsqueeze(1) 549 | true_tokens_mask = (from_weight >= 0) & true_tokens & (to_weight <= from_weight) 550 | pseudo_tokens_mask = (from_weight < 0) & true_tokens & (-to_weight > from_weight) 551 | pseudo_tokens_mask = pseudo_tokens_mask | ((from_weight < 0) & (to_weight == from_weight)) 552 | 553 | return (true_tokens_mask | pseudo_tokens_mask).type_as(source_mask) 554 | 555 | def forward(self, source_ids, noisy_source_ids, target_ids, pseudo_ids, num_source_tokens, num_noisy_source_tokens, num_target_tokens, target_span_ids=None): 556 | if(self.use_noisy_student >0): 557 | source_ids = noisy_source_ids 558 | num_source_tokens = num_noisy_source_tokens 559 | source_len = source_ids.size(1) 560 | target_len = target_ids.size(1) 561 | pseudo_len = pseudo_ids.size(1) 562 | assert target_len == pseudo_len 563 | assert source_len > 0 and target_len > 0 564 | split_lengths = (source_len, target_len, pseudo_len) 565 | 566 | input_ids = torch.cat((source_ids, target_ids, pseudo_ids), dim=1) 567 | 568 | token_type_ids = torch.cat( 569 | (torch.ones_like(source_ids) * self.source_type_id, 570 | torch.ones_like(target_ids) * self.target_type_id, 571 | torch.ones_like(pseudo_ids) * self.target_type_id), dim=1) 572 | 573 | source_mask, source_position_ids = \ 574 | self.create_mask_and_position_ids(num_source_tokens, source_len) 575 | target_mask, target_position_ids = \ 576 | self.create_mask_and_position_ids(num_target_tokens, target_len, offset=num_source_tokens) 577 | 578 | position_ids = torch.cat((source_position_ids, target_position_ids, target_position_ids), dim=1) 579 | if target_span_ids is None: 580 | target_span_ids = target_position_ids 581 | attention_mask = self.create_attention_mask(source_mask, target_mask, source_position_ids, target_span_ids) 582 | 583 | outputs = self.bert( 584 | input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, 585 | position_ids=position_ids, split_lengths=split_lengths) 586 | 587 | sequence_output = outputs[0] 588 | pseudo_sequence_output = sequence_output[:, source_len + target_len:, ] 589 | 590 | def loss_mask_and_normalize(loss, mask): 591 | mask = mask.type_as(loss) 592 | loss = loss * mask 593 | denominator = torch.sum(mask) + 1e-5 594 | return (loss / denominator).sum() 595 | 596 | prediction_scores_masked = self.cls(pseudo_sequence_output) 597 | 598 | if self.crit_mask_lm_smoothed: 599 | masked_lm_loss = self.crit_mask_lm_smoothed( 600 | F.log_softmax(prediction_scores_masked.float(), dim=-1), target_ids) 601 | else: 602 | masked_lm_loss = self.crit_mask_lm( 603 | prediction_scores_masked.transpose(1, 2).float(), target_ids) 604 | pseudo_lm_loss = loss_mask_and_normalize( 605 | masked_lm_loss.float(), target_mask) 606 | 607 | return pseudo_lm_loss 608 | 609 | 610 | 611 | 612 | 613 | class BertForSequenceToSequence_Distill(BertPreTrainedForSeq2SeqModel): 614 | def __init__(self, config): 615 | super(BertForSequenceToSequence_Distill, self).__init__(config) 616 | self.bert = BertModel(config) 617 | 618 | _hidden_dropout_prob = config.hidden_dropout_prob 619 | _attention_probs_dropout_prob = config.attention_probs_dropout_prob 620 | config.hidden_dropout_prob = config.teacher_dropout_prob 621 | config.attention_probs_dropout_prob = config.teacher_dropout_prob 622 | 623 | self.teacher_model = BertModel(config) 624 | 625 | config.hidden_dropout_prob = _hidden_dropout_prob 626 | config.attention_probs_dropout_prob = _attention_probs_dropout_prob 627 | 628 | self.use_teacher_dropout = config.use_teacher_dropout 629 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 630 | self.teacher_cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 631 | self.init_weights() 632 | self.kd_weight = 0 633 | self.log_softmax = nn.LogSoftmax() 634 | # setattr(config, 'label_smoothing', 0.1) 635 | self.source_type_id = config.source_type_id 636 | self.target_type_id = config.target_type_id 637 | 638 | self.use_random_replace_input = -1 639 | self.random_replace_input_p = -1 640 | self.random_replace_input_k = -1 641 | 642 | self.use_noisy_student = -1 643 | self.use_noisy_teacher = -1 644 | 645 | if config.label_smoothing > 0: 646 | self.crit_mask_lm_smoothed = LabelSmoothingLoss( 647 | config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none') 648 | self.crit_mask_lm = None 649 | else: 650 | self.crit_mask_lm_smoothed = None 651 | self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') 652 | 653 | @staticmethod 654 | def create_mask_and_position_ids(num_tokens, max_len, offset=None): 655 | base_position_matrix = torch.arange( 656 | 0, max_len, dtype=num_tokens.dtype, device=num_tokens.device).view(1, -1) 657 | mask = (base_position_matrix < num_tokens.view(-1, 1)).type_as(num_tokens) 658 | if offset is not None: 659 | base_position_matrix = base_position_matrix + offset.view(-1, 1) 660 | position_ids = base_position_matrix * mask 661 | return mask, position_ids 662 | 663 | @staticmethod 664 | def create_attention_mask(source_mask, target_mask, source_position_ids, target_span_ids): 665 | weight = torch.cat((torch.zeros_like(source_position_ids), target_span_ids, -target_span_ids), dim=1) 666 | from_weight = weight.unsqueeze(-1) 667 | to_weight = weight.unsqueeze(1) 668 | 669 | true_tokens = (0 <= to_weight) & (torch.cat((source_mask, target_mask, target_mask), dim=1) == 1).unsqueeze(1) 670 | true_tokens_mask = (from_weight >= 0) & true_tokens & (to_weight <= from_weight) 671 | pseudo_tokens_mask = (from_weight < 0) & true_tokens & (-to_weight > from_weight) 672 | pseudo_tokens_mask = pseudo_tokens_mask | ((from_weight < 0) & (to_weight == from_weight)) 673 | 674 | return (true_tokens_mask | pseudo_tokens_mask).type_as(source_mask) 675 | 676 | 677 | def get_kd_loss(self, student_lprobs, teacher_probs): 678 | 679 | # kd_loss = F.kl_div(student_lprobs.view(-1, student_lprobs.size(-1)), teacher_probs.view(-1,student_lprobs.size(-1))) 680 | kd_loss = F.kl_div(student_lprobs, teacher_probs, reduction='none').sum(-1) 681 | return kd_loss 682 | 683 | 684 | def forward(self, source_ids, noisy_source_ids_s, noisy_source_ids_t, target_ids, pseudo_ids, num_source_tokens, num_noisy_source_tokens_s, num_noisy_source_tokens_t, num_target_tokens, target_span_ids=None): 685 | if(self.use_noisy_student>0 ): 686 | if (self.use_random_replace_input > 0): 687 | with torch.no_grad(): 688 | input_embs = self.teacher_model.embeddings(noisy_source_ids_s) 689 | top_words = torch.topk(torch.matmul(input_embs, 690 | self.teacher_model.embeddings.word_embeddings.weight.data.transpose( 691 | 1, 0)), self.random_replace_input_k + 1, -1)[1][:, :, 1:] 692 | 693 | indices = np.random.rand(*top_words.shape).argsort(axis=-1) 694 | indices[indices > 1] = 0 695 | noisy_words = torch.masked_select(top_words, 696 | torch.tensor(indices).to(torch.bool).to(noisy_source_ids_s.device)).reshape( 697 | top_words.shape[0], top_words.shape[1]) 698 | indices = torch.bernoulli(torch.ones(*noisy_words.shape) * (1 - self.random_replace_input_p)).long().to( 699 | noisy_source_ids_s.device) 700 | noisy_x = indices * noisy_source_ids_s + (1 - indices) * noisy_words 701 | noisy_source_ids_s = noisy_x 702 | if(self.use_noisy_teacher>0 ): 703 | if (self.use_random_replace_input > 0): 704 | with torch.no_grad(): 705 | input_embs = self.teacher_model.embeddings(noisy_source_ids_t) 706 | top_words = torch.topk(torch.matmul(input_embs, 707 | self.teacher_model.embeddings.word_embeddings.weight.data.transpose( 708 | 1, 0)), self.random_replace_input_k + 1, -1)[1][:, :, 1:] 709 | 710 | indices = np.random.rand(*top_words.shape).argsort(axis=-1) 711 | indices[indices > 1] = 0 712 | noisy_words = torch.masked_select(top_words, 713 | torch.tensor(indices).to(torch.bool).to(noisy_source_ids_t.device)).reshape( 714 | top_words.shape[0], top_words.shape[1]) 715 | indices = torch.bernoulli(torch.ones(*noisy_words.shape) * (1 - self.random_replace_input_p)).long().to( 716 | noisy_source_ids_t.device) 717 | noisy_x = indices * noisy_source_ids_t + (1 - indices) * noisy_words 718 | noisy_source_ids_t = noisy_x 719 | 720 | if (self.use_noisy_student>0): 721 | student_source_ids = noisy_source_ids_s 722 | student_num_source_tokens = num_noisy_source_tokens_s 723 | else: 724 | student_source_ids = source_ids 725 | student_num_source_tokens = num_source_tokens 726 | if(self.use_noisy_teacher>0): 727 | teacher_source_ids = noisy_source_ids_t 728 | teacher_num_source_tokens = num_noisy_source_tokens_t 729 | else: 730 | teacher_source_ids = source_ids 731 | teacher_num_source_tokens = num_source_tokens 732 | 733 | # if (self.use_random_drop_input > 0): 734 | # with torch.no_grad(): 735 | # indices = np.random.rand(*top_words.shape).argsort(axis=-1) 736 | # indices[indices > 1] = 0 737 | # noisy_words = torch.masked_select(top_words, 738 | # torch.tensor(indices).to(torch.bool).to(source_ids.device)).reshape( 739 | # top_words.shape[0], top_words.shape[1]) 740 | # indices = torch.bernoulli(torch.ones(*noisy_words.shape) * (1 - self.random_replace_input_p)).long().to( 741 | # source_ids.device) 742 | # noisy_x = indices * source_ids + (1 - indices) * noisy_words 743 | # source_ids = noisy_x 744 | 745 | student_source_len = student_source_ids.size(1) 746 | teacher_source_len = teacher_source_ids.size(1) 747 | target_len = target_ids.size(1) 748 | pseudo_len = pseudo_ids.size(1) 749 | assert target_len == pseudo_len 750 | assert student_source_len > 0 and target_len > 0 751 | student_split_lengths = (student_source_len, target_len, pseudo_len) 752 | teacher_split_lengths = (teacher_source_len, target_len, pseudo_len) 753 | 754 | student_input_ids = torch.cat((student_source_ids, target_ids, pseudo_ids), dim=1) 755 | teacher_input_ids = torch.cat((teacher_source_ids, target_ids, pseudo_ids), dim=1) 756 | 757 | student_token_type_ids = torch.cat( 758 | (torch.ones_like(student_source_ids) * self.source_type_id, 759 | torch.ones_like(target_ids) * self.target_type_id, 760 | torch.ones_like(pseudo_ids) * self.target_type_id), dim=1) 761 | teacher_token_type_ids = torch.cat( 762 | (torch.ones_like(teacher_source_ids) * self.source_type_id, 763 | torch.ones_like(target_ids) * self.target_type_id, 764 | torch.ones_like(pseudo_ids) * self.target_type_id), dim=1) 765 | 766 | student_source_mask, student_source_position_ids = \ 767 | self.create_mask_and_position_ids(student_num_source_tokens, student_source_len) 768 | student_target_mask, student_target_position_ids = \ 769 | self.create_mask_and_position_ids(num_target_tokens, target_len, offset=student_num_source_tokens) 770 | teacher_source_mask, teacher_source_position_ids = \ 771 | self.create_mask_and_position_ids(teacher_num_source_tokens, teacher_source_len) 772 | teacher_target_mask, teacher_target_position_ids = \ 773 | self.create_mask_and_position_ids(num_target_tokens, target_len, offset=teacher_num_source_tokens) 774 | 775 | student_position_ids = torch.cat((student_source_position_ids, student_target_position_ids, student_target_position_ids), dim=1) 776 | teacher_position_ids = torch.cat((teacher_source_position_ids, teacher_target_position_ids, teacher_target_position_ids), dim=1) 777 | 778 | if target_span_ids is None: 779 | student_target_span_ids = student_target_position_ids 780 | teacher_target_span_ids = teacher_target_position_ids 781 | else: 782 | student_target_span_ids = target_span_ids 783 | teacher_target_span_ids = target_span_ids 784 | student_attention_mask = self.create_attention_mask(student_source_mask, student_target_mask, student_source_position_ids, student_target_span_ids) 785 | teacher_attention_mask = self.create_attention_mask(teacher_source_mask, teacher_target_mask, teacher_source_position_ids, teacher_target_span_ids) 786 | 787 | student_outputs = self.bert( 788 | student_input_ids, attention_mask=student_attention_mask, token_type_ids=student_token_type_ids, 789 | position_ids=student_position_ids, split_lengths=student_split_lengths) 790 | student_sequence_output = student_outputs[0] 791 | student_pseudo_sequence_output = student_sequence_output[:, student_source_len + target_len:, ] 792 | student_prediction_scores_masked = self.cls(student_pseudo_sequence_output) 793 | 794 | with torch.no_grad(): 795 | if(self.use_teacher_dropout<0): 796 | self.teacher_model.eval() 797 | self.teacher_cls.eval() 798 | teacher_outputs = self.teacher_model( 799 | teacher_input_ids, attention_mask=teacher_attention_mask, token_type_ids=teacher_token_type_ids, 800 | position_ids=teacher_position_ids, split_lengths=teacher_split_lengths) 801 | teacher_sequence_output = teacher_outputs[0] 802 | teacher_pseudo_sequence_output = teacher_sequence_output[:, teacher_source_len + target_len:, ] 803 | teacher_prediction_scores_masked = self.teacher_cls(teacher_pseudo_sequence_output) 804 | teacher_prediction_scores_masked.detach_() 805 | 806 | def loss_mask_and_normalize(loss, mask): 807 | mask = mask.type_as(loss) 808 | loss = loss * mask 809 | denominator = torch.sum(mask) + 1e-5 810 | return (loss / denominator).sum() 811 | 812 | kd_loss = self.get_kd_loss(F.log_softmax(student_prediction_scores_masked.float(),-1), F.softmax(teacher_prediction_scores_masked.float(),-1)) 813 | if self.crit_mask_lm_smoothed: 814 | masked_lm_loss = self.crit_mask_lm_smoothed( 815 | F.log_softmax(student_prediction_scores_masked.float(), dim=-1), target_ids) 816 | else: 817 | masked_lm_loss = self.crit_mask_lm( 818 | student_prediction_scores_masked.transpose(1, 2).float(), target_ids) 819 | pseudo_lm_loss = loss_mask_and_normalize( 820 | masked_lm_loss.float(), student_target_mask) 821 | pseudo_kd_loss = loss_mask_and_normalize( 822 | kd_loss.float(), student_target_mask) 823 | 824 | pseudo_loss = pseudo_kd_loss * self.kd_weight + pseudo_lm_loss * (1 - self.kd_weight) 825 | return pseudo_loss 826 | 827 | 828 | 829 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1, min_lr=0, fix_lr=0): 830 | """ Create a schedule with a learning rate that decreases linearly after 831 | linearly increasing during a warmup period. 832 | """ 833 | 834 | def lr_lambda(current_step): 835 | if current_step < num_warmup_steps: 836 | return float(current_step) / float(max(1, num_warmup_steps)) 837 | return max( 838 | 0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 839 | ) 840 | 841 | return LambdaLR(optimizer, lr_lambda, last_epoch, min_lr, fix_lr) 842 | 843 | 844 | 845 | class LambdaLR(_LRScheduler): 846 | """Sets the learning rate of each parameter group to the initial lr 847 | times a given function. When last_epoch=-1, sets initial lr as lr. 848 | 849 | Args: 850 | optimizer (Optimizer): Wrapped optimizer. 851 | lr_lambda (function or list): A function which computes a multiplicative 852 | factor given an integer parameter epoch, or a list of such 853 | functions, one for each group in optimizer.param_groups. 854 | last_epoch (int): The index of last epoch. Default: -1. 855 | 856 | Example: 857 | >>> # Assuming optimizer has two groups. 858 | >>> lambda1 = lambda epoch: epoch // 30 859 | >>> lambda2 = lambda epoch: 0.95 ** epoch 860 | >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 861 | >>> for epoch in range(100): 862 | >>> train(...) 863 | >>> validate(...) 864 | >>> scheduler.step() 865 | """ 866 | 867 | def __init__(self, optimizer, lr_lambda, last_epoch=-1, min_lr=0, fix_lr=0): 868 | self.optimizer = optimizer 869 | self.min_lr = min_lr 870 | self.fix_lr = fix_lr 871 | if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): 872 | self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) 873 | else: 874 | if len(lr_lambda) != len(optimizer.param_groups): 875 | raise ValueError("Expected {} lr_lambdas, but got {}".format( 876 | len(optimizer.param_groups), len(lr_lambda))) 877 | self.lr_lambdas = list(lr_lambda) 878 | self.last_epoch = last_epoch 879 | super(LambdaLR, self).__init__(optimizer, last_epoch) 880 | 881 | 882 | def state_dict(self): 883 | """Returns the state of the scheduler as a :class:`dict`. 884 | 885 | It contains an entry for every variable in self.__dict__ which 886 | is not the optimizer. 887 | The learning rate lambda functions will only be saved if they are callable objects 888 | and not if they are functions or lambdas. 889 | """ 890 | state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} 891 | state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) 892 | 893 | for idx, fn in enumerate(self.lr_lambdas): 894 | if not isinstance(fn, types.FunctionType): 895 | state_dict['lr_lambdas'][idx] = fn.__dict__.copy() 896 | 897 | return state_dict 898 | 899 | def load_state_dict(self, state_dict): 900 | """Loads the schedulers state. 901 | 902 | Arguments: 903 | state_dict (dict): scheduler state. Should be an object returned 904 | from a call to :meth:`state_dict`. 905 | """ 906 | lr_lambdas = state_dict.pop('lr_lambdas') 907 | self.__dict__.update(state_dict) 908 | 909 | for idx, fn in enumerate(lr_lambdas): 910 | if fn is not None: 911 | self.lr_lambdas[idx].__dict__.update(fn) 912 | 913 | def get_lr(self): 914 | if(self.fix_lr>0): 915 | return [self.fix_lr 916 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] 917 | else: 918 | return[max(base_lr * lmbda(self.last_epoch),self.min_lr) 919 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] 920 | 921 | 922 | --------------------------------------------------------------------------------