├── __init__.py ├── requirements.txt ├── rouge_cli.py ├── sentence_splitter.py ├── convert_model_to_fp16.py ├── save_len_file.py ├── seq2seq_training_args.py ├── README.md ├── generate_augmentation.py ├── augmentation.py ├── run_generation.py ├── postprocess_cnndm.py ├── run_distributed_eval.py ├── LICENSE ├── cl_finetune_trainer.py ├── seq2seq_trainer.py ├── bs_pyrouge.py ├── utils.py └── cl_seq2seq_trainer.py /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | sys.path.insert(1, os.path.dirname(os.path.realpath(__file__))) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch == 1.7.1 3 | tensorboard 4 | transformers == 4.1.1 5 | rouge_score 6 | sacrebleu 7 | fairseq 8 | pyrouge 9 | scipy 10 | pandas -------------------------------------------------------------------------------- /rouge_cli.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import fire 16 | 17 | from utils import calculate_rouge, save_json 18 | 19 | 20 | def calculate_rouge_path(pred_path, tgt_path, save_path=None, **kwargs): 21 | """Kwargs will be passed to calculate_rouge""" 22 | pred_lns = [x.strip() for x in open(pred_path).readlines()] 23 | tgt_lns = [x.strip() for x in open(tgt_path).readlines()][: len(pred_lns)] 24 | metrics = calculate_rouge(pred_lns, tgt_lns, **kwargs) 25 | if save_path is not None: 26 | save_json(metrics, save_path, indent=None) 27 | return metrics # these print nicely 28 | 29 | 30 | if __name__ == "__main__": 31 | fire.Fire(calculate_rouge_path) 32 | -------------------------------------------------------------------------------- /sentence_splitter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from filelock import FileLock 17 | 18 | 19 | try: 20 | import nltk 21 | 22 | NLTK_AVAILABLE = True 23 | except (ImportError, ModuleNotFoundError): 24 | NLTK_AVAILABLE = False 25 | 26 | if NLTK_AVAILABLE: 27 | with FileLock(".lock") as lock: 28 | nltk.download("punkt", quiet=True) 29 | 30 | 31 | def add_newline_to_end_of_each_sentence(x: str) -> str: 32 | """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" 33 | re.sub("", "", x) # remove pegasus newline char 34 | assert NLTK_AVAILABLE, "nltk must be installed to separate newlines between sentences. (pip install nltk)" 35 | return "\n".join(nltk.sent_tokenize(x)) 36 | -------------------------------------------------------------------------------- /convert_model_to_fp16.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2020 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Union 17 | 18 | import fire 19 | import torch 20 | from tqdm import tqdm 21 | 22 | 23 | def convert(src_path: str, map_location: str = "cpu", save_path: Union[str, None] = None) -> None: 24 | """Convert a pytorch_model.bin or model.pt file to torch.float16 for faster downloads, less disk space.""" 25 | state_dict = torch.load(src_path, map_location=map_location) 26 | for k, v in tqdm(state_dict.items()): 27 | if not isinstance(v, torch.Tensor): 28 | raise TypeError("FP16 conversion only works on paths that are saved state dicts, like pytorch_model.bin") 29 | state_dict[k] = v.half() 30 | if save_path is None: # overwrite src_path 31 | save_path = src_path 32 | torch.save(state_dict, save_path) 33 | 34 | 35 | if __name__ == "__main__": 36 | fire.Fire(convert) 37 | -------------------------------------------------------------------------------- /save_len_file.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2020 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import fire 17 | from torch.utils.data import DataLoader 18 | from tqdm import tqdm 19 | 20 | from transformers import AutoTokenizer 21 | from utils import Seq2SeqDataset, pickle_save 22 | 23 | 24 | def save_len_file( 25 | tokenizer_name, data_dir, max_source_length=1024, max_target_length=1024, consider_target=False, **kwargs 26 | ): 27 | """Save max(src_len, tgt_len) for each example to allow dynamic batching.""" 28 | tok = AutoTokenizer.from_pretrained(tokenizer_name) 29 | train_ds = Seq2SeqDataset(tok, data_dir, max_source_length, max_target_length, type_path="train", **kwargs) 30 | pad = tok.pad_token_id 31 | 32 | def get_lens(ds): 33 | dl = tqdm( 34 | DataLoader(ds, batch_size=512, num_workers=8, shuffle=False, collate_fn=ds.collate_fn), 35 | desc=str(ds.len_file), 36 | ) 37 | max_lens = [] 38 | for batch in dl: 39 | src_lens = batch["input_ids"].ne(pad).sum(1).tolist() 40 | tgt_lens = batch["labels"].ne(pad).sum(1).tolist() 41 | if consider_target: 42 | for src, tgt in zip(src_lens, tgt_lens): 43 | max_lens.append(max(src, tgt)) 44 | else: 45 | max_lens.extend(src_lens) 46 | return max_lens 47 | 48 | train_lens = get_lens(train_ds) 49 | val_ds = Seq2SeqDataset(tok, data_dir, max_source_length, max_target_length, type_path="val", **kwargs) 50 | val_lens = get_lens(val_ds) 51 | pickle_save(train_lens, train_ds.len_file) 52 | pickle_save(val_lens, val_ds.len_file) 53 | 54 | 55 | if __name__ == "__main__": 56 | fire.Fire(save_len_file) 57 | -------------------------------------------------------------------------------- /seq2seq_training_args.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from dataclasses import dataclass, field 17 | from typing import Optional 18 | 19 | from seq2seq_trainer import arg_to_scheduler 20 | from transformers import TrainingArguments 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @dataclass 27 | class Seq2SeqTrainingArguments(TrainingArguments): 28 | """ 29 | Parameters: 30 | label_smoothing (:obj:`float`, `optional`, defaults to 0): 31 | The label smoothing epsilon to apply (if not zero). 32 | sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`): 33 | Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size. 34 | predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): 35 | Whether to use generate to calculate generative metrics (ROUGE, BLEU). 36 | """ 37 | 38 | label_smoothing: Optional[float] = field( 39 | default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."} 40 | ) 41 | sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."}) 42 | predict_with_generate: bool = field( 43 | default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} 44 | ) 45 | adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"}) 46 | encoder_layerdrop: Optional[float] = field( 47 | default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."} 48 | ) 49 | decoder_layerdrop: Optional[float] = field( 50 | default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."} 51 | ) 52 | dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."}) 53 | attention_dropout: Optional[float] = field( 54 | default=None, metadata={"help": "Attention dropout probability. Goes into model.config."} 55 | ) 56 | lr_scheduler: Optional[str] = field( 57 | default="linear", 58 | metadata={"help": f"Which lr scheduler to use. Selected in {sorted(arg_to_scheduler.keys())}"}, 59 | ) 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESACL: Enhanced Seq2Seq Autoencoder via Contrastive Learning for AbstractiveText Summarization 2 | This repo is for our paper "Enhanced Seq2Seq Autoencoder via Contrastive Learning for AbstractiveText Summarization". Our program is building on top of the Huggingface ```transformers``` framework. You can refer to their repo at: https://github.com/huggingface/transformers/tree/master/examples/seq2seq. 3 | 4 | ## Local Setup 5 | Tested with Python 3.7 via virtual environment. Clone the repo, go to the repo folder, setup the virtual environment, and install the required packages: 6 | ```bash 7 | $ python3.7 -m venv venv 8 | $ source venv/bin/activate 9 | $ pip install -r requirements.txt 10 | ``` 11 | 12 | ### Install ```apex``` 13 | Based on the recommendation from HuggingFace, both finetuning and eval are 30% faster with ```--fp16```. For that you need to install ```apex```. 14 | ```bash 15 | $ git clone https://github.com/NVIDIA/apex 16 | $ cd apex 17 | $ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 18 | ``` 19 | 20 | ## Data 21 | Create a directory for data used in this work named ```data```: 22 | ```bash 23 | $ mkdir data 24 | ``` 25 | 26 | ### CNN/DM 27 | ```bash 28 | $ wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz 29 | $ tar -xzvf cnn_dm_v2.tgz 30 | $ mv cnn_cln data/cnndm 31 | ``` 32 | 33 | ### XSUM 34 | ```bash 35 | $ wget https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz 36 | $ tar -xzvf xsum.tar.gz 37 | $ mv xsum data/xsum 38 | ``` 39 | 40 | ### Generate Augmented Dataset 41 | ```bash 42 | $ python generate_augmentation.py \ 43 | --dataset xsum \ 44 | --n 5 \ 45 | --augmentation1 randomdelete \ 46 | --augmentation2 randomswap 47 | ``` 48 | 49 | ## Training 50 | ### CNN/DM 51 | Our model is warmed up using ```sshleifer/distilbart-cnn-12-6```: 52 | ```bash 53 | $ DATA_DIR=./data/cnndm-augmented/RandominsertionRandominsertion-NumSent-3 54 | $ OUTPUT_DIR=./log/cnndm 55 | 56 | $ python -m torch.distributed.launch --nproc_per_node=3 cl_finetune_trainer.py \ 57 | --data_dir $DATA_DIR \ 58 | --output_dir $OUTPUT_DIR \ 59 | --learning_rate=5e-7 \ 60 | --per_device_train_batch_size 16 \ 61 | --per_device_eval_batch_size 16 \ 62 | --do_train --do_eval \ 63 | --evaluation_strategy steps \ 64 | --freeze_embeds \ 65 | --save_total_limit 10 \ 66 | --save_steps 1000 \ 67 | --logging_steps 1000 \ 68 | --num_train_epochs 5 \ 69 | --model_name_or_path sshleifer/distilbart-cnn-12-6 \ 70 | --alpha 0.2 \ 71 | --temperature 0.5 \ 72 | --freeze_encoder_layer 6 \ 73 | --prediction_loss_only \ 74 | --fp16 75 | ``` 76 | 77 | ### XSUM 78 | ```bash 79 | $ DATA_DIR=./data/xsum-augmented/RandomdeleteRandomswap-NumSent-3 80 | $ OUTPUT_DIR=./log/xsum 81 | 82 | $ python -m torch.distributed.launch --nproc_per_node=3 cl_finetune_trainer.py \ 83 | --data_dir $DATA_DIR \ 84 | --output_dir $OUTPUT_DIR \ 85 | --learning_rate=5e-7 \ 86 | --per_device_train_batch_size 16 \ 87 | --per_device_eval_batch_size 16 \ 88 | --do_train --do_eval \ 89 | --evaluation_strategy steps \ 90 | --freeze_embeds \ 91 | --save_total_limit 10 \ 92 | --save_steps 1000 \ 93 | --logging_steps 1000 \ 94 | --num_train_epochs 5 \ 95 | --model_name_or_path sshleifer/distilbart-xsum-12-6 \ 96 | --alpha 0.2 \ 97 | --temperature 0.5 \ 98 | --freeze_encoder \ 99 | --prediction_loss_only \ 100 | --fp16 101 | ``` 102 | 103 | ## Evaluation 104 | We have released the following checkpoints for pre-trained models as described in the paper: 105 | - [CNN/DM](https://drive.google.com/file/d/1MbLySs5hcxPsSRfPUCzR4AikhtEP08NJ/view?usp=sharing): 106 | - [XSUM](https://drive.google.com/file/d/1SsA8Bstn-VBiH3gDHxU_myFBkNUpNx-D/view?usp=sharing): 107 | 108 | ### CNN/DM 109 | CNN/DM requires an extra postprocessing step. 110 | ```bash 111 | $ export DATA=cnndm 112 | $ export DATA_DIR=data/$DATA 113 | $ export CHECKPOINT_DIR=./log/$DATA 114 | $ export OUTPUT_DIR=output/$DATA 115 | 116 | $ python -m torch.distributed.launch --nproc_per_node=2 run_distributed_eval.py \ 117 | --model_name sshleifer/distilbart-cnn-12-6 \ 118 | --save_dir $OUTPUT_DIR \ 119 | --data_dir $DATA_DIR \ 120 | --bs 16 \ 121 | --fp16 \ 122 | --use_checkpoint \ 123 | --checkpoint_path $CHECKPOINT_DIR 124 | 125 | $ python postprocess_cnndm.py \ 126 | --src_file $OUTPUT_DIR/test_generations.txt \ 127 | --tgt_file $DATA_DIR/test.target 128 | ``` 129 | 130 | ### XSUM 131 | ```bash 132 | $ export DATA=xsum 133 | $ export DATA_DIR=data/$DATA 134 | $ export CHECKPOINT_DIR=./log/$DATA 135 | $ export OUTPUT_DIR=output/$DATA 136 | 137 | $ python -m torch.distributed.launch --nproc_per_node=3 run_distributed_eval.py \ 138 | --model_name sshleifer/distilbart-xsum-12-6 \ 139 | --save_dir $OUTPUT_DIR \ 140 | --data_dir $DATA_DIR \ 141 | --bs 16 \ 142 | --fp16 \ 143 | --use_checkpoint \ 144 | --checkpoint_path $CHECKPOINT_DIR 145 | ``` -------------------------------------------------------------------------------- /generate_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate Augmented Data 3 | """ 4 | import argparse 5 | import random 6 | 7 | from augmentation import DocumentAugmentation 8 | import os 9 | from shutil import copyfile 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--dataset", type=str, help="dataset name", required=True) 13 | parser.add_argument("--n", type=int, help="number of sentences", required=True) 14 | parser.add_argument("--augmentation1", type=str, help="augmentation method #1", required=True, default=None) 15 | parser.add_argument("--augmentation2", type=str, help="augmentation method #2", required=True, default=None) 16 | parser.add_argument("--generation_model", type=str, help="generation model for language generation", default='gpt2') 17 | parser.add_argument("--fp16", type=bool, help="flag variable to use fp16", default=False) 18 | 19 | args = parser.parse_args() 20 | 21 | # read the parameters 22 | DATA_DIR = args.dataset 23 | N = args.n 24 | if args.augmentation1 is not None and args.augmentation2 is not None: 25 | AUGMENTATION = sorted([args.augmentation1, args.augmentation2]) 26 | else: 27 | print(f"No Valid Augmentation Methods") 28 | 29 | fp16 = args.fp16 30 | model = args.generation_model 31 | 32 | # Update datasets by performing document augmentation to get augmented dataset 33 | if not os.path.isdir(f"./data/{DATA_DIR}-augmented"): 34 | # make a new directory for storing the augmented data 35 | os.mkdir(f"./data/{DATA_DIR}-augmented") 36 | 37 | # set the folder name 38 | FOLDER_DIR = AUGMENTATION[0].capitalize() + AUGMENTATION[1].capitalize() 39 | 40 | if not os.path.isdir(f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}"): 41 | os.mkdir(f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}") 42 | 43 | for element in ['train']: 44 | with open(f"./data/{DATA_DIR}/{element}.source", "r", encoding='utf8') as document: 45 | for line in document: 46 | sent = [] 47 | for i in range(len(AUGMENTATION)): 48 | method = AUGMENTATION[i] 49 | # set the seed 50 | if i == 0: 51 | random.seed(97) 52 | elif i == 1: 53 | random.seed(41) 54 | augmentation = DocumentAugmentation(n=N, input=line) 55 | if method.lower() == 'randominsertion': 56 | augmentation.RandomInsertionFromDoc() 57 | elif method.lower() == 'randomswap': 58 | augmentation.RandomSwap() 59 | elif method.lower() == 'randomdelete': 60 | augmentation.RandomDeletion() 61 | elif method.lower() == 'generation': 62 | augmentation.LanguageGenerationReplacement(fp16=fp16, model=model, num_sent_context=N) 63 | elif method.lower() == 'rotation': 64 | augmentation.DocumentRotation() 65 | # record the augmented sentences 66 | sent.append(augmentation.augmented_sentences) 67 | 68 | # record - document 69 | if not os.path.isfile(f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/{element}.source"): 70 | with open(f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/{element}.source", "w", 71 | encoding='utf8') as f: 72 | f.write(f"{' '.join(sent[0])}\n") 73 | f.write(f"{' '.join(sent[1])}\n") 74 | else: 75 | with open(f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/{element}.source", "a", 76 | encoding='utf8') as f: 77 | f.write(f"{' '.join(sent[0])}\n") 78 | f.write(f"{' '.join(sent[1])}\n") 79 | 80 | with open(f"./data/{DATA_DIR}/{element}.target", "r", encoding='utf8') as document: 81 | for line in document: 82 | # record - summary 83 | if not os.path.isfile(f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/{element}.target"): 84 | with open(f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/{element}.target", "w", 85 | encoding='utf8') as f: 86 | f.write(line) 87 | f.write(line) 88 | else: 89 | with open(f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/{element}.target", "a", 90 | encoding='utf8') as f: 91 | f.write(line) 92 | f.write(line) 93 | 94 | # copy validation 95 | copyfile(src=f'./data/{DATA_DIR}/val.source', 96 | dst=f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/val.source") 97 | copyfile(src=f'./data/{DATA_DIR}/val.target', 98 | dst=f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/val.target") 99 | 100 | # copy test 101 | copyfile(src=f'./data/{DATA_DIR}/test.source', 102 | dst=f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/test.source") 103 | copyfile(src=f'./data/{DATA_DIR}/test.target', 104 | dst=f"./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}/test.target") 105 | 106 | else: 107 | print(f"there is data already in this path: ./data/{DATA_DIR}-augmented/{FOLDER_DIR}-NumSent-{N}") 108 | -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | from run_generation import Generation 3 | from nltk import tokenize 4 | 5 | 6 | class DocumentAugmentation(): 7 | """ 8 | Document Augmentation Approaches 9 | """ 10 | 11 | def __init__(self, n, input): 12 | """ 13 | Initialize the class 14 | :param n: how many sentences are selected for augmentation 15 | :param input: input sequence, string 16 | """ 17 | self.n = n 18 | self.input = input 19 | self.sentences = tokenize.sent_tokenize(input) 20 | 21 | def RandomInsertion(self, fp16, num_sent_context=3, model='gpt2'): 22 | """ 23 | randomly insert a sentence in the input document, which is generated based on its context. Do this n times. 24 | :return: 25 | """ 26 | self.augmented_sentences = self.sentences 27 | generation = Generation(model_type=model, fp16=fp16) 28 | 29 | for i in range(self.n): 30 | # generate the index for inserting 31 | location = random.randrange(len(self.sentences)) 32 | 33 | before_idx_start, before_idx_end = location - num_sent_context, location 34 | after_idx_start, after_idx_end = location + 1, location + num_sent_context + 1 35 | 36 | context = " ".join( 37 | self.augmented_sentences[max(0, before_idx_start): max(0, before_idx_end)]) + " ".join( 38 | self.augmented_sentences[ 39 | min(len(self.augmented_sentences) - 1, after_idx_start): min(len(self.augmented_sentences) - 1, 40 | after_idx_end)]) 41 | 42 | new_sentence = generation.generate(context) 43 | 44 | if new_sentence[-1] not in ["?", ".", "!"]: 45 | new_sentence += "." 46 | 47 | # insert new_sentence into the self.sentences 48 | if location < 0: 49 | self.augmented_sentences = [new_sentence] + self.sentences 50 | else: 51 | update_sentence = self.sentences[:location + 1] 52 | update_sentence.append(new_sentence) 53 | update_sentence += self.sentences[location + 1:] 54 | self.augmented_sentences = update_sentence 55 | 56 | def RandomSwap(self): 57 | """ 58 | randomly select two sentences in the input document and swap their positions. Do this $n$ times. 59 | :return: 60 | """ 61 | self.augmented_sentences = self.sentences 62 | if len(self.sentences) >= 2: 63 | for i in range(self.n): 64 | # location is a list contains two random numbers selected 65 | location = random.sample(range(len(self.augmented_sentences)), 2) 66 | sent1 = self.augmented_sentences[location[0]] 67 | sent2 = self.augmented_sentences[location[1]] 68 | # swap two sentences 69 | self.augmented_sentences[location[0]], self.augmented_sentences[location[1]] = sent2, sent1 70 | 71 | def RandomDeletion(self): 72 | """ 73 | randomly delete n sentences from the input document. 74 | :return: 75 | """ 76 | self.augmented_sentences = self.sentences 77 | # Here we require that the augmented document should have at least one sentence 78 | if self.n <= len(self.sentences) - 1: 79 | # location is a list contains two random numbers selected 80 | location_delete = random.sample(range(len(self.augmented_sentences)), self.n) 81 | update_sentence = [self.augmented_sentences[i] for i in range(len(self.augmented_sentences)) if 82 | i not in location_delete] 83 | self.augmented_sentences = update_sentence 84 | 85 | def LanguageGenerationReplacement(self, fp16, model="gpt2", num_sent_context=3): 86 | """ 87 | randomly choose n sentences from the input document. 88 | Replace each of these sentences with a newly generated sentence based on its context. 89 | :return: 90 | """ 91 | self.augmented_sentences = self.sentences 92 | generation = Generation(model_type=model, fp16=fp16) 93 | 94 | if self.n <= len(self.sentences): 95 | location = random.sample(range(len(self.augmented_sentences)), self.n) 96 | update_sentence = [] 97 | for i in range(len(self.augmented_sentences)): 98 | if i not in location: 99 | update_sentence.append(self.augmented_sentences[i]) 100 | else: 101 | before_idx_start, before_idx_end = i - num_sent_context, i 102 | after_idx_start, after_idx_end = i + 1, i + num_sent_context + 1 103 | 104 | context = " ".join( 105 | self.augmented_sentences[max(0, before_idx_start): max(0, before_idx_end)]) + " ".join( 106 | self.augmented_sentences[ 107 | min(len(self.augmented_sentences) - 1, after_idx_start): min(len(self.augmented_sentences) - 1, 108 | after_idx_end)]) 109 | new_sentence = generation.generate(context) 110 | update_sentence.append(new_sentence) 111 | self.augmented_sentences = update_sentence 112 | 113 | def DocumentRotation(self): 114 | """ 115 | randomly select a sentence and rotate the document using this sentence. Do this n times. 116 | :return: 117 | """ 118 | self.augmented_sentences = self.sentences 119 | 120 | # perform experiments n times 121 | for i in range(self.n): 122 | # generate the index for the selected sentence 123 | location = random.randrange(len(self.sentences)) 124 | # rotate the document 125 | self.augmented_sentences = self.augmented_sentences[location + 1:][::-1] + [ 126 | self.augmented_sentences[location]] + self.augmented_sentences[:location][::-1] 127 | 128 | def RandomInsertionFromDoc(self): 129 | """ 130 | Simplified version of RandomInsertion: 131 | randomly insert n sentence in the input document, which is selected from document itself. 132 | :return: 133 | """ 134 | self.augmented_sentences = self.sentences 135 | 136 | for i in range(self.n): 137 | # generate the index for inserting 138 | location = random.randrange(len(self.sentences)) 139 | 140 | # randomly select a sentence from the input document 141 | select_location = random.randrange(len(self.sentences)) 142 | new_sentence = self.sentences[select_location] 143 | 144 | update_sentence = self.sentences[:location + 1] 145 | update_sentence.append(new_sentence) 146 | update_sentence += self.sentences[location + 1:] 147 | self.augmented_sentences = update_sentence 148 | -------------------------------------------------------------------------------- /run_generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding=utf-8 3 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) 18 | """ 19 | 20 | import numpy as np 21 | import torch 22 | 23 | from transformers import ( 24 | CTRLLMHeadModel, 25 | CTRLTokenizer, 26 | GPT2LMHeadModel, 27 | GPT2Tokenizer, 28 | OpenAIGPTLMHeadModel, 29 | OpenAIGPTTokenizer, 30 | TransfoXLLMHeadModel, 31 | TransfoXLTokenizer, 32 | XLMTokenizer, 33 | XLMWithLMHeadModel, 34 | XLNetLMHeadModel, 35 | XLNetTokenizer, 36 | ) 37 | 38 | MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop 39 | 40 | MODEL_CLASSES = { 41 | "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), 42 | "ctrl": (CTRLLMHeadModel, CTRLTokenizer), 43 | "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 44 | "xlnet-large-cased": (XLNetLMHeadModel, XLNetTokenizer), 45 | "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), 46 | "xlm": (XLMWithLMHeadModel, XLMTokenizer), 47 | } 48 | 49 | def set_seed(args): 50 | np.random.seed(args.seed) 51 | torch.manual_seed(args.seed) 52 | if args.n_gpu > 0: 53 | torch.cuda.manual_seed_all(args.seed) 54 | # 55 | # Functions to prepare models' input 56 | # 57 | 58 | def prepare_ctrl_input(args, _, tokenizer, prompt_text): 59 | # if args.temperature > 0.7: 60 | # logger.info("CTRL typically works better with lower temperatures (and lower top_k).") 61 | 62 | encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) 63 | # if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): 64 | # logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") 65 | return prompt_text 66 | 67 | 68 | def prepare_xlm_input(args, model, tokenizer, prompt_text): 69 | # kwargs = {"language": None, "mask_token_id": None} 70 | 71 | # Set the language 72 | use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb 73 | if hasattr(model.config, "lang2id") and use_lang_emb: 74 | available_languages = model.config.lang2id.keys() 75 | if args.xlm_language in available_languages: 76 | language = args.xlm_language 77 | else: 78 | language = None 79 | while language not in available_languages: 80 | language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") 81 | 82 | model.config.lang_id = model.config.lang2id[language] 83 | # kwargs["language"] = tokenizer.lang2id[language] 84 | 85 | # XLM masked-language modeling (MLM) models need masked token 86 | # is_xlm_mlm = "mlm" in args.model_name_or_path 87 | # if is_xlm_mlm: 88 | # kwargs["mask_token_id"] = tokenizer.mask_token_id 89 | 90 | return prompt_text 91 | 92 | 93 | def prepare_xlnet_input(args, _, tokenizer, prompt_text): 94 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX 95 | prompt_text = prefix + prompt_text 96 | return prompt_text 97 | 98 | 99 | def prepare_transfoxl_input(args, _, tokenizer, prompt_text): 100 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX 101 | prompt_text = prefix + prompt_text 102 | return prompt_text 103 | 104 | 105 | PREPROCESSING_FUNCTIONS = { 106 | "ctrl": prepare_ctrl_input, 107 | "xlm": prepare_xlm_input, 108 | "xlnet": prepare_xlnet_input, 109 | "transfo-xl": prepare_transfoxl_input, 110 | } 111 | 112 | 113 | def adjust_length_to_model(length, max_sequence_length): 114 | if length < 0 and max_sequence_length > 0: 115 | length = max_sequence_length 116 | elif 0 < max_sequence_length < length: 117 | length = max_sequence_length # No generation bigger than model size 118 | elif length < 0: 119 | length = MAX_LENGTH # avoid infinite loop 120 | return length 121 | 122 | 123 | class Generation(): 124 | def __init__(self, model_type, prompt="", length=20, stop_token=None, temperature=1.0, repetition_penalty=1.2, k=0, 125 | p=0.9, prefix="", padding_text="", xlm_language="", seed=42, num_return_sequences=1, fp16=False): 126 | 127 | self.model_type = model_type 128 | self.model_name_or_path = model_type 129 | self.prompt = prompt 130 | self.length = length 131 | self.stop_token = stop_token 132 | self.temperature = temperature 133 | self.repetition_penalty = repetition_penalty 134 | self.k = k 135 | self.p = p 136 | self.prefix = prefix 137 | self.padding_text = padding_text 138 | self.xlm_language = xlm_language 139 | self.seed = seed 140 | self.num_return_sequences = num_return_sequences 141 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 142 | self.n_gpu = 0 if self.device == "cpu" else torch.cuda.device_count() 143 | self.fp16 = fp16 144 | 145 | def generate(self, prompt_text): 146 | # Initialize the model and tokenizer 147 | try: 148 | self.model_type = self.model_type.lower() 149 | model_class, tokenizer_class = MODEL_CLASSES[self.model_type] 150 | except KeyError: 151 | raise KeyError("the model {} you specified is not supported.") 152 | 153 | tokenizer = tokenizer_class.from_pretrained(self.model_name_or_path) 154 | model = model_class.from_pretrained(self.model_name_or_path) 155 | model.to(self.device) 156 | 157 | if self.fp16: 158 | model.half() 159 | 160 | self.length = adjust_length_to_model(self.length, max_sequence_length=model.config.max_position_embeddings) 161 | 162 | # Different models need different input formatting and/or extra arguments 163 | requires_preprocessing = self.model_type in PREPROCESSING_FUNCTIONS.keys() 164 | if requires_preprocessing: 165 | prepare_input = PREPROCESSING_FUNCTIONS.get(self.model_type) 166 | preprocessed_prompt_text = prepare_input(self, model, tokenizer, prompt_text) 167 | 168 | if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: 169 | tokenizer_kwargs = {"add_space_before_punct_symbol": True} 170 | else: 171 | tokenizer_kwargs = {} 172 | 173 | encoded_prompt = tokenizer.encode( 174 | preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs 175 | ) 176 | else: 177 | prefix = self.prefix if self.prefix else self.padding_text 178 | encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") 179 | encoded_prompt = encoded_prompt.to(self.device) 180 | 181 | if encoded_prompt.size()[-1] == 0: 182 | input_ids = None 183 | else: 184 | input_ids = encoded_prompt 185 | 186 | output_sequences = model.generate( 187 | input_ids=input_ids, 188 | max_length=self.length + len(encoded_prompt[0]), 189 | temperature=self.temperature, 190 | top_k=self.k, 191 | top_p=self.p, 192 | repetition_penalty=self.repetition_penalty, 193 | do_sample=True, 194 | num_return_sequences=self.num_return_sequences, 195 | ) 196 | 197 | # Remove the batch dimension when returning multiple sequences 198 | if len(output_sequences.shape) > 2: 199 | output_sequences.squeeze_() 200 | 201 | generated_sequences = [] 202 | 203 | for generated_sequence_idx, generated_sequence in enumerate(output_sequences): 204 | # print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1)) 205 | generated_sequence = generated_sequence.tolist() 206 | 207 | # Decode text 208 | text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) 209 | 210 | # Remove all text after the stop token 211 | text = text[: text.find(self.stop_token) if self.stop_token else None] 212 | 213 | # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing 214 | total_sequence = (text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)):]) 215 | 216 | generated_sequences.append(total_sequence) 217 | 218 | return generated_sequences[0] 219 | -------------------------------------------------------------------------------- /postprocess_cnndm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Post-processing for CNN/DM - Modified From ProphetNet: https://github.com/microsoft/ProphetNet 3 | """ 4 | 5 | 6 | import sys 7 | import string 8 | import argparse 9 | import tempfile 10 | import os 11 | import time 12 | import shutil 13 | from bs_pyrouge import Rouge155 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--src_file", type=str, required=True) 17 | parser.add_argument("--tgt_file", type=str, required=True) 18 | parser.add_argument("--duplicate_rate", type=float, default=0.7, 19 | help="If the duplicat rate (compared with history) is large, we can discard the current sentence.") 20 | parser.add_argument("--trunc_len", type=int, default=0, 21 | help="Truncate line by the maximum length.") 22 | 23 | args = parser.parse_args() 24 | 25 | fin = open(args.src_file, 'r', encoding='utf-8') 26 | fgolden = open(args.tgt_file, 'r', encoding='utf-8') 27 | dedup_rate = args.duplicate_rate 28 | trunc_len = args.trunc_len 29 | 30 | def _is_digit(w): 31 | for ch in w: 32 | if not (ch.isdigit() or ch == ','): 33 | return False 34 | return True 35 | 36 | 37 | def fix_tokenization(text): 38 | input_tokens = text.split() 39 | output_tokens = [] 40 | has_left_quote = False 41 | has_left_single_quote = False 42 | 43 | i = 0 44 | prev_dash = False 45 | while i < len(input_tokens): 46 | tok = input_tokens[i] 47 | flag_prev_dash = False 48 | if tok == "\"": 49 | if has_left_quote: 50 | output_tokens.append("''") 51 | else: 52 | output_tokens.append("``") 53 | has_left_quote = not has_left_quote 54 | i += 1 55 | elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and \ 56 | input_tokens[i + 1] == "t": 57 | output_tokens[-1] = output_tokens[-1][:-1] 58 | output_tokens.append("n't") 59 | i += 2 60 | elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"): 61 | output_tokens.append("'" + input_tokens[i + 1]) 62 | i += 2 63 | elif tok == "'": 64 | if has_left_single_quote: 65 | output_tokens.append("'") 66 | else: 67 | output_tokens.append("`") 68 | has_left_single_quote = not has_left_single_quote 69 | i += 1 70 | elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".": 71 | output_tokens.append("...") 72 | i += 3 73 | elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len( 74 | input_tokens) - 1 and _is_digit(input_tokens[i + 1]): 75 | # $ 3 , 000 -> $ 3,000 76 | output_tokens[-1] += ',' + input_tokens[i + 1] 77 | i += 2 78 | elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and \ 79 | input_tokens[i + 1].isdigit(): 80 | # 3 . 03 -> $ 3.03 81 | output_tokens[-1] += '.' + input_tokens[i + 1] 82 | i += 2 83 | elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[ 84 | -1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[ 85 | i + 1].isupper() and input_tokens[i + 2] == '.': 86 | # U . N . -> U.N. 87 | k = i + 3 88 | while k + 2 < len(input_tokens): 89 | if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.': 90 | k += 2 91 | else: 92 | break 93 | output_tokens[-1] += ''.join(input_tokens[i:k]) 94 | i += 2 95 | elif tok == "-": 96 | if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-": 97 | output_tokens.append("--") 98 | i += 2 99 | elif i == len(input_tokens) - 1 or i == 0: 100 | output_tokens.append("-") 101 | i += 1 102 | elif output_tokens[-1] not in string.punctuation and input_tokens[i + 1][0] not in string.punctuation: 103 | output_tokens[-1] += "-" 104 | i += 1 105 | flag_prev_dash = True 106 | else: 107 | output_tokens.append("-") 108 | i += 1 109 | elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation: 110 | output_tokens[-1] += tok 111 | i += 1 112 | else: 113 | output_tokens.append(tok) 114 | i += 1 115 | prev_dash = flag_prev_dash 116 | text = ' '.join([x for x in output_tokens]) 117 | fine_text = text.replace(' ##', '') 118 | return fine_text 119 | 120 | 121 | def remove_duplicate(l_list, duplicate_rate): 122 | tk_list = [l.lower().split() for l in l_list] 123 | r_list = [] 124 | history_set = set() 125 | for i, w_list in enumerate(tk_list): 126 | w_set = set(w_list) 127 | if len(w_set & history_set) / len(w_set) <= duplicate_rate: 128 | r_list.append(l_list[i]) 129 | history_set |= w_set 130 | return r_list 131 | 132 | 133 | def rouge(cand, ref): 134 | temp_dir = tempfile.mkdtemp() 135 | candidates = cand 136 | references = ref 137 | assert len(candidates) == len(references) 138 | 139 | cnt = len(candidates) 140 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 141 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) 142 | if not os.path.isdir(tmp_dir): 143 | os.mkdir(tmp_dir) 144 | os.mkdir(tmp_dir + "/candidate") 145 | os.mkdir(tmp_dir + "/reference") 146 | try: 147 | for i in range(cnt): 148 | if len(references[i]) < 1: 149 | continue 150 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 151 | encoding="utf-8") as f: 152 | f.write(candidates[i]) 153 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 154 | encoding="utf-8") as f: 155 | f.write(references[i]) 156 | r = Rouge155(temp_dir=temp_dir) 157 | r.model_dir = tmp_dir + "/reference/" 158 | r.system_dir = tmp_dir + "/candidate/" 159 | r.model_filename_pattern = 'ref.#ID#.txt' 160 | r.system_filename_pattern = r'cand.(\d+).txt' 161 | rouge_results = r.convert_and_evaluate() 162 | print(rouge_results) 163 | results_dict = r.output_to_dict(rouge_results) 164 | finally: 165 | if os.path.isdir(tmp_dir): 166 | shutil.rmtree(tmp_dir) 167 | return results_dict 168 | 169 | 170 | def rouge_results_to_str(results_dict): 171 | return ">> ROUGE-F(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( 172 | results_dict["rouge_1_f_score"] * 100, 173 | results_dict["rouge_2_f_score"] * 100, 174 | results_dict["rouge_l_f_score"] * 100, 175 | results_dict["rouge_1_recall"] * 100, 176 | results_dict["rouge_2_recall"] * 100, 177 | results_dict["rouge_l_recall"] * 100 178 | ) 179 | 180 | 181 | def count_tokens(tokens): 182 | counter = {} 183 | for t in tokens: 184 | if t in counter.keys(): 185 | counter[t] += 1 186 | else: 187 | counter[t] = 1 188 | return counter 189 | 190 | 191 | def get_f1(text_a, text_b): 192 | tokens_a = text_a.lower().split() 193 | tokens_b = text_b.lower().split() 194 | if len(tokens_a) == 0 or len(tokens_b) == 0: 195 | return 1 if len(tokens_a) == len(tokens_b) else 0 196 | set_a = count_tokens(tokens_a) 197 | set_b = count_tokens(tokens_b) 198 | match = 0 199 | for token in set_a.keys(): 200 | if token in set_b.keys(): 201 | match += min(set_a[token], set_b[token]) 202 | p = match / len(tokens_a) 203 | r = match / len(tokens_b) 204 | return 2.0 * p * r / (p + r + 1e-5) 205 | 206 | 207 | generated_list = [] 208 | for line in fin: 209 | buf = [] 210 | modified_line = line.strip() 211 | for sentence in modified_line.split('.'): 212 | sentence = fix_tokenization(sentence) 213 | if any(get_f1(sentence, s) > 1.0 for s in buf): 214 | continue 215 | s_len = len(sentence.split()) 216 | if s_len <= 4: 217 | continue 218 | buf.append(sentence) 219 | if dedup_rate < 1: 220 | buf = remove_duplicate(buf, dedup_rate) 221 | if trunc_len: 222 | num_left = trunc_len 223 | trunc_list = [] 224 | for bit in buf: 225 | tk_list = bit.split() 226 | n = min(len(tk_list), num_left) 227 | trunc_list.append(' '.join(tk_list[:n])) 228 | num_left -= n 229 | if num_left <= 0: 230 | break 231 | else: 232 | trunc_list = buf 233 | 234 | generated_list.append("\n".join(trunc_list)) 235 | 236 | golden_list = [] 237 | for line in fgolden: 238 | line = line.strip().replace(".", "\n") 239 | # line = line.strip() 240 | golden_list.append(line) 241 | 242 | scores = rouge(generated_list, golden_list) 243 | print(rouge_results_to_str(scores)) 244 | -------------------------------------------------------------------------------- /run_distributed_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright 2020 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import shutil 18 | import time 19 | from json import JSONDecodeError 20 | from logging import getLogger 21 | from pathlib import Path 22 | from typing import Dict, List 23 | 24 | import torch 25 | from torch.utils.data import DataLoader 26 | from tqdm import tqdm 27 | 28 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PegasusTokenizer 29 | from utils import ( 30 | Seq2SeqDataset, 31 | calculate_bleu, 32 | calculate_rouge, 33 | chunks, 34 | lmap, 35 | load_json, 36 | parse_numeric_n_bool_cl_kwargs, 37 | save_json, 38 | use_task_specific_params, 39 | write_txt_file, 40 | ) 41 | 42 | logger = getLogger(__name__) 43 | 44 | 45 | def eval_data_dir( 46 | data_dir, 47 | save_dir: str, 48 | model_name: str, 49 | bs: int = 8, 50 | max_source_length: int = 1024, 51 | type_path="val", 52 | n_obs=None, 53 | fp16=False, 54 | task="summarization", 55 | local_rank=None, 56 | num_return_sequences=1, 57 | dataset_kwargs: Dict = None, 58 | prefix="", 59 | use_checkpoint=False, 60 | checkpoint_path=None, 61 | **generate_kwargs, 62 | ) -> Dict: 63 | """Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json""" 64 | model_name = str(model_name) 65 | assert local_rank is not None 66 | torch.distributed.init_process_group(backend="nccl", rank=local_rank) 67 | 68 | save_dir = Path(save_dir) 69 | save_path = save_dir.joinpath(f"rank_{local_rank}_output.json") 70 | torch.cuda.set_device(local_rank) 71 | 72 | if not use_checkpoint: 73 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda() 74 | else: 75 | model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path).cuda() 76 | 77 | if fp16: 78 | model = model.half() 79 | # update config with task specific params 80 | use_task_specific_params(model, task) 81 | num_beams = generate_kwargs.pop("num_beams", model.config.num_beams) 82 | if num_return_sequences > num_beams: 83 | num_beams = num_return_sequences 84 | 85 | if 'pegasus' in model_name: 86 | tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum', cache_dir="./cache") 87 | else: 88 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./cache") 89 | logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type. 90 | 91 | if max_source_length is None: 92 | max_source_length = tokenizer.model_max_length 93 | if prefix is None: 94 | prefix = prefix or getattr(model.config, "prefix", "") or "" 95 | ds = Seq2SeqDataset( 96 | tokenizer, 97 | data_dir, 98 | max_source_length, 99 | max_target_length=1024, 100 | type_path=type_path, 101 | n_obs=n_obs, 102 | prefix=prefix, 103 | **dataset_kwargs, 104 | ) 105 | # I set shuffle=True for a more accurate progress bar. 106 | # If all the longest samples are first, the prog bar estimate is too high at the beginning. 107 | sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True) 108 | data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn) 109 | results = [] 110 | for batch in tqdm(data_loader): 111 | summaries = model.generate( 112 | input_ids=batch["input_ids"].to(model.device), 113 | attention_mask=batch["attention_mask"].to(model.device), 114 | num_return_sequences=num_return_sequences, 115 | num_beams=num_beams, 116 | **generate_kwargs, 117 | ) 118 | preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) 119 | ids = batch["ids"] 120 | if num_return_sequences > 1: 121 | preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq 122 | for i, pred in enumerate(preds): 123 | results.append(dict(pred=pred, id=ids[i].item())) 124 | save_json(results, save_path) 125 | return results, sampler.num_replicas 126 | 127 | 128 | def run_generate(): 129 | parser = argparse.ArgumentParser( 130 | epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate" 131 | ) 132 | parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source") 133 | parser.add_argument( 134 | "--model_name", 135 | type=str, 136 | help="like facebook/bart-large-cnn,t5-base, etc.", 137 | default="sshleifer/distilbart-xsum-12-3", 138 | ) 139 | parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen") 140 | parser.add_argument("--max_source_length", type=int, default=None) 141 | parser.add_argument( 142 | "--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test" 143 | ) 144 | parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") 145 | parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") 146 | parser.add_argument( 147 | "--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch" 148 | ) 149 | 150 | parser.add_argument( 151 | "--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all." 152 | ) 153 | parser.add_argument( 154 | "--num_return_sequences", type=int, default=1, required=False, help="How many sequences to return" 155 | ) 156 | parser.add_argument( 157 | "--sync_timeout", 158 | type=int, 159 | default=600, 160 | required=False, 161 | help="How long should master process wait for other processes to finish.", 162 | ) 163 | parser.add_argument("--src_lang", type=str, default=None, required=False) 164 | parser.add_argument("--tgt_lang", type=str, default=None, required=False) 165 | parser.add_argument( 166 | "--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples" 167 | ) 168 | parser.add_argument("--fp16", action="store_true") 169 | parser.add_argument("--debug", action="store_true") 170 | parser.add_argument("--use_checkpoint", action="store_true") 171 | parser.add_argument("--checkpoint_path", type=str, default=None, required=False) 172 | start_time = time.time() 173 | args, rest = parser.parse_known_args() 174 | generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest) 175 | if generate_kwargs and args.local_rank <= 0: 176 | print(f"parsed the following generate kwargs: {generate_kwargs}") 177 | json_save_dir = Path(args.save_dir + "_tmp") 178 | Path(json_save_dir).mkdir(exist_ok=True) # this handles locking. 179 | intermediate_files = list(json_save_dir.glob("rank_*.json")) 180 | if intermediate_files: 181 | raise ValueError(f"Found files at {json_save_dir} please move or remove them.") 182 | # In theory, a node could finish and save before another node hits this. If this happens, we can address later. 183 | dataset_kwargs = {} 184 | if args.src_lang is not None: 185 | dataset_kwargs["src_lang"] = args.src_lang 186 | if args.tgt_lang is not None: 187 | dataset_kwargs["tgt_lang"] = args.tgt_lang 188 | 189 | Path(args.save_dir).mkdir(exist_ok=True) 190 | results, num_replicas = eval_data_dir( 191 | args.data_dir, 192 | json_save_dir, 193 | args.model_name, 194 | type_path=args.type_path, 195 | bs=args.bs, 196 | fp16=args.fp16, 197 | task=args.task, 198 | local_rank=args.local_rank, 199 | n_obs=args.n_obs, 200 | max_source_length=args.max_source_length, 201 | num_return_sequences=args.num_return_sequences, 202 | prefix=args.prefix, 203 | dataset_kwargs=dataset_kwargs, 204 | use_checkpoint=args.use_checkpoint, 205 | checkpoint_path=args.checkpoint_path, 206 | **generate_kwargs, 207 | ) 208 | 209 | if args.local_rank <= 0: 210 | save_dir = Path(args.save_dir) 211 | save_dir.mkdir(exist_ok=True) 212 | partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout) 213 | preds = combine_partial_results(partial_results) 214 | if args.num_return_sequences > 1: 215 | save_path = save_dir.joinpath("pseudolabel_results.json") 216 | print(f"Saving aggregated results at {save_path}, intermediate in {json_save_dir}/") 217 | save_json(preds, save_path) 218 | return 219 | tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target") 220 | labels = [x.rstrip() for x in open(tgt_file, encoding='utf8').readlines()][: len(preds)] 221 | 222 | # Calculate metrics, save metrics, and save _generations.txt 223 | calc_bleu = "translation" in args.task 224 | score_fn = calculate_bleu if calc_bleu else calculate_rouge 225 | metric_name = "bleu" if calc_bleu else "rouge" 226 | metrics: Dict = score_fn(preds, labels) 227 | metrics["n_obs"] = len(preds) 228 | runtime = time.time() - start_time 229 | metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 4) 230 | metrics["n_gpus"] = num_replicas 231 | metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json") 232 | save_json(metrics, metrics_save_path, indent=None) 233 | print(metrics) 234 | write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt")) 235 | if args.debug: 236 | write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target")) 237 | else: 238 | shutil.rmtree(json_save_dir) 239 | 240 | 241 | def combine_partial_results(partial_results) -> List: 242 | """Concatenate partial results into one file, then sort it by id.""" 243 | records = [] 244 | for partial_result in partial_results: 245 | records.extend(partial_result) 246 | records = list(sorted(records, key=lambda x: x["id"])) 247 | preds = [x["pred"] for x in records] 248 | return preds 249 | 250 | 251 | def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]: 252 | # WAIT FOR lots of .json files 253 | start_wait = time.time() 254 | logger.info("waiting for all nodes to finish") 255 | json_data = None 256 | while (time.time() - start_wait) < timeout: 257 | json_files = list(save_dir.glob("rank_*.json")) 258 | if len(json_files) < num_replicas: 259 | continue 260 | try: 261 | # make sure all json files are fully saved 262 | json_data = lmap(load_json, json_files) 263 | return json_data 264 | except JSONDecodeError: 265 | continue 266 | else: 267 | raise TimeoutError("Rank 0 gave up on waiting for other processes") 268 | # Unreachable 269 | 270 | 271 | if __name__ == "__main__": 272 | # Usage for MT: 273 | run_generate() 274 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /cl_finetune_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import time 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import transformers 9 | from cl_seq2seq_trainer import Seq2SeqTrainerCL 10 | from seq2seq_training_args import Seq2SeqTrainingArguments 11 | from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed 12 | from transformers.trainer_utils import EvaluationStrategy, is_main_process 13 | # from transformers.training_args import ParallelMode 14 | from utils import ( 15 | Seq2SeqDataCollator, 16 | Seq2SeqDataset, 17 | assert_all_frozen, 18 | build_compute_metrics_fn, 19 | check_output_dir, 20 | freeze_embeds, 21 | freeze_params, 22 | lmap, 23 | save_json, 24 | use_task_specific_params, 25 | write_txt_file, 26 | ) 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | @dataclass 32 | class ModelArguments: 33 | """ 34 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 35 | """ 36 | 37 | model_name_or_path: str = field( 38 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 39 | ) 40 | config_name: Optional[str] = field( 41 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 42 | ) 43 | tokenizer_name: Optional[str] = field( 44 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 45 | ) 46 | cache_dir: Optional[str] = field( 47 | default=None, 48 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 49 | ) 50 | alpha: Optional[float] = field(default=0.5, metadata={"help": "weight for final loss function"}) 51 | sentence_n: Optional[int] = field(default=2, metadata={"help": "sentences for augmentation"}) 52 | temperature: Optional[float] = field(default=0.5, metadata={"help": "temperature for cl loss"}) 53 | 54 | hidden_state_representation: Optional[str] = field(default='cls', metadata={"help": "representation"}) 55 | 56 | freeze_embeds: bool = field(default=False, metadata={"help": "Whether to freeze the embeddings."}) 57 | freeze_encoder: bool = field(default=False, metadata={"help": "Whether tp freeze the encoder."}) 58 | freeze_decoder: bool = field(default=False, metadata={"help": "Whether tp freeze the decoder."}) 59 | freeze_encoder_layer: int = field(default=-1, metadata={"help": "Freeze the first n layers in the encoder"}) 60 | freeze_decoder_layer: int = field(default=-1, metadata={"help": "Freeze the first n layers in the decoder"}) 61 | eval_metric: str = field(default='loss', metadata={"help": "eval metrics for validation set"}) 62 | 63 | continue_trainer: bool = field(default=False, metadata={"help": "flag variable to continue training"}) 64 | continue_trainer_path: str = field(default=None, metadata={"help": "checkpoint path to continue training"}) 65 | 66 | 67 | @dataclass 68 | class DataTrainingArguments: 69 | """ 70 | Arguments pertaining to what data we are going to input our model for training and eval. 71 | """ 72 | 73 | data_dir: str = field( 74 | metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} 75 | ) 76 | task: Optional[str] = field( 77 | default="summarization", 78 | metadata={"help": "Task name, summarization (or summarization_{dataset} for pegasus) or translation"}, 79 | ) 80 | max_source_length: Optional[int] = field( 81 | default=1024, 82 | metadata={ 83 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 84 | "than this will be truncated, sequences shorter will be padded." 85 | }, 86 | ) 87 | max_target_length: Optional[int] = field( 88 | default=128, 89 | metadata={ 90 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 91 | "than this will be truncated, sequences shorter will be padded." 92 | }, 93 | ) 94 | val_max_target_length: Optional[int] = field( 95 | default=142, 96 | metadata={ 97 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 98 | "than this will be truncated, sequences shorter will be padded." 99 | }, 100 | ) 101 | test_max_target_length: Optional[int] = field( 102 | default=142, 103 | metadata={ 104 | "help": "The maximum total sequence length for test target text after tokenization. Sequences longer " 105 | "than this will be truncated, sequences shorter will be padded." 106 | }, 107 | ) 108 | n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."}) 109 | n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."}) 110 | n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."}) 111 | src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."}) 112 | tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."}) 113 | eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."}) 114 | ignore_pad_token_for_loss: bool = field( 115 | default=True, 116 | metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."}, 117 | ) 118 | 119 | 120 | def speed_metrics(split, start_time, num_samples): 121 | """ 122 | Measure and return speed performance metrics. 123 | 124 | This function requires a time snapshot `start_time` before the operation to be measured starts and this 125 | function should be run immediately after the operation to be measured has completed. 126 | 127 | Args: 128 | - split: one of train, val, test 129 | - start_time: operation start time 130 | - num_samples: number of samples processed 131 | 132 | """ 133 | runtime = time.time() - start_time 134 | result = {} 135 | 136 | samples_per_second = 1 / (runtime / num_samples) 137 | result[f"{split}_samples_per_second"] = round(samples_per_second, 3) 138 | result[f"{split}_runtime"] = round(runtime, 4) 139 | 140 | result[f"{split}_n_ojbs"] = num_samples 141 | return result 142 | 143 | 144 | def handle_metrics(split, metrics, output_dir): 145 | """ 146 | Log and save metrics 147 | 148 | Args: 149 | - split: one of train, val, test 150 | - metrics: metrics dict 151 | - output_dir: where to save the metrics 152 | """ 153 | 154 | logger.info(f"***** {split} metrics *****") 155 | for key, value in metrics.items(): 156 | logger.info(f" {key} = {value}") 157 | save_json(metrics, os.path.join(output_dir, f"{split}_results.json")) 158 | 159 | 160 | def main(): 161 | # load the parameters 162 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 163 | 164 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 165 | # If we pass only one argument to the script and it's the path to a json file, 166 | # let's parse it to get our arguments. 167 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 168 | else: 169 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 170 | 171 | check_output_dir(training_args) 172 | 173 | # Set seed 174 | set_seed(training_args.seed) 175 | 176 | # Load pretrained model and tokenizer 177 | # 178 | # Distributed training: 179 | # The .from_pretrained methods guarantee that only one local process can concurrently 180 | # download model & vocab. 181 | 182 | config = AutoConfig.from_pretrained( 183 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 184 | cache_dir=model_args.cache_dir, 185 | ) 186 | 187 | extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") 188 | for p in extra_model_params: 189 | if getattr(training_args, p, None): 190 | assert hasattr(config, p), f"({config.__class__.__name__}) doesn't have a `{p}` attribute" 191 | setattr(config, p, getattr(training_args, p)) 192 | 193 | tokenizer = AutoTokenizer.from_pretrained( 194 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 195 | cache_dir=model_args.cache_dir, 196 | ) 197 | 198 | model = AutoModelForSeq2SeqLM.from_pretrained( 199 | model_args.model_name_or_path, 200 | from_tf=".ckpt" in model_args.model_name_or_path, 201 | config=config, 202 | cache_dir=model_args.cache_dir, 203 | ) 204 | 205 | # use task specific params 206 | use_task_specific_params(model, data_args.task) 207 | 208 | # set num_beams for evaluation 209 | if data_args.eval_beams is None: 210 | data_args.eval_beams = model.config.num_beams 211 | 212 | # set decoder_start_token_id for MBart 213 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): 214 | assert ( 215 | data_args.tgt_lang is not None and data_args.src_lang is not None 216 | ), "mBart requires --tgt_lang and --src_lang" 217 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] 218 | 219 | if model_args.freeze_embeds: 220 | freeze_embeds(model) 221 | logger.info(f"Freeze the embeddings") 222 | 223 | if model_args.freeze_encoder: 224 | freeze_params(model.get_encoder()) 225 | assert_all_frozen(model.get_encoder()) 226 | logger.info(f"Freeze the encoder") 227 | 228 | if model_args.freeze_decoder: 229 | freeze_params(model.get_decoder()) 230 | assert_all_frozen(model.get_decoder()) 231 | logger.info(f"Freeze the decoder") 232 | 233 | # freeze the first N layers in the encoder 234 | if model_args.freeze_encoder_layer > 0: 235 | freeze_params(model.get_encoder().layers[:model_args.freeze_encoder_layer]) 236 | assert_all_frozen(model.get_encoder().layers[:model_args.freeze_encoder_layer]) 237 | freeze_params(model.get_encoder().layernorm_embedding) 238 | assert_all_frozen(model.get_encoder().layernorm_embedding) 239 | logger.info(f"Freeze the first {model_args.freeze_encoder_layer} layers in the encoder") 240 | 241 | if model_args.freeze_decoder_layer > 0: 242 | freeze_params(model.get_decoder().layers[:model_args.freeze_encoder_layer]) 243 | assert_all_frozen(model.get_decoder().layers[:model_args.freeze_encoder_layer]) 244 | logger.info(f"Freeze the first {model_args.freeze_decoder_layer} layers in the decoder") 245 | 246 | dataset_class = Seq2SeqDataset 247 | 248 | # Get datasets 249 | train_dataset = ( 250 | dataset_class( 251 | tokenizer, 252 | type_path="train", 253 | data_dir=data_args.data_dir, 254 | n_obs=data_args.n_train, 255 | max_target_length=data_args.max_target_length, 256 | max_source_length=data_args.max_source_length, 257 | prefix=model.config.prefix or "", 258 | ) 259 | if training_args.do_train 260 | else None 261 | ) 262 | eval_dataset = ( 263 | dataset_class( 264 | tokenizer, 265 | type_path="val", 266 | data_dir=data_args.data_dir, 267 | n_obs=data_args.n_val, 268 | max_target_length=data_args.val_max_target_length, 269 | max_source_length=data_args.max_source_length, 270 | prefix=model.config.prefix or "", 271 | ) 272 | if training_args.do_eval or training_args.evaluation_strategy != EvaluationStrategy.NO 273 | else None 274 | ) 275 | test_dataset = ( 276 | dataset_class( 277 | tokenizer, 278 | type_path="test", 279 | data_dir=data_args.data_dir, 280 | n_obs=data_args.n_test, 281 | max_target_length=data_args.test_max_target_length, 282 | max_source_length=data_args.max_source_length, 283 | prefix=model.config.prefix or "", 284 | ) 285 | if training_args.do_predict 286 | else None 287 | ) 288 | 289 | # Initialize our Trainer 290 | compute_metrics_fn = ( 291 | build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None 292 | ) 293 | trainer = Seq2SeqTrainerCL( 294 | model=model, 295 | tokenizer=tokenizer, 296 | config=config, 297 | args=training_args, 298 | train_dataset=train_dataset, 299 | eval_dataset=eval_dataset, 300 | data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), 301 | compute_metrics=compute_metrics_fn, 302 | data_args=data_args, 303 | alpha=model_args.alpha, 304 | temperature=model_args.temperature, 305 | hidden_state_representation=model_args.hidden_state_representation, 306 | eval_metric=model_args.eval_metric, 307 | ) 308 | 309 | all_metrics = {} 310 | # Training 311 | if training_args.do_train: 312 | logger.info("*** Train ***") 313 | 314 | start_time = time.time() 315 | 316 | if not model_args.continue_trainer: 317 | trainer.train( 318 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 319 | ) 320 | else: 321 | trainer.train( 322 | model_path=model_args.continue_trainer_path if os.path.isdir(model_args.continue_trainer_path) else None 323 | ) 324 | 325 | metrics = speed_metrics("train", start_time, data_args.n_train) 326 | 327 | trainer.save_model() # this also saves the tokenizer 328 | 329 | if trainer.is_world_process_zero(): 330 | handle_metrics("train", metrics, training_args.output_dir) 331 | all_metrics.update(metrics) 332 | 333 | # Need to save the state, since Trainer.save_model saves only the tokenizer with the model 334 | trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) 335 | 336 | # For convenience, we also re-save the tokenizer to the same directory, 337 | # so that you can share your model easily on huggingface.co/models =) 338 | tokenizer.save_pretrained(training_args.output_dir) 339 | 340 | # Evaluation 341 | if training_args.do_eval: 342 | logger.info("*** Evaluate ***") 343 | 344 | start_time = time.time() 345 | metrics = trainer.evaluate(metric_key_prefix="val") 346 | metrics.update(speed_metrics("val", start_time, data_args.n_val)) 347 | metrics["val_loss"] = round(metrics["val_loss"], 4) 348 | 349 | if trainer.is_world_process_zero(): 350 | handle_metrics("val", metrics, training_args.output_dir) 351 | all_metrics.update(metrics) 352 | 353 | if training_args.do_predict: 354 | logger.info("*** Predict ***") 355 | 356 | start_time = time.time() 357 | test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test") 358 | metrics = test_output.metrics 359 | metrics.update(speed_metrics("test", start_time, data_args.n_test)) 360 | 361 | if trainer.is_world_process_zero(): 362 | metrics["test_loss"] = round(metrics["test_loss"], 4) 363 | handle_metrics("test", metrics, training_args.output_dir) 364 | all_metrics.update(metrics) 365 | 366 | if training_args.predict_with_generate: 367 | test_preds = tokenizer.batch_decode( 368 | test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 369 | ) 370 | test_preds = lmap(str.strip, test_preds) 371 | write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt")) 372 | 373 | if trainer.is_world_process_zero(): 374 | save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json")) 375 | 376 | return all_metrics 377 | 378 | 379 | if __name__ == "__main__": 380 | main() 381 | -------------------------------------------------------------------------------- /seq2seq_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import collections 15 | import os 16 | import re 17 | import shutil 18 | from pathlib import Path 19 | from typing import Any, Dict, List, Optional, Tuple, Union 20 | import warnings 21 | import numpy as np 22 | 23 | import torch 24 | from torch import nn 25 | from torch.utils.data import DistributedSampler, RandomSampler, DataLoader, Dataset 26 | 27 | from transformers import PreTrainedModel, Trainer, logging, is_ray_available 28 | from transformers.file_utils import is_torch_tpu_available 29 | from transformers.models.fsmt.configuration_fsmt import FSMTConfig 30 | from transformers.optimization import ( 31 | Adafactor, 32 | AdamW, 33 | get_constant_schedule, 34 | get_constant_schedule_with_warmup, 35 | get_cosine_schedule_with_warmup, 36 | get_cosine_with_hard_restarts_schedule_with_warmup, 37 | get_linear_schedule_with_warmup, 38 | get_polynomial_decay_schedule_with_warmup, 39 | ) 40 | from transformers.trainer_pt_utils import get_tpu_sampler, DistributedTensorGatherer, nested_concat, reissue_pt_warnings 41 | from transformers.trainer_utils import PredictionOutput, EvalPrediction, HPSearchBackend 42 | 43 | # from transformers.training_args import ParallelMode 44 | 45 | 46 | logger = logging.get_logger(__name__) 47 | 48 | arg_to_scheduler = { 49 | "linear": get_linear_schedule_with_warmup, 50 | "cosine": get_cosine_schedule_with_warmup, 51 | "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, 52 | "polynomial": get_polynomial_decay_schedule_with_warmup, 53 | "constant": get_constant_schedule, 54 | "constant_w_warmup": get_constant_schedule_with_warmup, 55 | } 56 | 57 | 58 | class Seq2SeqTrainer(Trainer): 59 | def __init__(self, config=None, data_args=None, *args, **kwargs): 60 | super().__init__(*args, **kwargs) 61 | 62 | if config is None: 63 | assert isinstance( 64 | self.model, PreTrainedModel 65 | ), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}" 66 | self.config = self._actual_model(self.model).config 67 | else: 68 | self.config = config 69 | 70 | self.data_args = data_args 71 | self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size 72 | 73 | if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss): 74 | assert ( 75 | self.config.pad_token_id is not None 76 | ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing." 77 | 78 | if self.config.pad_token_id is None and self.config.eos_token_id is not None: 79 | logger.warn( 80 | f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.." 81 | ) 82 | 83 | if self.args.label_smoothing == 0: 84 | self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) 85 | else: 86 | # dynamically import label_smoothed_nll_loss 87 | from utils import label_smoothed_nll_loss 88 | 89 | self.loss_fn = label_smoothed_nll_loss 90 | 91 | def create_optimizer_and_scheduler(self, num_training_steps: int): 92 | """ 93 | Setup the optimizer and the learning rate scheduler. 94 | 95 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 96 | Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. 97 | """ 98 | if self.optimizer is None: 99 | no_decay = ["bias", "LayerNorm.weight"] 100 | optimizer_grouped_parameters = [ 101 | { 102 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 103 | "weight_decay": self.args.weight_decay, 104 | }, 105 | { 106 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 107 | "weight_decay": 0.0, 108 | }, 109 | ] 110 | if self.args.adafactor: 111 | self.optimizer = Adafactor( 112 | optimizer_grouped_parameters, 113 | lr=self.args.learning_rate, 114 | scale_parameter=False, 115 | relative_step=False, 116 | ) 117 | 118 | else: 119 | self.optimizer = AdamW( 120 | optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon 121 | ) 122 | 123 | if self.lr_scheduler is None: 124 | self.lr_scheduler = self._get_lr_scheduler(num_training_steps) 125 | else: # ignoring --lr_scheduler 126 | logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.") 127 | 128 | def _get_lr_scheduler(self, num_training_steps): 129 | schedule_func = arg_to_scheduler[self.args.lr_scheduler] 130 | if self.args.lr_scheduler == "constant": 131 | scheduler = schedule_func(self.optimizer) 132 | elif self.args.lr_scheduler == "constant_w_warmup": 133 | scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps) 134 | else: 135 | scheduler = schedule_func( 136 | self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps 137 | ) 138 | return scheduler 139 | 140 | def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: 141 | if isinstance(self.train_dataset, torch.utils.data.IterableDataset): 142 | return None 143 | elif is_torch_tpu_available(): 144 | return get_tpu_sampler(self.train_dataset) 145 | else: 146 | if self.args.sortish_sampler: 147 | self.train_dataset.make_sortish_sampler( 148 | self.args.per_device_train_batch_size, 149 | # distributed=(self.args.parallel_mode == "distributed"), 150 | ) 151 | 152 | return ( 153 | RandomSampler(self.train_dataset) 154 | if self.args.local_rank == -1 155 | else DistributedSampler(self.train_dataset) 156 | ) 157 | 158 | def _compute_loss(self, model, inputs, labels): 159 | if self.args.label_smoothing == 0: 160 | if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: 161 | # force training to ignore pad token 162 | logits = model(**inputs, use_cache=False)[0] 163 | loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) 164 | else: 165 | # compute usual loss via models 166 | loss, logits = model(**inputs, labels=labels, use_cache=False)[:2] 167 | else: 168 | # compute label smoothed loss 169 | logits = model(**inputs, use_cache=False)[0] 170 | lprobs = torch.nn.functional.log_softmax(logits, dim=-1) 171 | loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id) 172 | return loss, logits 173 | 174 | def compute_loss(self, model, inputs): 175 | labels = inputs.pop("labels") 176 | loss, _ = self._compute_loss(model, inputs, labels) 177 | return loss 178 | 179 | def prediction_step( 180 | self, 181 | model: nn.Module, 182 | inputs: Dict[str, Union[torch.Tensor, Any]], 183 | prediction_loss_only: bool, 184 | ignore_keys: Optional[List[str]] = None, 185 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 186 | """ 187 | Perform an evaluation step on :obj:`model` using obj:`inputs`. 188 | 189 | Subclass and override to inject custom behavior. 190 | 191 | Args: 192 | model (:obj:`nn.Module`): 193 | The model to evaluate. 194 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 195 | The inputs and targets of the model. 196 | 197 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 198 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 199 | prediction_loss_only (:obj:`bool`): 200 | Whether or not to return the loss only. 201 | 202 | Return: 203 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 204 | A tuple with the loss, logits and labels (each being optional). 205 | """ 206 | inputs = self._prepare_inputs(inputs) 207 | 208 | gen_kwargs = { 209 | "max_length": self.data_args.val_max_target_length 210 | if self.data_args is not None 211 | else self.config.max_length, 212 | "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, 213 | } 214 | 215 | if self.args.predict_with_generate and not self.args.prediction_loss_only: 216 | generated_tokens = self.model.generate( 217 | inputs["input_ids"], 218 | attention_mask=inputs["attention_mask"], 219 | **gen_kwargs, 220 | ) 221 | # in case the batch is shorter than max length, the output should be padded 222 | if generated_tokens.shape[-1] < gen_kwargs["max_length"]: 223 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 224 | 225 | labels = inputs.pop("labels") 226 | with torch.no_grad(): 227 | # compute loss on predict data 228 | loss, logits = self._compute_loss(model, inputs, labels) 229 | 230 | loss = loss.mean().detach() 231 | if self.args.prediction_loss_only: 232 | return (loss, None, None) 233 | 234 | logits = generated_tokens if self.args.predict_with_generate else logits 235 | 236 | if labels.shape[-1] < gen_kwargs["max_length"]: 237 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 238 | 239 | return (loss, logits, labels) 240 | 241 | def _pad_tensors_to_max_len(self, tensor, max_length): 242 | # If PAD token is not defined at least EOS token has to be defined 243 | pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id 244 | 245 | if pad_token_id is None: 246 | raise ValueError( 247 | f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}" 248 | ) 249 | 250 | padded_tensor = pad_token_id * torch.ones( 251 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 252 | ) 253 | padded_tensor[:, : tensor.shape[-1]] = tensor 254 | return padded_tensor 255 | 256 | def evaluate( 257 | self, 258 | eval_dataset: Optional[Dataset] = None, 259 | ignore_keys: Optional[List[str]] = None, 260 | metric_key_prefix: str = "eval", 261 | ) -> Dict[str, float]: 262 | """ 263 | Run evaluation and returns metrics. 264 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 265 | (pass it to the init :obj:`compute_metrics` argument). 266 | You can also subclass and override this method to inject custom behavior. 267 | Args: 268 | eval_dataset (:obj:`Dataset`, `optional`): 269 | Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, 270 | columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the 271 | :obj:`__len__` method. 272 | ignore_keys (:obj:`Lst[str]`, `optional`): 273 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 274 | gathering predictions. 275 | metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): 276 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 277 | "eval_bleu" if the prefix is "eval" (default) 278 | Returns: 279 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 280 | dictionary also contains the epoch number which comes from the training state. 281 | """ 282 | if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): 283 | raise ValueError("eval_dataset must implement __len__") 284 | 285 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 286 | 287 | output = self.prediction_loop( 288 | eval_dataloader, 289 | description="Evaluation", 290 | # No point gathering the predictions if there are no metrics, otherwise we defer to 291 | # self.args.prediction_loss_only 292 | prediction_loss_only=True if self.compute_metrics is None else None, 293 | ignore_keys=ignore_keys, 294 | metric_key_prefix=metric_key_prefix, 295 | ) 296 | 297 | self.log(output.metrics) 298 | 299 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) 300 | return output.metrics 301 | 302 | def predict( 303 | self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval" 304 | ) -> PredictionOutput: 305 | """ 306 | Run prediction and returns predictions and potential metrics. 307 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method 308 | will also return metrics, like in :obj:`evaluate()`. 309 | Args: 310 | test_dataset (:obj:`Dataset`): 311 | Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the 312 | ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` 313 | ignore_keys (:obj:`Lst[str]`, `optional`): 314 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 315 | gathering predictions. 316 | metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): 317 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 318 | "eval_bleu" if the prefix is "eval" (default) 319 | .. note:: 320 | If your predictions or labels have different sequence length (for instance because you're doing dynamic 321 | padding in a token classification task) the predictions will be padded (on the right) to allow for 322 | concatenation into one array. The padding index is -100. 323 | Returns: `NamedTuple` A namedtuple with the following keys: 324 | - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`. 325 | - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some). 326 | - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset 327 | contained labels). 328 | """ 329 | if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized): 330 | raise ValueError("test_dataset must implement __len__") 331 | 332 | test_dataloader = self.get_test_dataloader(test_dataset) 333 | 334 | return self.prediction_loop( 335 | test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix 336 | ) 337 | 338 | def prediction_loop( 339 | self, 340 | dataloader: DataLoader, 341 | description: str, 342 | prediction_loss_only: Optional[bool] = None, 343 | ignore_keys: Optional[List[str]] = None, 344 | metric_key_prefix: str = "eval", 345 | ) -> PredictionOutput: 346 | """ 347 | Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. 348 | Works both with or without labels. 349 | """ 350 | if not isinstance(dataloader.dataset, collections.abc.Sized): 351 | raise ValueError("dataset must implement __len__") 352 | prediction_loss_only = ( 353 | prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only 354 | ) 355 | 356 | model = self.model 357 | # multi-gpu eval 358 | if self.args.n_gpu > 1 and not self.args.model_parallel: 359 | model = torch.nn.DataParallel(model) 360 | # Note: in torch.distributed mode, there's no point in wrapping the model 361 | # inside a DistributedDataParallel as we'll be under `no_grad` anyways. 362 | 363 | batch_size = dataloader.batch_size 364 | num_examples = self.num_examples(dataloader) 365 | logger.info("***** Running %s *****", description) 366 | logger.info(" Num examples = %d", num_examples) 367 | logger.info(" Batch size = %d", batch_size) 368 | losses_host: torch.Tensor = None 369 | preds_host: Union[torch.Tensor, List[torch.Tensor]] = None 370 | labels_host: Union[torch.Tensor, List[torch.Tensor]] = None 371 | 372 | world_size = 1 373 | 374 | if self.args.local_rank != -1: 375 | world_size = torch.distributed.get_world_size() 376 | world_size = max(1, world_size) 377 | 378 | eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) 379 | if not prediction_loss_only: 380 | preds_gatherer = DistributedTensorGatherer(world_size, num_examples) 381 | labels_gatherer = DistributedTensorGatherer(world_size, num_examples) 382 | 383 | model.eval() 384 | 385 | if self.args.past_index >= 0: 386 | self._past = None 387 | 388 | self.callback_handler.eval_dataloader = dataloader 389 | 390 | for step, inputs in enumerate(dataloader): 391 | loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) 392 | if loss is not None: 393 | losses = loss.repeat(batch_size) 394 | losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) 395 | if logits is not None: 396 | preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) 397 | if labels is not None: 398 | labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) 399 | self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) 400 | 401 | # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. 402 | if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: 403 | eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) 404 | if not prediction_loss_only: 405 | preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) 406 | labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) 407 | 408 | # Set back to None to begin a new accumulation 409 | losses_host, preds_host, labels_host = None, None, None 410 | 411 | if self.args.past_index and hasattr(self, "_past"): 412 | # Clean the state at the end of the evaluation loop 413 | delattr(self, "_past") 414 | 415 | # Gather all remaining tensors and put them back on the CPU 416 | eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) 417 | if not prediction_loss_only: 418 | preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) 419 | labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) 420 | 421 | eval_loss = eval_losses_gatherer.finalize() 422 | preds = preds_gatherer.finalize() if not prediction_loss_only else None 423 | label_ids = labels_gatherer.finalize() if not prediction_loss_only else None 424 | 425 | if self.compute_metrics is not None and preds is not None and label_ids is not None: 426 | metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) 427 | else: 428 | metrics = {} 429 | 430 | if eval_loss is not None: 431 | metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() 432 | 433 | # Prefix all keys with metric_key_prefix + '_' 434 | for key in list(metrics.keys()): 435 | if not key.startswith(f"{metric_key_prefix}_"): 436 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 437 | 438 | return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) 439 | 440 | def _save_checkpoint(self, model, trial, metrics=None): 441 | # In all cases (even distributed/parallel), self.model is always a reference 442 | # to the model we want to save. 443 | if hasattr(model, "module"): 444 | assert model.module is self.model, f"Module {model.module} should be a reference to self.model" 445 | else: 446 | assert model is self.model, f"Model {model} should be a reference to self.model" 447 | 448 | # metrics is a dict like: {'eval_loss': 3.2179477214813232, 'epoch': 0.015873015873015872} 449 | # Save model checkpoint 450 | checkpoint_folder = f"val_avg_loss-{'%.4f' % np.round(metrics['eval_loss'], 4)}-step-{self.state.global_step}" 451 | 452 | output_dir = os.path.join(self.args.output_dir, checkpoint_folder) 453 | 454 | self.store_flos() 455 | self.save_model(output_dir) 456 | 457 | # Save optimizer and scheduler 458 | 459 | if self.is_world_process_zero(): 460 | torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 461 | with warnings.catch_warnings(record=True) as caught_warnings: 462 | torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 463 | reissue_pt_warnings(caught_warnings) 464 | 465 | # Determine the new best metric / best model checkpoint 466 | if metrics is not None and self.args.metric_for_best_model is not None: 467 | metric_to_check = self.args.metric_for_best_model 468 | if not metric_to_check.startswith("eval_"): 469 | metric_to_check = f"eval_{metric_to_check}" 470 | metric_value = metrics[metric_to_check] 471 | 472 | operator = np.greater if self.args.greater_is_better else np.less 473 | if ( 474 | self.state.best_metric is None 475 | or self.state.best_model_checkpoint is None 476 | or operator(metric_value, self.state.best_metric) 477 | ): 478 | self.state.best_metric = metric_value 479 | self.state.best_model_checkpoint = output_dir 480 | 481 | # Save the Trainer state 482 | if self.is_world_process_zero(): 483 | self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) 484 | 485 | # Maybe delete some older checkpoints. 486 | if self.is_world_process_zero(): 487 | self._rotate_checkpoints() 488 | 489 | 490 | def _sorted_checkpoints(self, checkpoint_prefix="val_avg_loss") -> List[str]: 491 | glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")] 492 | checkpoints_sorted = sorted(glob_checkpoints) 493 | return checkpoints_sorted 494 | 495 | def _rotate_checkpoints(self) -> None: 496 | if self.args.save_total_limit is None or self.args.save_total_limit <= 0: 497 | return 498 | 499 | # Check if we should delete older checkpoint(s) 500 | checkpoints_sorted = self._sorted_checkpoints() 501 | if len(checkpoints_sorted) <= self.args.save_total_limit: 502 | return 503 | saved_checkpoints = checkpoints_sorted[:self.args.save_total_limit] 504 | 505 | for checkpoint in checkpoints_sorted: 506 | if checkpoint not in saved_checkpoints: 507 | logger.info("Deleting checkpoint [{}] due to args.save_total_limit".format(checkpoint)) 508 | shutil.rmtree(checkpoint) 509 | -------------------------------------------------------------------------------- /bs_pyrouge.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, division 2 | 3 | import os 4 | import re 5 | import codecs 6 | import platform 7 | 8 | from subprocess import check_output 9 | from tempfile import mkdtemp 10 | from functools import partial 11 | 12 | try: 13 | from configparser import ConfigParser 14 | except ImportError: 15 | from ConfigParser import ConfigParser 16 | 17 | from pyrouge.utils import log 18 | from pyrouge.utils.file_utils import verify_dir 19 | 20 | 21 | REMAP = {"-lrb-": "(", "-rrb-": ")", "-lcb-": "{", "-rcb-": "}", 22 | "-lsb-": "[", "-rsb-": "]", "``": '"', "''": '"'} 23 | 24 | 25 | def clean(x): 26 | return re.sub( 27 | r"-lrb-|-rrb-|-lcb-|-rcb-|-lsb-|-rsb-|``|''", 28 | lambda m: REMAP.get(m.group()), x) 29 | 30 | 31 | class DirectoryProcessor: 32 | 33 | @staticmethod 34 | def process(input_dir, output_dir, function): 35 | """ 36 | Apply function to all files in input_dir and save the resulting ouput 37 | files in output_dir. 38 | 39 | """ 40 | if not os.path.exists(output_dir): 41 | os.makedirs(output_dir) 42 | logger = log.get_global_console_logger() 43 | logger.info("Processing files in {}.".format(input_dir)) 44 | input_file_names = os.listdir(input_dir) 45 | for input_file_name in input_file_names: 46 | input_file = os.path.join(input_dir, input_file_name) 47 | with codecs.open(input_file, "r", encoding="UTF-8") as f: 48 | input_string = f.read() 49 | output_string = function(input_string) 50 | output_file = os.path.join(output_dir, input_file_name) 51 | with codecs.open(output_file, "w", encoding="UTF-8") as f: 52 | f.write(clean(output_string.lower())) 53 | logger.info("Saved processed files to {}.".format(output_dir)) 54 | 55 | 56 | class Rouge155(object): 57 | """ 58 | This is a wrapper for the ROUGE 1.5.5 summary evaluation package. 59 | This class is designed to simplify the evaluation process by: 60 | 61 | 1) Converting summaries into a format ROUGE understands. 62 | 2) Generating the ROUGE configuration file automatically based 63 | on filename patterns. 64 | 65 | This class can be used within Python like this: 66 | 67 | rouge = Rouge155() 68 | rouge.system_dir = 'test/systems' 69 | rouge.model_dir = 'test/models' 70 | 71 | # The system filename pattern should contain one group that 72 | # matches the document ID. 73 | rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html' 74 | 75 | # The model filename pattern has '#ID#' as a placeholder for the 76 | # document ID. If there are multiple model summaries, pyrouge 77 | # will use the provided regex to automatically match them with 78 | # the corresponding system summary. Here, [A-Z] matches 79 | # multiple model summaries for a given #ID#. 80 | rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 81 | 82 | rouge_output = rouge.evaluate() 83 | print(rouge_output) 84 | output_dict = rouge.output_to_dict(rouge_ouput) 85 | print(output_dict) 86 | -> {'rouge_1_f_score': 0.95652, 87 | 'rouge_1_f_score_cb': 0.95652, 88 | 'rouge_1_f_score_ce': 0.95652, 89 | 'rouge_1_precision': 0.95652, 90 | [...] 91 | 92 | 93 | To evaluate multiple systems: 94 | 95 | rouge = Rouge155() 96 | rouge.system_dir = '/PATH/TO/systems' 97 | rouge.model_dir = 'PATH/TO/models' 98 | for system_id in ['id1', 'id2', 'id3']: 99 | rouge.system_filename_pattern = \ 100 | 'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id) 101 | rouge.model_filename_pattern = \ 102 | 'SL.P.10.R.[A-Z].SL062003-#ID#.html' 103 | rouge_output = rouge.evaluate(system_id) 104 | print(rouge_output) 105 | 106 | """ 107 | 108 | def __init__(self, rouge_dir=None, rouge_args=None, temp_dir=None): 109 | """ 110 | Create a Rouge155 object. 111 | 112 | rouge_dir: Directory containing Rouge-1.5.5.pl 113 | rouge_args: Arguments to pass through to ROUGE if you 114 | don't want to use the default pyrouge 115 | arguments. 116 | 117 | """ 118 | self.temp_dir = temp_dir 119 | self.log = log.get_global_console_logger() 120 | self.__set_dir_properties() 121 | self._config_file = None 122 | self._settings_file = self.__get_config_path() 123 | self.__set_rouge_dir(rouge_dir) 124 | self.args = self.__clean_rouge_args(rouge_args) 125 | self._system_filename_pattern = None 126 | self._model_filename_pattern = None 127 | 128 | def save_home_dir(self): 129 | config = ConfigParser() 130 | section = 'pyrouge settings' 131 | config.add_section(section) 132 | config.set(section, 'home_dir', self._home_dir) 133 | with open(self._settings_file, 'w') as f: 134 | config.write(f) 135 | self.log.info("Set ROUGE home directory to {}.".format(self._home_dir)) 136 | 137 | @property 138 | def settings_file(self): 139 | """ 140 | Path of the setttings file, which stores the ROUGE home dir. 141 | 142 | """ 143 | return self._settings_file 144 | 145 | @property 146 | def bin_path(self): 147 | """ 148 | The full path of the ROUGE binary (although it's technically 149 | a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl 150 | 151 | """ 152 | if self._bin_path is None: 153 | raise Exception( 154 | "ROUGE path not set. Please set the ROUGE home directory " 155 | "and ensure that ROUGE-1.5.5.pl exists in it.") 156 | return self._bin_path 157 | 158 | @property 159 | def system_filename_pattern(self): 160 | """ 161 | The regular expression pattern for matching system summary 162 | filenames. The regex string. 163 | 164 | E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system 165 | filenames in the SPL2003/system folder of the ROUGE SPL example 166 | in the "sample-test" folder. 167 | 168 | Currently, there is no support for multiple systems. 169 | 170 | """ 171 | return self._system_filename_pattern 172 | 173 | @system_filename_pattern.setter 174 | def system_filename_pattern(self, pattern): 175 | self._system_filename_pattern = pattern 176 | 177 | @property 178 | def model_filename_pattern(self): 179 | """ 180 | The regular expression pattern for matching model summary 181 | filenames. The pattern needs to contain the string "#ID#", 182 | which is a placeholder for the document ID. 183 | 184 | E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model 185 | filenames in the SPL2003/system folder of the ROUGE SPL 186 | example in the "sample-test" folder. 187 | 188 | "#ID#" is a placeholder for the document ID which has been 189 | matched by the "(\d+)" part of the system filename pattern. 190 | The different model summaries for a given document ID are 191 | matched by the "[A-Z]" part. 192 | 193 | """ 194 | return self._model_filename_pattern 195 | 196 | @model_filename_pattern.setter 197 | def model_filename_pattern(self, pattern): 198 | self._model_filename_pattern = pattern 199 | 200 | @property 201 | def config_file(self): 202 | return self._config_file 203 | 204 | @config_file.setter 205 | def config_file(self, path): 206 | config_dir, _ = os.path.split(path) 207 | verify_dir(config_dir, "configuration file") 208 | self._config_file = path 209 | 210 | def split_sentences(self): 211 | """ 212 | ROUGE requires texts split into sentences. In case the texts 213 | are not already split, this method can be used. 214 | 215 | """ 216 | from pyrouge.utils.sentence_splitter import PunktSentenceSplitter 217 | self.log.info("Splitting sentences.") 218 | ss = PunktSentenceSplitter() 219 | 220 | def sent_split_to_string(s): return "\n".join(ss.split(s)) 221 | process_func = partial( 222 | DirectoryProcessor.process, function=sent_split_to_string) 223 | self.__process_summaries(process_func) 224 | 225 | @staticmethod 226 | def convert_summaries_to_rouge_format(input_dir, output_dir): 227 | """ 228 | Convert all files in input_dir into a format ROUGE understands 229 | and saves the files to output_dir. The input files are assumed 230 | to be plain text with one sentence per line. 231 | 232 | input_dir: Path of directory containing the input files. 233 | output_dir: Path of directory in which the converted files 234 | will be saved. 235 | 236 | """ 237 | DirectoryProcessor.process( 238 | input_dir, output_dir, Rouge155.convert_text_to_rouge_format) 239 | 240 | @staticmethod 241 | def convert_text_to_rouge_format(text, title="dummy title"): 242 | """ 243 | Convert a text to a format ROUGE understands. The text is 244 | assumed to contain one sentence per line. 245 | 246 | text: The text to convert, containg one sentence per line. 247 | title: Optional title for the text. The title will appear 248 | in the converted file, but doesn't seem to have 249 | any other relevance. 250 | 251 | Returns: The converted text as string. 252 | 253 | """ 254 | sentences = text.split("\n") 255 | sent_elems = [ 256 | "[{i}] " 257 | "{text}".format(i=i, text=sent) 258 | for i, sent in enumerate(sentences, start=1)] 259 | html = """ 260 | 261 | {title} 262 | 263 | 264 | {elems} 265 | 266 | """.format(title=title, elems="\n".join(sent_elems)) 267 | 268 | return html 269 | 270 | @staticmethod 271 | def write_config_static(system_dir, system_filename_pattern, 272 | model_dir, model_filename_pattern, 273 | config_file_path, system_id=None): 274 | """ 275 | Write the ROUGE configuration file, which is basically a list 276 | of system summary files and their corresponding model summary 277 | files. 278 | 279 | pyrouge uses regular expressions to automatically find the 280 | matching model summary files for a given system summary file 281 | (cf. docstrings for system_filename_pattern and 282 | model_filename_pattern). 283 | 284 | system_dir: Path of directory containing 285 | system summaries. 286 | system_filename_pattern: Regex string for matching 287 | system summary filenames. 288 | model_dir: Path of directory containing 289 | model summaries. 290 | model_filename_pattern: Regex string for matching model 291 | summary filenames. 292 | config_file_path: Path of the configuration file. 293 | system_id: Optional system ID string which 294 | will appear in the ROUGE output. 295 | 296 | """ 297 | system_filenames = [f for f in os.listdir(system_dir)] 298 | system_models_tuples = [] 299 | 300 | system_filename_pattern = re.compile(system_filename_pattern) 301 | for system_filename in sorted(system_filenames): 302 | match = system_filename_pattern.match(system_filename) 303 | if match: 304 | id = match.groups(0)[0] 305 | model_filenames = [model_filename_pattern.replace('#ID#', id)] 306 | # model_filenames = Rouge155.__get_model_filenames_for_id( 307 | # id, model_dir, model_filename_pattern) 308 | system_models_tuples.append( 309 | (system_filename, sorted(model_filenames))) 310 | if not system_models_tuples: 311 | raise Exception( 312 | "Did not find any files matching the pattern {} " 313 | "in the system summaries directory {}.".format( 314 | system_filename_pattern.pattern, system_dir)) 315 | 316 | with codecs.open(config_file_path, 'w', encoding='utf-8') as f: 317 | f.write('') 318 | for task_id, (system_filename, model_filenames) in enumerate( 319 | system_models_tuples, start=1): 320 | 321 | eval_string = Rouge155.__get_eval_string( 322 | task_id, system_id, 323 | system_dir, system_filename, 324 | model_dir, model_filenames) 325 | f.write(eval_string) 326 | f.write("") 327 | 328 | def write_config(self, config_file_path=None, system_id=None): 329 | """ 330 | Write the ROUGE configuration file, which is basically a list 331 | of system summary files and their matching model summary files. 332 | 333 | This is a non-static version of write_config_file_static(). 334 | 335 | config_file_path: Path of the configuration file. 336 | system_id: Optional system ID string which will 337 | appear in the ROUGE output. 338 | 339 | """ 340 | if not system_id: 341 | system_id = 1 342 | if (not config_file_path) or (not self._config_dir): 343 | self._config_dir = mkdtemp(dir=self.temp_dir) 344 | config_filename = "rouge_conf.xml" 345 | else: 346 | config_dir, config_filename = os.path.split(config_file_path) 347 | verify_dir(config_dir, "configuration file") 348 | self._config_file = os.path.join(self._config_dir, config_filename) 349 | Rouge155.write_config_static( 350 | self._system_dir, self._system_filename_pattern, 351 | self._model_dir, self._model_filename_pattern, 352 | self._config_file, system_id) 353 | self.log.info( 354 | "Written ROUGE configuration to {}".format(self._config_file)) 355 | 356 | def evaluate(self, system_id=1, rouge_args=None): 357 | """ 358 | Run ROUGE to evaluate the system summaries in system_dir against 359 | the model summaries in model_dir. The summaries are assumed to 360 | be in the one-sentence-per-line HTML format ROUGE understands. 361 | 362 | system_id: Optional system ID which will be printed in 363 | ROUGE's output. 364 | 365 | Returns: Rouge output as string. 366 | 367 | """ 368 | self.write_config(system_id=system_id) 369 | options = self.__get_options(rouge_args) 370 | command = [self._bin_path] + options 371 | self.log.info( 372 | "Running ROUGE with command {}".format(" ".join(command))) 373 | rouge_output = check_output(command).decode("UTF-8") 374 | return rouge_output 375 | 376 | def convert_and_evaluate(self, system_id=1, 377 | split_sentences=False, rouge_args=None): 378 | """ 379 | Convert plain text summaries to ROUGE format and run ROUGE to 380 | evaluate the system summaries in system_dir against the model 381 | summaries in model_dir. Optionally split texts into sentences 382 | in case they aren't already. 383 | 384 | This is just a convenience method combining 385 | convert_summaries_to_rouge_format() and evaluate(). 386 | 387 | split_sentences: Optional argument specifying if 388 | sentences should be split. 389 | system_id: Optional system ID which will be printed 390 | in ROUGE's output. 391 | 392 | Returns: ROUGE output as string. 393 | 394 | """ 395 | if split_sentences: 396 | self.split_sentences() 397 | self.__write_summaries() 398 | rouge_output = self.evaluate(system_id, rouge_args) 399 | return rouge_output 400 | 401 | def output_to_dict(self, output): 402 | """ 403 | Convert the ROUGE output into python dictionary for further 404 | processing. 405 | 406 | """ 407 | # 0 ROUGE-1 Average_R: 0.02632 (95%-conf.int. 0.02632 - 0.02632) 408 | pattern = re.compile( 409 | r"(\d+) (ROUGE-\S+) (Average_\w): (\d.\d+) " 410 | r"\(95%-conf.int. (\d.\d+) - (\d.\d+)\)") 411 | results = {} 412 | for line in output.split("\n"): 413 | match = pattern.match(line) 414 | if match: 415 | sys_id, rouge_type, measure, result, conf_begin, conf_end = \ 416 | match.groups() 417 | measure = { 418 | 'Average_R': 'recall', 419 | 'Average_P': 'precision', 420 | 'Average_F': 'f_score' 421 | }[measure] 422 | rouge_type = rouge_type.lower().replace("-", '_') 423 | key = "{}_{}".format(rouge_type, measure) 424 | results[key] = float(result) 425 | results["{}_cb".format(key)] = float(conf_begin) 426 | results["{}_ce".format(key)] = float(conf_end) 427 | return results 428 | 429 | ################################################################### 430 | # Private methods 431 | 432 | def __set_rouge_dir(self, home_dir=None): 433 | """ 434 | Verfify presence of ROUGE-1.5.5.pl and data folder, and set 435 | those paths. 436 | 437 | """ 438 | if not home_dir: 439 | self._home_dir = self.__get_rouge_home_dir_from_settings() 440 | else: 441 | self._home_dir = home_dir 442 | self.save_home_dir() 443 | self._bin_path = os.path.join(self._home_dir, 'ROUGE-1.5.5.pl') 444 | self.data_dir = os.path.join(self._home_dir, 'data') 445 | if not os.path.exists(self._bin_path): 446 | raise Exception( 447 | "ROUGE binary not found at {}. Please set the " 448 | "correct path by running pyrouge_set_rouge_path " 449 | "/path/to/rouge/home.".format(self._bin_path)) 450 | 451 | def __get_rouge_home_dir_from_settings(self): 452 | config = ConfigParser() 453 | with open(self._settings_file) as f: 454 | if hasattr(config, "read_file"): 455 | config.read_file(f) 456 | else: 457 | # use deprecated python 2.x method 458 | config.readfp(f) 459 | rouge_home_dir = config.get('pyrouge settings', 'home_dir') 460 | return rouge_home_dir 461 | 462 | @staticmethod 463 | def __get_eval_string( 464 | task_id, system_id, 465 | system_dir, system_filename, 466 | model_dir, model_filenames): 467 | """ 468 | ROUGE can evaluate several system summaries for a given text 469 | against several model summaries, i.e. there is an m-to-n 470 | relation between system and model summaries. The system 471 | summaries are listed in the tag and the model summaries 472 | in the tag. pyrouge currently only supports one system 473 | summary per text, i.e. it assumes a 1-to-n relation between 474 | system and model summaries. 475 | 476 | """ 477 | peer_elems = "

{name}

".format( 478 | id=system_id, name=system_filename) 479 | 480 | model_elems = ["{name}".format( 481 | id=chr(65 + i), name=name) 482 | for i, name in enumerate(model_filenames)] 483 | 484 | model_elems = "\n\t\t\t".join(model_elems) 485 | eval_string = """ 486 | 487 | {model_root} 488 | {peer_root} 489 | 490 | 491 | 492 | {peer_elems} 493 | 494 | 495 | {model_elems} 496 | 497 | 498 | """.format( 499 | task_id=task_id, 500 | model_root=model_dir, model_elems=model_elems, 501 | peer_root=system_dir, peer_elems=peer_elems) 502 | return eval_string 503 | 504 | def __process_summaries(self, process_func): 505 | """ 506 | Helper method that applies process_func to the files in the 507 | system and model folders and saves the resulting files to new 508 | system and model folders. 509 | 510 | """ 511 | temp_dir = mkdtemp(dir=self.temp_dir) 512 | new_system_dir = os.path.join(temp_dir, "system") 513 | os.mkdir(new_system_dir) 514 | new_model_dir = os.path.join(temp_dir, "model") 515 | os.mkdir(new_model_dir) 516 | self.log.info( 517 | "Processing summaries. Saving system files to {} and " 518 | "model files to {}.".format(new_system_dir, new_model_dir)) 519 | process_func(self._system_dir, new_system_dir) 520 | process_func(self._model_dir, new_model_dir) 521 | self._system_dir = new_system_dir 522 | self._model_dir = new_model_dir 523 | 524 | def __write_summaries(self): 525 | self.log.info("Writing summaries.") 526 | self.__process_summaries(self.convert_summaries_to_rouge_format) 527 | 528 | @staticmethod 529 | def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern): 530 | pattern = re.compile(model_filenames_pattern.replace('#ID#', id)) 531 | model_filenames = [ 532 | f for f in os.listdir(model_dir) if pattern.match(f)] 533 | if not model_filenames: 534 | raise Exception( 535 | "Could not find any model summaries for the system" 536 | " summary with ID {}. Specified model filename pattern was: " 537 | "{}".format(id, model_filenames_pattern)) 538 | return model_filenames 539 | 540 | def __get_options(self, rouge_args=None): 541 | """ 542 | Get supplied command line arguments for ROUGE or use default 543 | ones. 544 | 545 | """ 546 | if self.args: 547 | options = self.args.split() 548 | elif rouge_args: 549 | options = rouge_args.split() 550 | else: 551 | options = [ 552 | '-e', self._data_dir, 553 | '-c', 95, 554 | # '-2', 555 | # '-1', 556 | # '-U', 557 | '-m', 558 | # '-v', 559 | '-r', 1000, 560 | '-n', 2, 561 | # '-w', 1.2, 562 | '-a', 563 | ] 564 | options = list(map(str, options)) 565 | 566 | options = self.__add_config_option(options) 567 | return options 568 | 569 | def __create_dir_property(self, dir_name, docstring): 570 | """ 571 | Generate getter and setter for a directory property. 572 | 573 | """ 574 | property_name = "{}_dir".format(dir_name) 575 | private_name = "_" + property_name 576 | setattr(self, private_name, None) 577 | 578 | def fget(self): 579 | return getattr(self, private_name) 580 | 581 | def fset(self, path): 582 | verify_dir(path, dir_name) 583 | setattr(self, private_name, path) 584 | 585 | p = property(fget=fget, fset=fset, doc=docstring) 586 | setattr(self.__class__, property_name, p) 587 | 588 | def __set_dir_properties(self): 589 | """ 590 | Automatically generate the properties for directories. 591 | 592 | """ 593 | directories = [ 594 | ("home", "The ROUGE home directory."), 595 | ("data", "The path of the ROUGE 'data' directory."), 596 | ("system", "Path of the directory containing system summaries."), 597 | ("model", "Path of the directory containing model summaries."), 598 | ] 599 | for (dirname, docstring) in directories: 600 | self.__create_dir_property(dirname, docstring) 601 | 602 | def __clean_rouge_args(self, rouge_args): 603 | """ 604 | Remove enclosing quotation marks, if any. 605 | 606 | """ 607 | if not rouge_args: 608 | return 609 | quot_mark_pattern = re.compile('"(.+)"') 610 | match = quot_mark_pattern.match(rouge_args) 611 | if match: 612 | cleaned_args = match.group(1) 613 | return cleaned_args 614 | else: 615 | return rouge_args 616 | 617 | def __add_config_option(self, options): 618 | return options + [self._config_file] 619 | 620 | def __get_config_path(self): 621 | if platform.system() == "Windows": 622 | parent_dir = os.getenv("APPDATA") 623 | config_dir_name = "pyrouge" 624 | elif os.name == "posix": 625 | parent_dir = os.path.expanduser("~") 626 | config_dir_name = ".pyrouge" 627 | else: 628 | parent_dir = os.path.dirname(__file__) 629 | config_dir_name = "" 630 | config_dir = os.path.join(parent_dir, config_dir_name) 631 | if not os.path.exists(config_dir): 632 | os.makedirs(config_dir) 633 | return os.path.join(config_dir, 'settings.ini') 634 | 635 | 636 | if __name__ == "__main__": 637 | import argparse 638 | from utils.argparsers import rouge_path_parser 639 | 640 | parser = argparse.ArgumentParser(parents=[rouge_path_parser]) 641 | args = parser.parse_args() 642 | 643 | rouge = Rouge155(args.rouge_home) 644 | rouge.save_home_dir() 645 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import itertools 16 | import json 17 | import linecache 18 | import math 19 | import os 20 | import pickle 21 | import socket 22 | from logging import getLogger 23 | from pathlib import Path 24 | from typing import Callable, Dict, Iterable, List, Tuple, Union 25 | 26 | from augmentation import DocumentAugmentation 27 | 28 | import numpy as np 29 | import torch 30 | import torch.distributed as dist 31 | from rouge_score import rouge_scorer, scoring 32 | from sacrebleu import corpus_bleu 33 | from torch import nn 34 | from torch.utils.data import Dataset, Sampler 35 | 36 | from sentence_splitter import add_newline_to_end_of_each_sentence 37 | from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer 38 | from transformers.file_utils import cached_property 39 | from transformers.models.bart.modeling_bart import shift_tokens_right 40 | import random 41 | 42 | try: 43 | from fairseq.data.data_utils import batch_by_size 44 | 45 | FAIRSEQ_AVAILABLE = True 46 | except (ImportError, ModuleNotFoundError): 47 | FAIRSEQ_AVAILABLE = False 48 | 49 | 50 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): 51 | """From fairseq""" 52 | if target.dim() == lprobs.dim() - 1: 53 | target = target.unsqueeze(-1) 54 | nll_loss = -lprobs.gather(dim=-1, index=target) 55 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 56 | if ignore_index is not None: 57 | pad_mask = target.eq(ignore_index) 58 | nll_loss.masked_fill_(pad_mask, 0.0) 59 | smooth_loss.masked_fill_(pad_mask, 0.0) 60 | else: 61 | nll_loss = nll_loss.squeeze(-1) 62 | smooth_loss = smooth_loss.squeeze(-1) 63 | 64 | nll_loss = nll_loss.sum() # mean()? Scared to break other math. 65 | smooth_loss = smooth_loss.sum() 66 | eps_i = epsilon / lprobs.size(-1) 67 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 68 | return loss, nll_loss 69 | 70 | 71 | def lmap(f: Callable, x: Iterable) -> List: 72 | """list(map(f, x))""" 73 | return list(map(f, x)) 74 | 75 | 76 | def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: 77 | """Uses sacrebleu's corpus_bleu implementation.""" 78 | return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} 79 | 80 | 81 | def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]: 82 | def non_pad_len(tokens: np.ndarray) -> int: 83 | return np.count_nonzero(tokens != tokenizer.pad_token_id) 84 | 85 | def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: 86 | pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) 87 | label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) 88 | pred_str = lmap(str.strip, pred_str) 89 | label_str = lmap(str.strip, label_str) 90 | return pred_str, label_str 91 | 92 | def summarization_metrics(pred: EvalPrediction) -> Dict: 93 | pred_str, label_str = decode_pred(pred) 94 | rouge: Dict = calculate_rouge(pred_str, label_str) 95 | summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) 96 | rouge.update({"gen_len": summ_len}) 97 | return rouge 98 | 99 | def translation_metrics(pred: EvalPrediction) -> Dict: 100 | pred_str, label_str = decode_pred(pred) 101 | bleu: Dict = calculate_bleu(pred_str, label_str) 102 | gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1) 103 | bleu.update({"gen_len": gen_len}) 104 | return bleu 105 | 106 | compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics 107 | return compute_metrics_fn 108 | 109 | 110 | def trim_batch( 111 | input_ids, 112 | pad_token_id, 113 | attention_mask=None, 114 | ): 115 | """Remove columns that are populated exclusively by pad_token_id""" 116 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 117 | if attention_mask is None: 118 | return input_ids[:, keep_column_mask] 119 | else: 120 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 121 | 122 | 123 | class AbstractSeq2SeqDataset(Dataset): 124 | def __init__( 125 | self, 126 | tokenizer, 127 | data_dir, 128 | max_source_length, 129 | max_target_length, 130 | type_path="train", 131 | n_obs=None, 132 | prefix="", 133 | **dataset_kwargs 134 | ): 135 | super().__init__() 136 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 137 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 138 | self.len_file = Path(data_dir).joinpath(type_path + ".len") 139 | if os.path.exists(self.len_file): 140 | self.src_lens = pickle_load(self.len_file) 141 | self.used_char_len = False 142 | else: 143 | self.src_lens = self.get_char_lens(self.src_file) 144 | self.used_char_len = True 145 | self.max_source_length = max_source_length 146 | self.max_target_length = max_target_length 147 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 148 | self.tokenizer = tokenizer 149 | self.prefix = prefix if prefix is not None else "" 150 | 151 | if n_obs is not None: 152 | self.src_lens = self.src_lens[:n_obs] 153 | self.pad_token_id = self.tokenizer.pad_token_id 154 | self.dataset_kwargs = dataset_kwargs 155 | dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}) 156 | 157 | def __len__(self): 158 | return len(self.src_lens) 159 | 160 | @staticmethod 161 | def get_char_lens(data_file): 162 | return [len(x) for x in Path(data_file).open(encoding='utf8').readlines()] 163 | 164 | @cached_property 165 | def tgt_lens(self): 166 | """Length in characters of target documents""" 167 | return self.get_char_lens(self.tgt_file) 168 | 169 | def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs): 170 | if distributed: 171 | return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs) 172 | else: 173 | return SortishSampler(self.src_lens, batch_size, shuffle=shuffle) 174 | 175 | def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs): 176 | assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`" 177 | assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler" 178 | sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False)) 179 | 180 | def num_tokens_in_example(i): 181 | return min(self.src_lens[i], self.max_target_length) 182 | 183 | # call fairseq cython function 184 | batch_sampler: List[List[int]] = batch_by_size( 185 | sorted_indices, 186 | num_tokens_fn=num_tokens_in_example, 187 | max_tokens=max_tokens_per_batch, 188 | required_batch_size_multiple=64, 189 | ) 190 | shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))] 191 | # move the largest batch to the front to OOM quickly (uses an approximation for padding) 192 | approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches] 193 | largest_batch_idx = np.argmax(approximate_toks_per_batch) 194 | shuffled_batches[0], shuffled_batches[largest_batch_idx] = ( 195 | shuffled_batches[largest_batch_idx], 196 | shuffled_batches[0], 197 | ) 198 | return shuffled_batches 199 | 200 | def __getitem__(self, item): 201 | raise NotImplementedError("You must implement this") 202 | 203 | def collate_fn(self, batch): 204 | raise NotImplementedError("You must implement this") 205 | 206 | 207 | class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): 208 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 209 | """Call tokenizer on src and tgt_lines""" 210 | index = index + 1 # linecache starts at 1 211 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 212 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 213 | assert source_line, f"empty source line for index {index}" 214 | assert tgt_line, f"empty tgt line for index {index}" 215 | source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length) 216 | target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length) 217 | 218 | source_ids = source_inputs["input_ids"].squeeze() 219 | target_ids = target_inputs["input_ids"].squeeze() 220 | src_mask = source_inputs["attention_mask"].squeeze() 221 | return { 222 | "input_ids": source_ids, 223 | "attention_mask": src_mask, 224 | "labels": target_ids, 225 | } 226 | 227 | def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): 228 | """Only used by LegacyDataset""" 229 | return tokenizer( 230 | [line], 231 | max_length=max_length, 232 | padding="max_length" if pad_to_max_length else None, 233 | truncation=True, 234 | return_tensors=return_tensors, 235 | **self.dataset_kwargs, 236 | ) 237 | 238 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 239 | input_ids = torch.stack([x["input_ids"] for x in batch]) 240 | masks = torch.stack([x["attention_mask"] for x in batch]) 241 | target_ids = torch.stack([x["labels"] for x in batch]) 242 | pad_token_id = self.pad_token_id 243 | y = trim_batch(target_ids, pad_token_id) 244 | source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) 245 | batch = { 246 | "input_ids": source_ids, 247 | "attention_mask": source_mask, 248 | "labels": y, 249 | } 250 | return batch 251 | 252 | 253 | class Seq2SeqDataset(AbstractSeq2SeqDataset): 254 | """A dataset that calls prepare_seq2seq_batch.""" 255 | 256 | def __getitem__(self, index) -> Dict[str, str]: 257 | index = index + 1 # linecache starts at 1 258 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 259 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 260 | assert source_line, f"empty source line for index {index}" 261 | assert tgt_line, f"empty tgt line for index {index}" 262 | return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1} 263 | 264 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 265 | """Call prepare_seq2seq_batch.""" 266 | batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( 267 | [x["src_texts"] for x in batch], 268 | tgt_texts=[x["tgt_texts"] for x in batch], 269 | max_length=self.max_source_length, 270 | max_target_length=self.max_target_length, 271 | return_tensors="pt", 272 | **self.dataset_kwargs, 273 | ).data 274 | batch_encoding["ids"] = torch.tensor([x["id"] for x in batch]) 275 | return batch_encoding 276 | 277 | class Seq2SeqDataCollator: 278 | def __init__(self, tokenizer, data_args, tpu_num_cores=None): 279 | self.tokenizer = tokenizer 280 | self.pad_token_id = tokenizer.pad_token_id 281 | assert ( 282 | self.pad_token_id is not None 283 | ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." 284 | self.data_args = data_args 285 | self.tpu_num_cores = tpu_num_cores 286 | self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 287 | if data_args.src_lang is not None: 288 | self.dataset_kwargs["src_lang"] = data_args.src_lang 289 | if data_args.tgt_lang is not None: 290 | self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang 291 | 292 | def __call__(self, batch) -> Dict[str, torch.Tensor]: 293 | if hasattr(self.tokenizer, "prepare_seq2seq_batch"): 294 | batch = self._encode(batch) 295 | input_ids, attention_mask, labels = ( 296 | batch["input_ids"], 297 | batch["attention_mask"], 298 | batch["labels"], 299 | ) 300 | else: 301 | input_ids = torch.stack([x["input_ids"] for x in batch]) 302 | attention_mask = torch.stack([x["attention_mask"] for x in batch]) 303 | labels = torch.stack([x["labels"] for x in batch]) 304 | 305 | labels = trim_batch(labels, self.pad_token_id) 306 | input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask) 307 | 308 | if isinstance(self.tokenizer, T5Tokenizer): 309 | decoder_input_ids = self._shift_right_t5(labels) 310 | else: 311 | decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) 312 | 313 | batch = { 314 | "input_ids": input_ids, 315 | "attention_mask": attention_mask, 316 | "decoder_input_ids": decoder_input_ids, 317 | "labels": labels, 318 | } 319 | return batch 320 | 321 | def _shift_right_t5(self, input_ids): 322 | # shift inputs to the right 323 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 324 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() 325 | shifted_input_ids[..., 0] = self.pad_token_id 326 | return shifted_input_ids 327 | 328 | def _encode(self, batch) -> Dict[str, torch.Tensor]: 329 | batch_encoding = self.tokenizer.prepare_seq2seq_batch( 330 | [x["src_texts"] for x in batch], 331 | tgt_texts=[x["tgt_texts"] for x in batch], 332 | max_length=self.data_args.max_source_length, 333 | max_target_length=self.data_args.max_target_length, 334 | padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack 335 | return_tensors="pt", 336 | **self.dataset_kwargs, 337 | ) 338 | return batch_encoding.data 339 | 340 | 341 | class SortishSampler(Sampler): 342 | "Go through the text data by order of src length with a bit of randomness. From fastai repo." 343 | 344 | def __init__(self, data, batch_size, shuffle=True): 345 | self.data, self.bs, self.shuffle = data, batch_size, shuffle 346 | 347 | def __len__(self) -> int: 348 | return len(self.data) 349 | 350 | def __iter__(self): 351 | return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle)) 352 | 353 | 354 | def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array: 355 | "Go through the text data by order of src length with a bit of randomness. From fastai repo." 356 | if not shuffle: 357 | return np.argsort(np.array(data) * -1) 358 | 359 | def key_fn(i): 360 | return data[i] 361 | 362 | idxs = np.random.permutation(len(data)) 363 | sz = bs * 50 364 | ck_idx = [idxs[i: i + sz] for i in range(0, len(idxs), sz)] 365 | sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx]) 366 | sz = bs 367 | ck_idx = [sort_idx[i: i + sz] for i in range(0, len(sort_idx), sz)] 368 | max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, 369 | ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. 370 | sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) 371 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) 372 | return sort_idx 373 | 374 | 375 | class DistributedSortishSampler(Sampler): 376 | """Copied from torch DistributedSampler""" 377 | 378 | def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True): 379 | if num_replicas is None: 380 | if not dist.is_available(): 381 | raise RuntimeError("Requires distributed package to be available") 382 | num_replicas = dist.get_world_size() 383 | if rank is None: 384 | if not dist.is_available(): 385 | raise RuntimeError("Requires distributed package to be available") 386 | rank = dist.get_rank() 387 | self.dataset = dataset 388 | self.num_replicas = num_replicas 389 | self.rank = rank 390 | self.epoch = 0 391 | if add_extra_examples: 392 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 393 | self.total_size = self.num_samples * self.num_replicas 394 | else: 395 | self.total_size = len(dataset) 396 | self.num_samples = len(self.available_indices) 397 | self.batch_size = batch_size 398 | self.add_extra_examples = add_extra_examples 399 | self.shuffle = shuffle 400 | 401 | def __iter__(self) -> Iterable: 402 | g = torch.Generator() 403 | g.manual_seed(self.epoch) 404 | 405 | sortish_data = [self.dataset.src_lens[i] for i in self.available_indices] 406 | sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle) 407 | indices = [self.available_indices[i] for i in sortish_indices] 408 | assert len(indices) == self.num_samples 409 | return iter(indices) 410 | 411 | @cached_property 412 | def available_indices(self) -> np.array: 413 | indices = list(range(len(self.dataset))) 414 | # add extra samples to make it evenly divisible 415 | indices += indices[: (self.total_size - len(indices))] 416 | assert len(indices) == self.total_size 417 | # subsample 418 | available_indices = indices[self.rank: self.total_size: self.num_replicas] 419 | return available_indices 420 | 421 | def __len__(self): 422 | return self.num_samples 423 | 424 | def set_epoch(self, epoch): 425 | self.epoch = epoch 426 | 427 | 428 | logger = getLogger(__name__) 429 | 430 | 431 | def use_task_specific_params(model, task): 432 | """Update config with summarization specific params.""" 433 | task_specific_params = model.config.task_specific_params 434 | 435 | if task_specific_params is not None: 436 | pars = task_specific_params.get(task, {}) 437 | logger.info(f"using task specific params for {task}: {pars}") 438 | model.config.update(pars) 439 | 440 | 441 | def pickle_load(path): 442 | """pickle.load(path)""" 443 | with open(path, "rb") as f: 444 | return pickle.load(f) 445 | 446 | 447 | def pickle_save(obj, path): 448 | """pickle.dump(obj, path)""" 449 | with open(path, "wb") as f: 450 | return pickle.dump(obj, f) 451 | 452 | 453 | def flatten_list(summary_ids: List[List]): 454 | return [x for x in itertools.chain.from_iterable(summary_ids)] 455 | 456 | 457 | def save_json(content, path, indent=4, **json_dump_kwargs): 458 | with open(path, "w") as f: 459 | json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs) 460 | 461 | 462 | def load_json(path): 463 | with open(path) as f: 464 | return json.load(f) 465 | 466 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"] 467 | 468 | 469 | def extract_rouge_mid_statistics(dct): 470 | new_dict = {} 471 | for k1, v1 in dct.items(): 472 | mid = v1.mid 473 | new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]} 474 | return new_dict 475 | 476 | 477 | def calculate_rouge( 478 | pred_lns: List[str], 479 | tgt_lns: List[str], 480 | use_stemmer=True, 481 | rouge_keys=ROUGE_KEYS, 482 | return_precision_and_recall=False, 483 | bootstrap_aggregation=True, 484 | newline_sep=True, 485 | ) -> Dict: 486 | """Calculate rouge using rouge_scorer package. 487 | 488 | Args: 489 | pred_lns: list of summaries generated by model 490 | tgt_lns: list of groundtruth summaries (e.g. contents of val.target) 491 | use_stemmer: Bool indicating whether Porter stemmer should be used to 492 | strip word suffixes to improve matching. 493 | rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum 494 | return_precision_and_recall: (False) whether to also return precision and recall. 495 | bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False 496 | this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]`` 497 | newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL 498 | on multi sentence summaries (CNN/DM dataset). 499 | 500 | Returns: 501 | Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys 502 | 503 | """ 504 | scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer) 505 | aggregator = scoring.BootstrapAggregator() 506 | for pred, tgt in zip(tgt_lns, pred_lns): 507 | # rougeLsum expects "\n" separated sentences within a summary 508 | if newline_sep: 509 | pred = add_newline_to_end_of_each_sentence(pred) 510 | tgt = add_newline_to_end_of_each_sentence(tgt) 511 | scores = scorer.score(pred, tgt) 512 | aggregator.add_scores(scores) 513 | 514 | if bootstrap_aggregation: 515 | result = aggregator.aggregate() 516 | if return_precision_and_recall: 517 | return extract_rouge_mid_statistics(result) # here we return dict 518 | else: 519 | return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()} 520 | 521 | else: 522 | return aggregator._scores # here we return defaultdict(list) 523 | 524 | 525 | # Utilities for freezing parameters and checking whether they are frozen 526 | 527 | 528 | def freeze_params(model: nn.Module): 529 | """Set requires_grad=False for each of model.parameters()""" 530 | for par in model.parameters(): 531 | par.requires_grad = False 532 | 533 | 534 | def freeze_embeds(model): 535 | """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" 536 | model_type = model.config.model_type 537 | 538 | if model_type == "t5": 539 | freeze_params(model.shared) 540 | for d in [model.encoder, model.decoder]: 541 | freeze_params(d.embed_tokens) 542 | elif model_type == "fsmt": 543 | for d in [model.model.encoder, model.model.decoder]: 544 | freeze_params(d.embed_positions) 545 | freeze_params(d.embed_tokens) 546 | elif model_type == 'prophetnet': 547 | for d in [model.prophetnet.encoder, model.prophetnet.decoder]: 548 | freeze_params(d.position_embeddings) 549 | freeze_params(d.word_embeddings) 550 | else: 551 | freeze_params(model.model.shared) 552 | for d in [model.model.encoder, model.model.decoder]: 553 | freeze_params(d.embed_positions) 554 | freeze_params(d.embed_tokens) 555 | 556 | 557 | def grad_status(model: nn.Module) -> Iterable: 558 | return (par.requires_grad for par in model.parameters()) 559 | 560 | 561 | def any_requires_grad(model: nn.Module) -> bool: 562 | return any(grad_status(model)) 563 | 564 | 565 | def assert_all_frozen(model): 566 | model_grads: List[bool] = list(grad_status(model)) 567 | n_require_grad = sum(lmap(int, model_grads)) 568 | npars = len(model_grads) 569 | assert not any(model_grads), f"{n_require_grad / npars:.1%} of {npars} weights require grad" 570 | 571 | 572 | def assert_not_all_frozen(model): 573 | model_grads: List[bool] = list(grad_status(model)) 574 | npars = len(model_grads) 575 | assert any(model_grads), f"none of {npars} weights require grad" 576 | 577 | 578 | def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]: 579 | """ 580 | Parse an argv list of unspecified command line args to a dict. 581 | Assumes all values are either numeric or boolean in the form of true/false. 582 | """ 583 | result = {} 584 | assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}" 585 | num_pairs = len(unparsed_args) // 2 586 | for pair_num in range(num_pairs): 587 | i = 2 * pair_num 588 | assert unparsed_args[i].startswith("--") 589 | if unparsed_args[i + 1].lower() == "true": 590 | value = True 591 | elif unparsed_args[i + 1].lower() == "false": 592 | value = False 593 | else: 594 | try: 595 | value = int(unparsed_args[i + 1]) 596 | except ValueError: 597 | value = float(unparsed_args[i + 1]) # this can raise another informative ValueError 598 | 599 | result[unparsed_args[i][2:]] = value 600 | return result 601 | 602 | 603 | def write_txt_file(ordered_tgt, path): 604 | f = Path(path).open("w", encoding='utf8') 605 | for ln in ordered_tgt: 606 | f.write(ln + "\n") 607 | f.flush() 608 | 609 | 610 | def chunks(lst, n): 611 | """Yield successive n-sized chunks from lst.""" 612 | for i in range(0, len(lst), n): 613 | yield lst[i: i + n] 614 | 615 | 616 | def check_output_dir(args, expected_items=0): 617 | """ 618 | Checks whether to bail out if output_dir already exists and has more than expected_items in it 619 | 620 | `args`: needs to have the following attributes of `args`: 621 | - output_dir 622 | - do_train 623 | - overwrite_output_dir 624 | 625 | `expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM) 626 | """ 627 | if ( 628 | os.path.exists(args.output_dir) 629 | and len(os.listdir(args.output_dir)) > expected_items 630 | and args.do_train 631 | and not args.overwrite_output_dir 632 | ): 633 | raise ValueError( 634 | f"Output directory ({args.output_dir}) already exists and " 635 | f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). " 636 | "Use --overwrite_output_dir to overcome." 637 | ) 638 | 639 | class CLSeq2SeqDataset(Dataset): 640 | def __init__( 641 | self, 642 | tokenizer, 643 | data_dir1, 644 | data_dir2, 645 | max_source_length, 646 | max_target_length, 647 | type_path="train", 648 | n_obs=None, 649 | prefix="", 650 | **dataset_kwargs 651 | ): 652 | super().__init__() 653 | # read the first data 654 | self.src_file1 = Path(data_dir1).joinpath(type_path + ".source") 655 | self.tgt_file1 = Path(data_dir1).joinpath(type_path + ".target") 656 | 657 | # read the second data 658 | self.src_file2 = Path(data_dir2).joinpath(type_path + ".source") 659 | self.tgt_file2 = Path(data_dir2).joinpath(type_path + ".target") 660 | 661 | self.src_lens = self.get_char_lens(self.src_file1) 662 | self.used_char_len = True 663 | 664 | self.max_source_length = max_source_length 665 | self.max_target_length = max_target_length 666 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 667 | self.tokenizer = tokenizer 668 | self.prefix = prefix if prefix is not None else "" 669 | 670 | if n_obs is not None: 671 | self.src_lens = self.src_lens[:n_obs] 672 | self.pad_token_id = self.tokenizer.pad_token_id 673 | self.dataset_kwargs = dataset_kwargs 674 | dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}) 675 | 676 | def __len__(self): 677 | return len(self.src_lens) 678 | 679 | def __getitem__(self, index) -> Dict[str, str]: 680 | index = index + 1 # linecache starts at 1 681 | # read the first data 682 | source_line1 = self.prefix + linecache.getline(str(self.src_file1), index).rstrip("\n") 683 | tgt_line1 = linecache.getline(str(self.tgt_file1), index).rstrip("\n") 684 | # read the second data 685 | source_line2 = self.prefix + linecache.getline(str(self.src_file2), index).rstrip("\n") 686 | tgt_line2 = linecache.getline(str(self.tgt_file2), index).rstrip("\n") 687 | assert source_line1, f"empty source line for index {index}" 688 | assert tgt_line1, f"empty tgt line for index {index}" 689 | assert source_line2, f"empty source line for index {index}" 690 | assert tgt_line2, f"empty tgt line for index {index}" 691 | 692 | return {"tgt_texts": (tgt_line1, tgt_line2), "src_texts": (source_line1, source_line2), "id": index - 1} 693 | 694 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 695 | """Call prepare_seq2seq_batch.""" 696 | # encode the first data 697 | batch_encoding1: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( 698 | [x["src_texts"][0] for x in batch], 699 | tgt_texts=[x["tgt_texts"][0] for x in batch], 700 | max_length=self.max_source_length, 701 | max_target_length=self.max_target_length, 702 | return_tensors="pt", 703 | **self.dataset_kwargs, 704 | ).data 705 | batch_encoding1["ids"] = torch.tensor([x["id"] for x in batch]) 706 | # encode the second data 707 | batch_encoding2: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( 708 | [x["src_texts"][1] for x in batch], 709 | tgt_texts=[x["tgt_texts"][1] for x in batch], 710 | max_length=self.max_source_length, 711 | max_target_length=self.max_target_length, 712 | return_tensors="pt", 713 | **self.dataset_kwargs, 714 | ).data 715 | batch_encoding2["ids"] = torch.tensor([x["id"] for x in batch]) 716 | # combine two dict together 717 | batch_encoding: Dict[str, (torch.Tensor, torch.Tensor)] = {key: (batch_encoding1[key], batch_encoding2[key]) for 718 | key in batch_encoding1} 719 | 720 | return batch_encoding 721 | 722 | @staticmethod 723 | def get_char_lens(data_file): 724 | return [len(x) for x in Path(data_file).open(encoding='utf8').readlines()] 725 | 726 | 727 | class CLSeq2SeqDataCollator: 728 | def __init__(self, tokenizer, data_args, tpu_num_cores=None): 729 | self.tokenizer = tokenizer 730 | self.pad_token_id = tokenizer.pad_token_id 731 | assert ( 732 | self.pad_token_id is not None 733 | ), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined." 734 | self.data_args = data_args 735 | self.tpu_num_cores = tpu_num_cores 736 | self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 737 | if data_args.src_lang is not None: 738 | self.dataset_kwargs["src_lang"] = data_args.src_lang 739 | if data_args.tgt_lang is not None: 740 | self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang 741 | 742 | def __call__(self, batch): 743 | # here batch is tuple (torch.Tensor, torch.Tensor) 744 | batch = self._encode(batch) 745 | # input_ids: (input ids for the first instance, input ids for the second instance) 746 | input_ids = (batch["input_ids"][0], batch["input_ids"][1]) 747 | # attention_mask = (batch["attention_mask"][0], batch["attention_mask"][1]) 748 | # labels = (batch["labels"][0], batch["labels"][1]) 749 | # 750 | # decoder_input_ids = ( 751 | # shift_tokens_right(labels[0], self.pad_token_id), shift_tokens_right(labels[1], self.pad_token_id)) 752 | 753 | batch = { 754 | "input_ids": input_ids, 755 | # "attention_mask": attention_mask, 756 | # "decoder_input_ids": decoder_input_ids, 757 | # "labels": labels, 758 | } 759 | return batch 760 | 761 | def _encode(self, batch): 762 | # encode the first data 763 | batch_encoding1 = self.tokenizer.prepare_seq2seq_batch( 764 | [x["src_texts"][0] for x in batch], 765 | # tgt_texts=[x["tgt_texts"][0] for x in batch], 766 | max_length=self.data_args.max_source_length, 767 | max_target_length=self.data_args.max_target_length, 768 | padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack 769 | return_tensors="pt", 770 | **self.dataset_kwargs, 771 | ).data 772 | batch_encoding1["ids"] = torch.tensor([x["id"] for x in batch]) 773 | # encode the second data 774 | batch_encoding2 = self.tokenizer.prepare_seq2seq_batch( 775 | [x["src_texts"][1] for x in batch], 776 | # tgt_texts=[x["tgt_texts"][1] for x in batch], 777 | max_length=self.data_args.max_source_length, 778 | max_target_length=self.data_args.max_target_length, 779 | padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack 780 | return_tensors="pt", 781 | **self.dataset_kwargs, 782 | ).data 783 | batch_encoding2["ids"] = torch.tensor([x["id"] for x in batch]) 784 | # combine two dict together 785 | batch_encoding: Dict[str, (torch.Tensor, torch.Tensor)] = {key: (batch_encoding1[key], batch_encoding2[key]) for 786 | key in batch_encoding1} 787 | return batch_encoding 788 | 789 | class CLSeq2SeqDatasetSingleAugmentation(Dataset): 790 | def __init__( 791 | self, 792 | tokenizer, 793 | data_dir, 794 | max_source_length, 795 | max_target_length, 796 | augmentation1, 797 | augmentation2, 798 | n=3, 799 | generation_model="gpt2", 800 | type_path="train", 801 | n_obs=None, 802 | prefix="", 803 | **dataset_kwargs 804 | ): 805 | super().__init__() 806 | # read the original data 807 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 808 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 809 | 810 | self.src_lens = self.get_char_lens(self.src_file) 811 | self.used_char_len = True 812 | 813 | self.max_source_length = max_source_length 814 | self.max_target_length = max_target_length 815 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 816 | self.tokenizer = tokenizer 817 | self.prefix = prefix if prefix is not None else "" 818 | 819 | self.n = n 820 | self.augmentation = [augmentation1, augmentation2] 821 | self.generation_model = generation_model 822 | 823 | if n_obs is not None: 824 | self.src_lens = self.src_lens[:n_obs] 825 | self.pad_token_id = self.tokenizer.pad_token_id 826 | self.dataset_kwargs = dataset_kwargs 827 | dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}) 828 | 829 | def __len__(self): 830 | return len(self.src_lens) 831 | 832 | def __getitem__(self, index) -> Dict[str, str]: 833 | index = index + 1 # linecache starts at 1 834 | # read the original data 835 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 836 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 837 | 838 | augmented_sent = [] 839 | 840 | for i in range(len(self.augmentation)): 841 | # identify the augmentation method 842 | method = self.augmentation[i] 843 | # if two augmentation methods are the same, we should assign different random numbers 844 | if i < 1: 845 | random.seed(97) 846 | else: 847 | random.seed(41) 848 | # perform text document augmentation 849 | source_line_augmentation = DocumentAugmentation(n=self.n, input=source_line) 850 | 851 | if method == 'RandomInsertionFromDoc': 852 | source_line_augmentation.RandomInsertionFromDoc() 853 | sent = source_line_augmentation.augmented_sentences 854 | 855 | if method == 'DocumentRotation': 856 | source_line_augmentation.DocumentRotation() 857 | sent = source_line_augmentation.augmented_sentences 858 | 859 | if method == 'RandomSwap': 860 | source_line_augmentation.RandomSwap() 861 | sent = source_line_augmentation.augmented_sentences 862 | 863 | if method == 'RandomDeletion': 864 | source_line_augmentation.RandomDeletion() 865 | sent = source_line_augmentation.augmented_sentences 866 | 867 | augmented_sent.append(' '.join(sent)) 868 | 869 | source_line1 = augmented_sent[0] 870 | source_line2 = augmented_sent[1] 871 | 872 | assert source_line1, f"empty source line for index {index}" 873 | assert source_line2, f"empty source line for index {index}" 874 | assert tgt_line, f"empty tgt line for index {index}" 875 | 876 | # return {"tgt_texts": (tgt_line, tgt_line), "src_texts": (source_line1, source_line2), "id": index - 1} 877 | return {"src_texts": (source_line1, source_line2), "id": index - 1} 878 | 879 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 880 | """Call prepare_seq2seq_batch.""" 881 | # encode the first data 882 | batch_encoding1: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( 883 | [x["src_texts"][0] for x in batch], 884 | # tgt_texts=[x["tgt_texts"][0] for x in batch], 885 | max_length=self.max_source_length, 886 | max_target_length=self.max_target_length, 887 | return_tensors="pt", 888 | **self.dataset_kwargs, 889 | ).data 890 | batch_encoding1["ids"] = torch.tensor([x["id"] for x in batch]) 891 | # encode the second data 892 | batch_encoding2: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( 893 | [x["src_texts"][1] for x in batch], 894 | # tgt_texts=[x["tgt_texts"][1] for x in batch], 895 | max_length=self.max_source_length, 896 | max_target_length=self.max_target_length, 897 | return_tensors="pt", 898 | **self.dataset_kwargs, 899 | ).data 900 | batch_encoding2["ids"] = torch.tensor([x["id"] for x in batch]) 901 | # combine two dict together 902 | batch_encoding: Dict[str, (torch.Tensor, torch.Tensor)] = {key: (batch_encoding1[key], batch_encoding2[key]) for 903 | key in batch_encoding1} 904 | 905 | return batch_encoding 906 | 907 | @staticmethod 908 | def get_char_lens(data_file): 909 | return [len(x) for x in Path(data_file).open(encoding='utf8').readlines()] -------------------------------------------------------------------------------- /cl_seq2seq_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import collections 15 | import math 16 | import os 17 | import re 18 | import shutil 19 | from pathlib import Path 20 | from typing import Any, Dict, List, Optional, Tuple, Union 21 | import warnings 22 | import numpy as np 23 | from packaging import version 24 | import torch 25 | from torch import nn 26 | from torch.nn import CrossEntropyLoss 27 | from torch.utils.data import DistributedSampler, RandomSampler, DataLoader, Dataset 28 | import torch.nn.functional as F 29 | from loss.nt_xent import NTXentLoss 30 | 31 | from transformers import PreTrainedModel, Trainer, logging 32 | from transformers.file_utils import is_torch_tpu_available, WEIGHTS_NAME 33 | from transformers.integrations import is_fairscale_available, hp_params, is_tensorboard_available 34 | from transformers.models.fsmt.configuration_fsmt import FSMTConfig 35 | from transformers.optimization import ( 36 | Adafactor, 37 | AdamW, 38 | get_constant_schedule, 39 | get_constant_schedule_with_warmup, 40 | get_cosine_schedule_with_warmup, 41 | get_cosine_with_hard_restarts_schedule_with_warmup, 42 | get_linear_schedule_with_warmup, 43 | get_polynomial_decay_schedule_with_warmup, 44 | ) 45 | from transformers.trainer_pt_utils import get_tpu_sampler, DistributedTensorGatherer, nested_concat, reissue_pt_warnings 46 | from transformers.trainer_utils import PredictionOutput, EvalPrediction, HPSearchBackend, set_seed, TrainOutput 47 | from transformers.trainer_callback import TrainerState, DefaultFlowCallback 48 | 49 | from utils import calculate_rouge, lmap 50 | 51 | DEFAULT_CALLBACKS = [DefaultFlowCallback] 52 | 53 | # Check if Pytorch version >= 1.6 to switch between Native AMP and Apex 54 | if version.parse(torch.__version__) < version.parse("1.6"): 55 | from transformers.file_utils import is_apex_available 56 | 57 | if is_apex_available(): 58 | from apex import amp 59 | else: 60 | _is_native_amp_available = True 61 | from torch.cuda.amp import autocast 62 | 63 | if is_torch_tpu_available(): 64 | import torch_xla.core.xla_model as xm 65 | import torch_xla.debug.metrics as met 66 | import torch_xla.distributed.parallel_loader as pl 67 | 68 | if is_tensorboard_available(): 69 | from transformers.integrations import TensorBoardCallback 70 | 71 | DEFAULT_CALLBACKS.append(TensorBoardCallback) 72 | 73 | if is_fairscale_available(): 74 | from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP 75 | from fairscale.optim import OSS 76 | from fairscale.optim.grad_scaler import ShardedGradScaler 77 | 78 | # from transformers.training_args import ParallelMode 79 | 80 | 81 | logger = logging.get_logger(__name__) 82 | 83 | arg_to_scheduler = { 84 | "linear": get_linear_schedule_with_warmup, 85 | "cosine": get_cosine_schedule_with_warmup, 86 | "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, 87 | "polynomial": get_polynomial_decay_schedule_with_warmup, 88 | "constant": get_constant_schedule, 89 | "constant_w_warmup": get_constant_schedule_with_warmup, 90 | } 91 | 92 | 93 | class Seq2SeqTrainerCL(Trainer): 94 | def __init__(self, alpha=0.5, temperature=0.5, eval_metric='loss', hidden_state_representation='cls', tokenizer=None, 95 | config=None, data_args=None, *args, **kwargs): 96 | super().__init__(*args, **kwargs) 97 | self.alpha = alpha 98 | self.tokenizer = tokenizer 99 | self.temperature = temperature 100 | self.hidden_state_representation = hidden_state_representation 101 | self.eval_metric = eval_metric 102 | 103 | 104 | if config is None: 105 | assert isinstance( 106 | self.model, PreTrainedModel 107 | ), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}" 108 | self.config = self._actual_model(self.model).config 109 | else: 110 | self.config = config 111 | 112 | self.data_args = data_args 113 | self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size 114 | 115 | if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss): 116 | assert ( 117 | self.config.pad_token_id is not None 118 | ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing." 119 | 120 | if self.config.pad_token_id is None and self.config.eos_token_id is not None: 121 | logger.warn( 122 | f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.." 123 | ) 124 | 125 | if self.args.label_smoothing == 0: 126 | self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) 127 | else: 128 | # dynamically import label_smoothed_nll_loss 129 | from utils import label_smoothed_nll_loss 130 | 131 | self.loss_fn = label_smoothed_nll_loss 132 | 133 | def create_optimizer_and_scheduler(self, num_training_steps: int): 134 | """ 135 | Setup the optimizer and the learning rate scheduler. 136 | 137 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 138 | Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. 139 | """ 140 | if self.optimizer is None: 141 | no_decay = ["bias", "LayerNorm.weight"] 142 | optimizer_grouped_parameters = [ 143 | { 144 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 145 | "weight_decay": self.args.weight_decay, 146 | }, 147 | { 148 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 149 | "weight_decay": 0.0, 150 | }, 151 | ] 152 | if self.args.adafactor: 153 | self.optimizer = Adafactor( 154 | optimizer_grouped_parameters, 155 | lr=self.args.learning_rate, 156 | scale_parameter=False, 157 | relative_step=False, 158 | ) 159 | 160 | else: 161 | self.optimizer = AdamW( 162 | optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon 163 | ) 164 | 165 | if self.lr_scheduler is None: 166 | self.lr_scheduler = self._get_lr_scheduler(num_training_steps) 167 | else: # ignoring --lr_scheduler 168 | logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.") 169 | 170 | def _get_lr_scheduler(self, num_training_steps): 171 | schedule_func = arg_to_scheduler[self.args.lr_scheduler] 172 | if self.args.lr_scheduler == "constant": 173 | scheduler = schedule_func(self.optimizer) 174 | elif self.args.lr_scheduler == "constant_w_warmup": 175 | scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps) 176 | else: 177 | scheduler = schedule_func( 178 | self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps 179 | ) 180 | return scheduler 181 | 182 | def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: 183 | if isinstance(self.train_dataset, torch.utils.data.IterableDataset): 184 | return None 185 | elif is_torch_tpu_available(): 186 | return get_tpu_sampler(self.train_dataset) 187 | else: 188 | if self.args.sortish_sampler: 189 | self.train_dataset.make_sortish_sampler( 190 | self.args.per_device_train_batch_size, 191 | # distributed=(self.args.parallel_mode == "distributed"), 192 | ) 193 | 194 | return ( 195 | RandomSampler(self.train_dataset) 196 | if self.args.local_rank == -1 197 | else DistributedSampler(self.train_dataset) 198 | ) 199 | 200 | def _compute_loss(self, model, inputs, labels): 201 | if self.args.label_smoothing == 0: 202 | if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: 203 | # force training to ignore pad token 204 | logits = model(**inputs, use_cache=False)[0] 205 | loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) 206 | else: 207 | # compute usual loss via models 208 | loss, logits = model(**inputs, labels=labels, use_cache=False)[:2] 209 | else: 210 | # compute label smoothed loss 211 | logits = model(**inputs, use_cache=False)[0] 212 | lprobs = torch.nn.functional.log_softmax(logits, dim=-1) 213 | loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id) 214 | return loss, logits 215 | 216 | def compute_loss(self, model, inputs): 217 | labels = inputs.pop("labels") 218 | loss, _ = self._compute_loss(model, inputs, labels) 219 | return loss 220 | 221 | def prediction_step( 222 | self, 223 | model: nn.Module, 224 | inputs: Dict[str, Union[torch.Tensor, Any]], 225 | prediction_loss_only: bool, 226 | ignore_keys: Optional[List[str]] = None, 227 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 228 | """ 229 | Perform an evaluation step on :obj:`model` using obj:`inputs`. 230 | 231 | Subclass and override to inject custom behavior. 232 | 233 | Args: 234 | model (:obj:`nn.Module`): 235 | The model to evaluate. 236 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 237 | The inputs and targets of the model. 238 | 239 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 240 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 241 | prediction_loss_only (:obj:`bool`): 242 | Whether or not to return the loss only. 243 | 244 | Return: 245 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 246 | A tuple with the loss, logits and labels (each being optional). 247 | """ 248 | inputs = self._prepare_inputs(inputs) 249 | 250 | gen_kwargs = { 251 | "max_length": self.data_args.val_max_target_length 252 | if self.data_args is not None 253 | else self.config.max_length, 254 | "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, 255 | } 256 | 257 | # if self.args.predict_with_generate and not self.args.prediction_loss_only: 258 | # generated_tokens = self.model.generate( 259 | # inputs["input_ids"], 260 | # attention_mask=inputs["attention_mask"], 261 | # **gen_kwargs, 262 | # ) 263 | # # in case the batch is shorter than max length, the output should be padded 264 | # if generated_tokens.shape[-1] < gen_kwargs["max_length"]: 265 | # generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 266 | with torch.no_grad(): 267 | if 'rouge' in self.eval_metric: 268 | generated_tokens = self.model.generate( 269 | inputs["input_ids"], 270 | attention_mask=inputs["attention_mask"], 271 | **gen_kwargs, 272 | ) 273 | preds = self.ids_to_clean_text(generated_tokens) 274 | y = self.trim_batch(inputs["decoder_input_ids"], self.config.pad_token_id) 275 | target = self.ids_to_clean_text(y) 276 | rouge_dict = calculate_rouge(preds, target, rouge_keys=["rouge1", "rouge2", "rougeL"]) 277 | rouge2 = rouge_dict['rouge2'] 278 | loss = torch.tensor(rouge2).to(generated_tokens.device) 279 | elif 'loss' in self.eval_metric: 280 | labels = inputs.pop("labels") 281 | # compute loss on predict data 282 | loss, logits = self._compute_loss(model, inputs, labels) 283 | else: 284 | print(f"Please define the loss function for evaluation set") 285 | 286 | loss = loss.mean().detach() 287 | 288 | if self.args.prediction_loss_only: 289 | return (loss, None, None) 290 | 291 | # logits = generated_tokens if self.args.predict_with_generate else logits 292 | 293 | if labels.shape[-1] < gen_kwargs["max_length"]: 294 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 295 | 296 | return (loss, logits, labels) 297 | 298 | def ids_to_clean_text(self, generated_ids: List[int]): 299 | gen_text = self.tokenizer.batch_decode( 300 | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True 301 | ) 302 | return lmap(str.strip, gen_text) 303 | 304 | def trim_batch(self, input_ids, pad_token_id, attention_mask=None, ): 305 | """Remove columns that are populated exclusively by pad_token_id""" 306 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 307 | if attention_mask is None: 308 | return input_ids[:, keep_column_mask] 309 | else: 310 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 311 | 312 | def _pad_tensors_to_max_len(self, tensor, max_length): 313 | # If PAD token is not defined at least EOS token has to be defined 314 | pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id 315 | 316 | if pad_token_id is None: 317 | raise ValueError( 318 | f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}" 319 | ) 320 | 321 | padded_tensor = pad_token_id * torch.ones( 322 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 323 | ) 324 | padded_tensor[:, : tensor.shape[-1]] = tensor 325 | return padded_tensor 326 | 327 | def evaluate( 328 | self, 329 | eval_dataset: Optional[Dataset] = None, 330 | ignore_keys: Optional[List[str]] = None, 331 | metric_key_prefix: str = "eval", 332 | ) -> Dict[str, float]: 333 | """ 334 | Run evaluation and returns metrics. 335 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 336 | (pass it to the init :obj:`compute_metrics` argument). 337 | You can also subclass and override this method to inject custom behavior. 338 | Args: 339 | eval_dataset (:obj:`Dataset`, `optional`): 340 | Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, 341 | columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the 342 | :obj:`__len__` method. 343 | ignore_keys (:obj:`Lst[str]`, `optional`): 344 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 345 | gathering predictions. 346 | metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): 347 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 348 | "eval_bleu" if the prefix is "eval" (default) 349 | Returns: 350 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 351 | dictionary also contains the epoch number which comes from the training state. 352 | """ 353 | if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): 354 | raise ValueError("eval_dataset must implement __len__") 355 | 356 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 357 | 358 | output = self.prediction_loop( 359 | eval_dataloader, 360 | description="Evaluation", 361 | # No point gathering the predictions if there are no metrics, otherwise we defer to 362 | # self.args.prediction_loss_only 363 | prediction_loss_only=True if self.compute_metrics is None else None, 364 | ignore_keys=ignore_keys, 365 | metric_key_prefix=metric_key_prefix, 366 | ) 367 | 368 | self.log(output.metrics) 369 | 370 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) 371 | return output.metrics 372 | 373 | def prediction_loop( 374 | self, 375 | dataloader: DataLoader, 376 | description: str, 377 | prediction_loss_only: Optional[bool] = None, 378 | ignore_keys: Optional[List[str]] = None, 379 | metric_key_prefix: str = "eval", 380 | ) -> PredictionOutput: 381 | """ 382 | Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. 383 | Works both with or without labels. 384 | """ 385 | if not isinstance(dataloader.dataset, collections.abc.Sized): 386 | raise ValueError("dataset must implement __len__") 387 | prediction_loss_only = ( 388 | prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only 389 | ) 390 | 391 | model = self.model 392 | # multi-gpu eval 393 | if self.args.n_gpu > 1 and not self.args.model_parallel: 394 | model = torch.nn.DataParallel(model) 395 | # Note: in torch.distributed mode, there's no point in wrapping the model 396 | # inside a DistributedDataParallel as we'll be under `no_grad` anyways. 397 | 398 | batch_size = dataloader.batch_size 399 | num_examples = self.num_examples(dataloader) 400 | logger.info("***** Running %s *****", description) 401 | logger.info(" Num examples = %d", num_examples) 402 | logger.info(" Batch size = %d", batch_size) 403 | losses_host: torch.Tensor = None 404 | preds_host: Union[torch.Tensor, List[torch.Tensor]] = None 405 | labels_host: Union[torch.Tensor, List[torch.Tensor]] = None 406 | 407 | world_size = 1 408 | 409 | if self.args.local_rank != -1: 410 | world_size = torch.distributed.get_world_size() 411 | world_size = max(1, world_size) 412 | 413 | eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) 414 | if not prediction_loss_only: 415 | preds_gatherer = DistributedTensorGatherer(world_size, num_examples) 416 | labels_gatherer = DistributedTensorGatherer(world_size, num_examples) 417 | 418 | model.eval() 419 | 420 | if self.args.past_index >= 0: 421 | self._past = None 422 | 423 | self.callback_handler.eval_dataloader = dataloader 424 | 425 | for step, inputs in enumerate(dataloader): 426 | loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) 427 | if loss is not None: 428 | losses = loss.repeat(batch_size) 429 | losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) 430 | if logits is not None: 431 | preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) 432 | if labels is not None: 433 | labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) 434 | self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) 435 | 436 | # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. 437 | if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: 438 | eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) 439 | if not prediction_loss_only: 440 | preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) 441 | labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) 442 | 443 | # Set back to None to begin a new accumulation 444 | losses_host, preds_host, labels_host = None, None, None 445 | 446 | if self.args.past_index and hasattr(self, "_past"): 447 | # Clean the state at the end of the evaluation loop 448 | delattr(self, "_past") 449 | 450 | # Gather all remaining tensors and put them back on the CPU 451 | eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) 452 | if not prediction_loss_only: 453 | preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) 454 | labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) 455 | 456 | eval_loss = eval_losses_gatherer.finalize() 457 | preds = preds_gatherer.finalize() if not prediction_loss_only else None 458 | label_ids = labels_gatherer.finalize() if not prediction_loss_only else None 459 | 460 | if self.compute_metrics is not None and preds is not None and label_ids is not None: 461 | metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) 462 | else: 463 | metrics = {} 464 | 465 | if eval_loss is not None: 466 | if 'rouge' in self.eval_metric: 467 | metrics[f"{metric_key_prefix}_rouge"] = eval_loss.mean().item() 468 | else: 469 | metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() 470 | 471 | # Prefix all keys with metric_key_prefix + '_' 472 | for key in list(metrics.keys()): 473 | if not key.startswith(f"{metric_key_prefix}_"): 474 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 475 | 476 | return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) 477 | 478 | def _save_checkpoint(self, model, trial, metrics=None): 479 | # In all cases (even distributed/parallel), self.model is always a reference 480 | # to the model we want to save. 481 | if hasattr(model, "module"): 482 | assert model.module is self.model, f"Module {model.module} should be a reference to self.model" 483 | else: 484 | assert model is self.model, f"Model {model} should be a reference to self.model" 485 | 486 | # Save model checkpoint 487 | if 'rouge' in self.eval_metric: 488 | checkpoint_folder = f"val-rouge-{'%.4f' % np.round(metrics['eval_rouge'], 4)}-step-{self.state.global_step}" 489 | else: 490 | checkpoint_folder = f"val-loss-{'%.4f' % np.round(metrics['eval_loss'], 4)}-step-{self.state.global_step}" 491 | 492 | output_dir = os.path.join(self.args.output_dir, checkpoint_folder) 493 | 494 | self.store_flos() 495 | self.save_model(output_dir) 496 | 497 | # Save optimizer and scheduler 498 | 499 | if self.is_world_process_zero(): 500 | torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 501 | with warnings.catch_warnings(record=True) as caught_warnings: 502 | torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 503 | reissue_pt_warnings(caught_warnings) 504 | 505 | # Determine the new best metric / best model checkpoint 506 | if metrics is not None and self.args.metric_for_best_model is not None: 507 | metric_to_check = self.args.metric_for_best_model 508 | if not metric_to_check.startswith("eval_"): 509 | metric_to_check = f"eval_{metric_to_check}" 510 | metric_value = metrics[metric_to_check] 511 | 512 | operator = np.greater if self.args.greater_is_better else np.less 513 | if ( 514 | self.state.best_metric is None 515 | or self.state.best_model_checkpoint is None 516 | or operator(metric_value, self.state.best_metric) 517 | ): 518 | self.state.best_metric = metric_value 519 | self.state.best_model_checkpoint = output_dir 520 | 521 | # Save the Trainer state 522 | if self.is_world_process_zero(): 523 | self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) 524 | 525 | # Maybe delete some older checkpoints. 526 | if self.is_world_process_zero(): 527 | self._rotate_checkpoints() 528 | 529 | def _sorted_checkpoints(self, checkpoint_prefix) -> List[str]: 530 | glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")] 531 | if 'rouge' in self.eval_metric: 532 | checkpoints_sorted = sorted(glob_checkpoints, reverse=True) 533 | else: 534 | checkpoints_sorted = sorted(glob_checkpoints, reverse=False) 535 | return checkpoints_sorted 536 | 537 | def _rotate_checkpoints(self) -> None: 538 | if self.args.save_total_limit is None or self.args.save_total_limit <= 0: 539 | return 540 | # Check if we should delete older checkpoint(s) 541 | if 'rouge' in self.eval_metric: 542 | checkpoints_sorted = self._sorted_checkpoints(checkpoint_prefix="val-rouge") 543 | else: 544 | checkpoints_sorted = self._sorted_checkpoints(checkpoint_prefix="val-loss") 545 | if len(checkpoints_sorted) <= self.args.save_total_limit: 546 | return 547 | saved_checkpoints = checkpoints_sorted[:self.args.save_total_limit] 548 | 549 | for checkpoint in checkpoints_sorted: 550 | if checkpoint not in saved_checkpoints: 551 | logger.info("Deleting checkpoint [{}] due to args.save_total_limit".format(checkpoint)) 552 | shutil.rmtree(checkpoint) 553 | 554 | def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None): 555 | """ 556 | Main training entry point. 557 | 558 | Args: 559 | model_path (:obj:`str`, `optional`): 560 | Local path to the model if the model to train has been instantiated from a local path. If present, 561 | training will resume from the optimizer/scheduler states loaded here. 562 | trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): 563 | The trial run or the hyperparameter dictionary for hyperparameter search. 564 | """ 565 | # This might change the seed so needs to run first. 566 | self._hp_search_setup(trial) 567 | 568 | # Model re-init 569 | if self.model_init is not None: 570 | # Seed must be set before instantiating the model when using model_init. 571 | set_seed(self.args.seed) 572 | 573 | model = self.call_model_init(trial) 574 | 575 | if not self.args.model_parallel: 576 | self.model = model.to(self.args.device) 577 | 578 | # Reinitializes optimizer and scheduler 579 | self.optimizer, self.lr_scheduler = None, None 580 | 581 | # Keeping track whether we can can len() on the dataset or not 582 | train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) 583 | 584 | # Data loader and number of training steps 585 | train_dataloader = self.get_train_dataloader() 586 | 587 | # Setting up training control variables: 588 | # number of training epochs: num_train_epochs 589 | # number of training steps per epoch: num_update_steps_per_epoch 590 | # total number of training steps to execute: max_steps 591 | if train_dataset_is_sized: 592 | num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps 593 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 594 | if self.args.max_steps > 0: 595 | max_steps = self.args.max_steps 596 | num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( 597 | self.args.max_steps % num_update_steps_per_epoch > 0 598 | ) 599 | else: 600 | max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch) 601 | num_train_epochs = math.ceil(self.args.num_train_epochs) 602 | else: 603 | # see __init__. max_steps is set when the dataset has no __len__ 604 | max_steps = self.args.max_steps 605 | num_train_epochs = 1 606 | num_update_steps_per_epoch = max_steps 607 | 608 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 609 | self.state = TrainerState() 610 | self.state.is_hyper_param_search = trial is not None 611 | 612 | # Check if saved optimizer or scheduler states exist 613 | self._load_optimizer_and_scheduler(model_path) 614 | 615 | # Mixed precision training with apex (torch < 1.6) 616 | model = self.model 617 | if self.use_apex: 618 | model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) 619 | 620 | # Multi-gpu training (should be after apex fp16 initialization) 621 | if self.args.n_gpu > 1 and not self.args.model_parallel: 622 | model = torch.nn.DataParallel(model) 623 | 624 | # Distributed training (should be after apex fp16 initialization) 625 | if self.sharded_dpp: 626 | model = ShardedDDP(model, self.optimizer) 627 | elif self.args.local_rank != -1: 628 | model = torch.nn.parallel.DistributedDataParallel( 629 | model, 630 | device_ids=[self.args.local_rank], 631 | output_device=self.args.local_rank, 632 | find_unused_parameters=( 633 | not getattr(model.config, "gradient_checkpointing", False) 634 | if isinstance(model, PreTrainedModel) 635 | else True 636 | ), 637 | ) 638 | 639 | # Train! 640 | if is_torch_tpu_available(): 641 | total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() 642 | else: 643 | total_train_batch_size = ( 644 | self.args.train_batch_size 645 | * self.args.gradient_accumulation_steps 646 | * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) 647 | ) 648 | 649 | num_examples = ( 650 | self.num_examples(train_dataloader) 651 | if train_dataset_is_sized 652 | else total_train_batch_size * self.args.max_steps 653 | ) 654 | 655 | logger.info("***** Running training *****") 656 | logger.info(f" Num examples = {num_examples}") 657 | logger.info(f" Num Epochs = {num_train_epochs}") 658 | logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") 659 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 660 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") 661 | logger.info(f" Total optimization steps = {max_steps}") 662 | 663 | self.state.epoch = 0 664 | epochs_trained = 0 665 | steps_trained_in_current_epoch = 0 666 | 667 | # Check if continuing training from a checkpoint 668 | if model_path and os.path.isfile(os.path.join(model_path, "trainer_state.json")): 669 | self.state = TrainerState.load_from_json(os.path.join(model_path, "trainer_state.json")) 670 | epochs_trained = self.state.global_step // num_update_steps_per_epoch 671 | if not self.args.ignore_data_skip: 672 | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) 673 | steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps 674 | else: 675 | steps_trained_in_current_epoch = 0 676 | 677 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 678 | logger.info(f" Continuing training from epoch {epochs_trained}") 679 | logger.info(f" Continuing training from global step {self.state.global_step}") 680 | if not self.args.ignore_data_skip: 681 | logger.info( 682 | f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " 683 | "batches in the first epoch." 684 | ) 685 | 686 | # Update the references 687 | self.callback_handler.model = self.model 688 | self.callback_handler.optimizer = self.optimizer 689 | self.callback_handler.lr_scheduler = self.lr_scheduler 690 | self.callback_handler.train_dataloader = train_dataloader 691 | self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None 692 | self.state.trial_params = hp_params(trial) if trial is not None else None 693 | # This should be the same if the state has been saved but in case the training arguments changed, it's safer 694 | # to set this after the load. 695 | self.state.max_steps = max_steps 696 | self.state.num_train_epochs = num_train_epochs 697 | self.state.is_local_process_zero = self.is_local_process_zero() 698 | self.state.is_world_process_zero = self.is_world_process_zero() 699 | 700 | # tr_loss is a tensor to avoid synchronization of TPUs through .item() 701 | tr_loss = torch.tensor(0.0).to(self.args.device) 702 | # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses 703 | self._total_loss_scalar = 0.0 704 | self._globalstep_last_logged = 0 705 | self._total_flos = self.state.total_flos 706 | model.zero_grad() 707 | 708 | self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control) 709 | 710 | # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. 711 | if not self.args.ignore_data_skip: 712 | for epoch in range(epochs_trained): 713 | # We just need to begin an iteration to create the randomization of the sampler. 714 | for _ in train_dataloader: 715 | break 716 | 717 | for epoch in range(epochs_trained, num_train_epochs): 718 | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): 719 | train_dataloader.sampler.set_epoch(epoch) 720 | 721 | if is_torch_tpu_available(): 722 | parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader( 723 | self.args.device 724 | ) 725 | epoch_iterator = parallel_loader 726 | else: 727 | epoch_iterator = train_dataloader 728 | 729 | # Reset the past mems state at the beginning of each epoch if necessary. 730 | if self.args.past_index >= 0: 731 | self._past = None 732 | 733 | steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps 734 | self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) 735 | 736 | for step, inputs in enumerate(epoch_iterator): 737 | 738 | # Skip past any already trained steps if resuming training 739 | if steps_trained_in_current_epoch > 0: 740 | steps_trained_in_current_epoch -= 1 741 | continue 742 | 743 | if (step + 1) % self.args.gradient_accumulation_steps == 0: 744 | self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) 745 | 746 | if ((step + 1) % self.args.gradient_accumulation_steps != 0) and self.args.local_rank != -1: 747 | # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. 748 | with model.no_sync(): 749 | tr_loss += self.training_step(model, inputs) 750 | else: 751 | tr_loss += self.training_step(model, inputs) 752 | self._total_flos += self.floating_point_ops(inputs) 753 | 754 | if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( 755 | # last step in epoch but step is always smaller than gradient_accumulation_steps 756 | steps_in_epoch <= self.args.gradient_accumulation_steps 757 | and (step + 1) == steps_in_epoch 758 | ): 759 | # Gradient clipping 760 | if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0: 761 | if self.use_amp: 762 | # AMP: gradients need unscaling 763 | self.scaler.unscale_(self.optimizer) 764 | 765 | if hasattr(self.optimizer, "clip_grad_norm"): 766 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 767 | self.optimizer.clip_grad_norm(self.args.max_grad_norm) 768 | else: 769 | # Revert to normal clipping otherwise, handling Apex or full precision 770 | torch.nn.utils.clip_grad_norm_( 771 | amp.master_params(self.optimizer) if self.use_apex else model.parameters(), 772 | self.args.max_grad_norm, 773 | ) 774 | 775 | # Optimizer step 776 | if is_torch_tpu_available(): 777 | xm.optimizer_step(self.optimizer) 778 | elif self.use_amp: 779 | self.scaler.step(self.optimizer) 780 | self.scaler.update() 781 | else: 782 | self.optimizer.step() 783 | 784 | self.lr_scheduler.step() 785 | model.zero_grad() 786 | self.state.global_step += 1 787 | self.state.epoch = epoch + (step + 1) / steps_in_epoch 788 | self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) 789 | 790 | self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) 791 | 792 | if self.control.should_epoch_stop or self.control.should_training_stop: 793 | break 794 | 795 | self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control) 796 | self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) 797 | 798 | if self.args.tpu_metrics_debug or self.args.debug: 799 | if is_torch_tpu_available(): 800 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 801 | xm.master_print(met.metrics_report()) 802 | else: 803 | logger.warning( 804 | "You enabled PyTorch/XLA debug metrics but you don't have a TPU " 805 | "configured. Check your training configuration if this is unexpected." 806 | ) 807 | if self.control.should_training_stop: 808 | break 809 | 810 | if self.args.past_index and hasattr(self, "_past"): 811 | # Clean the state at the end of training 812 | delattr(self, "_past") 813 | 814 | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") 815 | if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None: 816 | logger.info( 817 | f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." 818 | ) 819 | if isinstance(self.model, PreTrainedModel): 820 | self.model = self.model.from_pretrained(self.state.best_model_checkpoint) 821 | if not self.args.model_parallel: 822 | self.model = self.model.to(self.args.device) 823 | else: 824 | state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) 825 | self.model.load_state_dict(state_dict) 826 | 827 | if self._total_flos is not None: 828 | self.store_flos() 829 | self.log({"total_flos": self.state.total_flos}) 830 | 831 | self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) 832 | # add remaining tr_loss 833 | self._total_loss_scalar += tr_loss.item() 834 | 835 | return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step) 836 | 837 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: 838 | """ 839 | Perform a training step on a batch of inputs. 840 | 841 | Subclass and override to inject custom behavior. 842 | 843 | Args: 844 | model (:obj:`nn.Module`): 845 | The model to train. 846 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 847 | The inputs and targets of the model. 848 | 849 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 850 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 851 | 852 | Return: 853 | :obj:`torch.Tensor`: The tensor with training loss on this batch. 854 | """ 855 | 856 | model.train() 857 | inputs = self._prepare_inputs(inputs) 858 | 859 | if self.use_amp: 860 | with autocast(): 861 | loss = self.train_compute_loss(model, inputs) 862 | else: 863 | loss = self.train_compute_loss(model, inputs) 864 | 865 | if self.args.n_gpu > 1: 866 | loss = loss.mean() # mean() to average on multi-gpu parallel training 867 | 868 | if self.args.gradient_accumulation_steps > 1: 869 | loss = loss / self.args.gradient_accumulation_steps 870 | 871 | if self.use_amp: 872 | self.scaler.scale(loss).backward() 873 | elif self.use_apex: 874 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 875 | scaled_loss.backward() 876 | else: 877 | loss.backward() 878 | 879 | return loss.detach() 880 | 881 | def train_compute_loss(self, model, inputs): 882 | """ 883 | How the loss is computed by Trainer. By default, all models return the loss in the first element. 884 | """ 885 | assert self.ids_to_clean_text(inputs['labels'])[::2] == self.ids_to_clean_text(inputs['labels'])[1::2] 886 | 887 | outputs = model(**inputs) 888 | finetune_loss = outputs['loss'] 889 | 890 | # loss from contrastive learning 891 | last_hidden_state = outputs["encoder_last_hidden_state"] 892 | 893 | cl_loss = self.cl_loss_compute(last_hidden_state) 894 | 895 | # final loss: alpha * cl_loss + (1-alpha) * finetune_loss 896 | loss = self.alpha * cl_loss + (1 - self.alpha) * finetune_loss 897 | 898 | return loss 899 | 900 | def cl_loss_compute(self, last_hidden_state): 901 | if self.hidden_state_representation == 'cls': 902 | xi = last_hidden_state[::2][::,0] 903 | xj = last_hidden_state[1::2][::,0] 904 | elif self.hidden_state_representation == 'average': 905 | xi = torch.mean(last_hidden_state[::2], dim=1) 906 | xj = torch.mean(last_hidden_state[1::2], dim=1) 907 | assert xi.size() == xj.size() 908 | cl_loss = self._step(xi, xj) 909 | 910 | return cl_loss 911 | 912 | def _step(self, xi, xj): 913 | # get the projection 914 | xis = self.projection(xi) 915 | xjs = self.projection(xj) 916 | 917 | # normalize projection feature vectors 918 | zis = F.normalize(xis, dim=1) 919 | zjs = F.normalize(xjs, dim=1) 920 | 921 | # initialize the loss 922 | self.nt_xent_criterion = NTXentLoss(device=zis.device, batch_size=zjs.size()[0], 923 | temperature=self.temperature, 924 | use_cosine_similarity=True) 925 | loss = self.nt_xent_criterion(zis, zjs) 926 | return loss 927 | 928 | def projection(self, x): 929 | # projection MLP 930 | self.l1 = nn.Linear(x.size()[1], x.size()[1]).to(x.device) 931 | self.l2 = nn.Linear(x.size()[1], 128).to(x.device) 932 | 933 | x = self.l1(x) 934 | x = F.relu(x) 935 | x = self.l2(x) 936 | return x 937 | 938 | def get_train_dataloader(self) -> DataLoader: 939 | """ 940 | Returns the training :class:`~torch.utils.data.DataLoader`. 941 | 942 | Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted 943 | to distributed training if necessary) otherwise. 944 | 945 | Subclass and override this method if you want to inject some custom behavior. 946 | """ 947 | if self.train_dataset is None: 948 | raise ValueError("Trainer: training requires a train_dataset.") 949 | 950 | return DataLoader( 951 | self.train_dataset, 952 | batch_size=self.args.train_batch_size, 953 | sampler=None, 954 | collate_fn=self.data_collator, 955 | drop_last=True, 956 | # drop_last=self.args.dataloader_drop_last, 957 | num_workers=self.args.dataloader_num_workers, 958 | shuffle=False, 959 | ) 960 | 961 | --------------------------------------------------------------------------------