├── source ├── __init__.py ├── common.py ├── generate.py ├── train.py └── encoder_decoder.py ├── .gitignore ├── requirements.txt ├── train_rebart.sh ├── LICENSE ├── README.md └── eval └── evaluation.py /source/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | 4 | 5 | *.pyc 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | pandas==1.3.3 3 | regex==2021.4.4 4 | requests==2.26.0 5 | rouge-score==0.0.4 6 | sacremoses==0.0.45 7 | scipy==1.7.0 8 | sentence-transformers==2.0.0 9 | sentencepiece==0.1.96 10 | tensorboardX==2.4 11 | tokenizers==0.8.0rc4 12 | torch==1.9.1 13 | torchvision==0.10.0 14 | tqdm==4.62.3 15 | transformers==3.0.0 16 | -------------------------------------------------------------------------------- /train_rebart.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | DATA_DIR="data/arxiv-abs" 5 | OUT_DIR="outputs/reorder_exp/bart-large_arxiv" 6 | 7 | mkdir -p ${OUT_DIR} 8 | cp $0 ${OUT_DIR} 9 | 10 | python -m source.encoder_decoder \ 11 | --train_file ${DATA_DIR}/train.jsonl \ 12 | --eval_data_file ${DATA_DIR}/dev.jsonl \ 13 | --out_dir $OUT_DIR \ 14 | --model_type facebook/bart-large \ 15 | --model_name_or_path facebook/bart-large \ 16 | --device 1 \ 17 | --do_train \ 18 | --do_eval \ 19 | --save_total_limit 1 \ 20 | --num_train_epochs 1 \ 21 | --logging_steps 3000 \ 22 | --gradient_accumulation_steps 8 \ 23 | --train_batch_size 4 \ 24 | --eval_batch_size 8 \ 25 | --overwrite_out_dir \ 26 | --max_input_length 1024 \ 27 | --max_output_length 40 \ 28 | --task index_with_sep \ 29 | $@ 30 | #--overwrite_cache \ 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 FaezeBr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /source/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import torch.nn as nn 5 | from transformers import AutoModelWithLMHead, AutoTokenizer 6 | 7 | 8 | def init_model(model_name: str, device, do_lower_case: bool = False, args=None): 9 | """ 10 | Initialize a pre-trained LM 11 | :param model_name: from MODEL_CLASSES 12 | :param device: CUDA / CPU device 13 | :param do_lower_case: whether the model is lower cased or not 14 | :return: the model and tokenizer 15 | """ 16 | tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=do_lower_case) 17 | model = AutoModelWithLMHead.from_pretrained(model_name) 18 | 19 | # uncomment for using data parallel 20 | # special_tokens = ["[shuffled]", "[orig]", ""] 21 | # extra_specials = [f"" for i in range(args.max_output_length)] 22 | # special_tokens += extra_specials 23 | # tokenizer.pad_token = "" 24 | # tokenizer.eos_token = "" 25 | # tokenizer.add_tokens(special_tokens) 26 | # 27 | # model.resize_token_embeddings(len(tokenizer)) 28 | # model = nn.DataParallel(model, device_ids = [1, 2]) 29 | model.to(device) 30 | model.eval() 31 | return tokenizer, model 32 | 33 | 34 | ### Reordering task 35 | def load_data(in_file, task="in_shuf"): 36 | """ 37 | Loads the dataset file: 38 | in_file: json file 39 | Returns a list of tuples (input, output) 40 | """ 41 | all_lines = [] 42 | with open(in_file, "r", encoding="utf-8") as f: 43 | for line in f: 44 | all_lines.append(json.loads(line)) 45 | if task == "index_with_sep": 46 | examples = [ 47 | ( 48 | f"[shuffled] {' '.join([' '.join((f'', sent)) for i, sent in zip(list(range(len(line['orig_sents']))), line['shuf_sents'])])} [orig]", 49 | f"{' '.join(line['orig_sents'])} ", 50 | ) 51 | for line in all_lines 52 | ] 53 | else: 54 | examples = [ 55 | ( 56 | f"[shuffled] {line['shuf_sents'].rstrip(' ') if type(line['shuf_sents']) == str else ' '.join(line['shuf_sents'])} [orig]", 57 | f"{line['orig_sents'].rstrip(' ') if type(line['orig_sents']) == str else ' '.join(line['orig_sents'])} ", 58 | ) 59 | for line in all_lines 60 | ] 61 | return examples 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Is Everything in Order? A Simple Way to Order Sentences 2 | 3 | This repo contains code for the EMNLP 2021 paper: 4 | 5 | **Is Everything in Order? A Simple Way to Order Sentences** 6 | 7 | *Somnath Basu Roy Chowdhury\*, Faeze Brahman\*, Snigdha Chaturvedi* EMNLP 2021 8 | 9 | [Link to paper](https://arxiv.org/pdf/2104.07064.pdf) 10 | 11 | ### Pre-requisities 12 | 13 | Please create a fresh conda env and run: 14 | 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ### Datasets 20 | 21 | First, create the dataset splits and put them in `./data` folder. 22 | 23 | Please find the links for the various datasets: [arXiv](https://drive.google.com/drive/folders/0B-mnK8kniGAiNVB6WTQ4bmdyamc), [Wiki Movie Plots](https://www.kaggle.com/jrobischon/wikipedia-movie-plots), [SIND](http://visionandlanguage.net/VIST/dataset.html), [NSF](https://archive.ics.uci.edu/ml/datasets/NSF+Research+Award+Abstracts+1990-2003), [ROCStories](https://www.cs.rochester.edu/nlp/rocstories/), [NeurIPS](https://www.kaggle.com/benhamner/nips-papers), [AAN](https://github.com/EagleW/ACL_titles_abstracts_dataset). 24 | 25 | All datsets should be formatted in jsonl files where each line is a json containing two fields: `orig_sents`, and `shuf_sents`. `orig_sents` is a list of markers [y1, y2, ..., yN], which denotes the position of ith sentence of the corresponding ordered sequence in the shuffled input (`shuf_sents`). An example is provided for ROCStories in [here](https://drive.google.com/drive/folders/1bY7CvXF1q2kgpmtXWtD0NT3bFRfLHpV1?usp=sharing). 26 | 27 | The exact data used in our experiments can be found [here](https://drive.google.com/file/d/17r9D_l-jdhHhpLsa86FGuWgeLgeJkQ19/view?usp=sharing). 28 | 29 | ### Train the ReBART model: 30 | 31 | To train the ReBART model run the following command: 32 | 33 | ``` 34 | bash train_rebart.sh 35 | ``` 36 | You can specify the hyper-parameters inside the bash script. 37 | 38 | ### Generate 39 | 40 | To generate the outputs (position markers) using the trained model, run the following commands: 41 | 42 | ``` 43 | export DATA_DIR="data/arxiv-abs" 44 | export MODEL_PATH="outputs/reorder_exp/bart-large_arxiv" 45 | python source/generate.py --in_file $DATA_DIR/test.jsonl --out_file $MODEL_PATH/test_bart_greedy.jsonl --model_name_or_path $MODEL_PATH --beams 1 --max_length 40 --task index_with_sep --device 0 46 | ``` 47 | 48 | ### Evaluate 49 | 50 | To evaluate the model and get the performance metrics, run: 51 | 52 | ``` 53 | python eval/evaluation.py --output_path $MODEL_PATH/test_bart_greedy.jsonl 54 | ``` 55 | 56 | 57 | ### Citation 58 | 59 | If you used our work please cite us using: 60 | 61 | ``` 62 | @inproceedings{basu-roy-chowdhury-etal-2021-everything, 63 | title = "Is Everything in Order? A Simple Way to Order Sentences", 64 | author = "Basu Roy Chowdhury, Somnath and 65 | Brahman, Faeze and 66 | Chaturvedi, Snigdha", 67 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 68 | month = nov, 69 | year = "2021", 70 | address = "Online and Punta Cana, Dominican Republic", 71 | publisher = "Association for Computational Linguistics", 72 | url = "https://aclanthology.org/2021.emnlp-main.841", 73 | doi = "10.18653/v1/2021.emnlp-main.841", 74 | pages = "10769--10779", 75 | } 76 | ``` 77 | 78 | -------------------------------------------------------------------------------- /source/generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/huggingface/transformers/blob/master/examples/run_generation.py 3 | """ 4 | import re 5 | import json 6 | import tqdm 7 | import torch 8 | import logging 9 | import argparse 10 | 11 | logging.basicConfig( 12 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 13 | datefmt="%m/%d/%Y %H:%M:%S", 14 | level=logging.INFO, 15 | ) 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | from common import init_model, load_data 20 | 21 | 22 | def main() -> None: 23 | """ 24 | Generate outputs 25 | """ 26 | parser = argparse.ArgumentParser() 27 | 28 | # Required 29 | parser.add_argument( 30 | "--in_file", 31 | default=None, 32 | type=str, 33 | required=True, 34 | help="The input json file", 35 | ) 36 | parser.add_argument( 37 | "--out_file", 38 | default=None, 39 | type=str, 40 | required=True, 41 | help="out jsonl file", 42 | ) 43 | parser.add_argument( 44 | "--model_name_or_path", 45 | default="gpt2", 46 | type=str, 47 | help="LM checkpoint for initialization.", 48 | ) 49 | 50 | # Optional 51 | parser.add_argument( 52 | "--max_length", default=40, type=int, required=False, help="Maximum text length" 53 | ) 54 | parser.add_argument( 55 | "--k", default=0, type=int, required=False, help="k for top k sampling" 56 | ) 57 | parser.add_argument( 58 | "--p", default=0, type=float, required=False, help="p for nucleus sampling" 59 | ) 60 | parser.add_argument( 61 | "--beams", default=0, type=int, required=False, help="beams for beam search" 62 | ) 63 | parser.add_argument( 64 | "--temperature", 65 | default=1.0, 66 | type=float, 67 | required=False, 68 | help="temperature for sampling", 69 | ) 70 | parser.add_argument( 71 | "--device", default="cpu", type=str, help="GPU number or 'cpu'." 72 | ) 73 | parser.add_argument( 74 | "--task", 75 | default="", 76 | type=str, 77 | help="what is the task?" 78 | ) 79 | args = parser.parse_args() 80 | logger.debug(args) 81 | 82 | if ( 83 | (args.k == args.p == args.beams == 0) 84 | or (args.k != 0 and args.p != 0) 85 | or (args.beams != 0 and args.p != 0) 86 | or (args.beams != 0 and args.k != 0) 87 | ): 88 | raise ValueError( 89 | "Exactly one of p, k, and beams should be set to a non-zero value." 90 | ) 91 | 92 | device = torch.device( 93 | f"cuda:{args.device}" 94 | if torch.cuda.is_available() and args.device != "cpu" 95 | else "cpu" 96 | ) 97 | logger.debug(f"Initializing {args.device}") 98 | 99 | tokenizer, model = init_model(args.model_name_or_path, device) 100 | 101 | examples = load_data(args.in_file, args.task) 102 | 103 | logger.info(examples[:5]) 104 | 105 | special_tokens = ["[shuffled]", "[orig]", ""] 106 | extra_specials = [f"" for i in range(args.max_length)] 107 | special_tokens += extra_specials 108 | 109 | 110 | with open(args.out_file, "w") as f_out: 111 | for input, output in tqdm.tqdm(examples): 112 | try: 113 | preds = generate_conditional( 114 | tokenizer, 115 | model, 116 | args, 117 | input, 118 | device, 119 | ) 120 | 121 | # Remove any word that has "]" or "[" in it 122 | preds = [re.sub(r"(\w*\])", "", pred) for pred in preds] 123 | preds = [re.sub(r"(\[\w*)", "", pred) for pred in preds] 124 | preds = [re.sub(" +", " ", pred).strip() for pred in preds] 125 | 126 | except Exception as exp: 127 | logger.info(exp) 128 | preds = [] 129 | 130 | f_out.write( 131 | json.dumps({"input": input, "gold": output, "predictions": preds}) 132 | + "\n" 133 | ) 134 | 135 | 136 | def generate_conditional(tokenizer, model, args, input, device): 137 | """ 138 | Generate a sequence with models like Bart and T5 139 | """ 140 | input_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input)) 141 | decoder_start_token_id = input_ids[-1] 142 | input_ids = torch.tensor([input_ids]).to(device) 143 | max_length = args.max_length 144 | 145 | 146 | outputs = model.generate( 147 | input_ids, 148 | do_sample=args.beams == 0, 149 | max_length=max_length, 150 | min_length=5, 151 | temperature=args.temperature, 152 | top_p=args.p if args.p > 0 else None, 153 | top_k=args.k if args.k > 0 else None, 154 | num_beams=args.beams if args.beams > 0 else None, 155 | early_stopping=True, 156 | no_repeat_ngram_size=2, 157 | eos_token_id=tokenizer.eos_token_id, 158 | decoder_start_token_id=decoder_start_token_id, 159 | num_return_sequences=1 #max(1, args.beams) 160 | ) 161 | 162 | 163 | preds = [tokenizer.decode( 164 | output, skip_special_tokens=False, clean_up_tokenization_spaces=False) for output in outputs] 165 | 166 | 167 | return preds 168 | 169 | 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /eval/evaluation.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import sys 4 | import numpy as np 5 | import json 6 | import re 7 | import math 8 | import nltk 9 | import argparse 10 | 11 | from scipy.stats import kendalltau 12 | from tqdm import tqdm 13 | 14 | 15 | def kendall_tau(order, ground_truth): 16 | """ 17 | Computes the kendall's tau metric 18 | between the predicted sentence order and true order 19 | 20 | Input: 21 | order: list of ints denoting the predicted output order 22 | ground_truth: list of ints denoting the true sentence order 23 | 24 | Returns: 25 | kendall's tau - float 26 | """ 27 | 28 | if len(ground_truth) == 1: 29 | if ground_truth[0] == order[0]: 30 | return 1.0 31 | 32 | reorder_dict = {} 33 | 34 | for i in range(len(ground_truth)): 35 | reorder_dict[ground_truth[i]] = i 36 | 37 | new_order = [0] * len(order) 38 | for i in range(len(new_order)): 39 | if order[i] in reorder_dict.keys(): 40 | new_order[i] = reorder_dict[order[i]] 41 | 42 | corr, _ = kendalltau(new_order, list(range(len(order)))) 43 | return corr 44 | 45 | def lcs(X , Y): 46 | """ 47 | Computes the longest common subsequence between two sequences 48 | 49 | Input: 50 | X: list of ints 51 | Y: list of ints 52 | 53 | Returns: 54 | LCS: int 55 | """ 56 | m = len(X) 57 | n = len(Y) 58 | 59 | L = [[None]*(n+1) for i in range(m+1)] 60 | 61 | for i in range(m+1): 62 | for j in range(n+1): 63 | if i == 0 or j == 0 : 64 | L[i][j] = 0 65 | elif X[i-1] == Y[j-1]: 66 | L[i][j] = L[i-1][j-1]+1 67 | else: 68 | L[i][j] = max(L[i-1][j] , L[i][j-1]) 69 | 70 | return L[m][n] 71 | 72 | 73 | def skip_bigrams(arr): 74 | """ 75 | Utility function for Rouge-S metric 76 | """ 77 | bigrams = set() 78 | for i in range(len(arr)): 79 | for j in range(i+1, len(arr)): 80 | bigrams.add((arr[i], arr[j])) 81 | return bigrams 82 | 83 | def rouge_s(gold, pred): 84 | """ 85 | Rouge-S metric between two sequence 86 | 87 | Input: 88 | gold: list of ints 89 | pred: list of ints 90 | 91 | Returns: 92 | Rouge-S score 93 | """ 94 | 95 | if len(gold) == 1 or len(pred) == 1: 96 | return int(gold[0] == pred[0]) 97 | 98 | gold_bigrams = skip_bigrams(gold) 99 | pred_bigrams = skip_bigrams(pred) 100 | 101 | total = len(gold_bigrams) 102 | same = len(gold_bigrams.intersection(pred_bigrams)) 103 | return (same / total) 104 | 105 | 106 | def clean_output(gold, predictions): 107 | """ 108 | Utility function to clean generated output from BART 109 | """ 110 | 111 | label = gold.replace("", "").strip() 112 | labels = [int(id_[2:-1]) for id_ in label.split()] 113 | 114 | # handle cases when output is empty 115 | if len(predictions) == 0: 116 | return labels, [] 117 | 118 | preds = [] 119 | for p in predictions[0].split(): 120 | pos = re.findall('\\d+', p) 121 | if len(pos) == 1: 122 | preds.append(int(pos[0])) 123 | return labels, preds 124 | 125 | 126 | def evaluate(filename): 127 | """ 128 | Evaluation iterator function. Generates all metrics 129 | by calling the functions for every instance. 130 | 131 | Input: 132 | filename: file name of the generated output 133 | 134 | Returns: None 135 | """ 136 | 137 | acc, PMR, kendall_score, LCS, rouge = 0, 0, 0, 0, 0 138 | total, total_sents = 0, 0 139 | 140 | err = 0 141 | 142 | with open(filename) as file: 143 | lines = file.readlines() 144 | for line in tqdm(lines): 145 | entry = json.loads(line.strip()) 146 | gold, predictions = clean_output(entry["gold"], entry["predictions"]) 147 | 148 | total += 1 149 | total_sents += len(gold) 150 | 151 | 152 | if len(predictions) == 0: 153 | err += 1 154 | continue 155 | 156 | LCS += lcs(gold, predictions) 157 | 158 | rouge += rouge_s(gold, predictions) 159 | 160 | if predictions == gold: 161 | PMR += 1 162 | 163 | tau_score = kendall_tau(predictions, gold) 164 | 165 | # handle cases of empty output 166 | if math.isnan(tau_score): 167 | err += 1 168 | tau_score = 0 169 | 170 | kendall_score += tau_score 171 | 172 | # Compute sentence level statistics 173 | for i in range(min(len(gold), len(predictions))): 174 | if gold[i] == predictions[i]: 175 | acc += 1 176 | 177 | print(f" {err} sample(s) were not processed") 178 | print(" Accuracy: {:.6f}".format(acc / total_sents)) 179 | print(" PMR: {:.6f}".format(PMR / total)) 180 | print(" Kendall's Tau: {:.6f}".format(kendall_score / total)) 181 | print(" LCS: {:.6f}".format(LCS / total_sents)) 182 | print(" Rouge-S: {:.6f}".format(rouge / total)) 183 | 184 | def main(): 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument( 187 | "--output_path", type = str, required=True 188 | ) 189 | 190 | args = parser.parse_args() 191 | evaluate(args.output_path) 192 | 193 | if __name__ == "__main__": 194 | main() -------------------------------------------------------------------------------- /source/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on https://github.com/huggingface/transformers/blob/master/examples/run_lm_finetuning.py: 3 | fine-tuning language models on a text file using a causal language modeling (CLM) loss. 4 | """ 5 | import os 6 | import re 7 | import glob 8 | import torch 9 | import random 10 | import shutil 11 | import pickle 12 | import logging 13 | import argparse 14 | import numpy as np 15 | 16 | from tqdm import tqdm, trange 17 | from torch.nn import CrossEntropyLoss 18 | from transformers import AdamW, get_linear_schedule_with_warmup 19 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 20 | 21 | from source.common import init_model, load_data 22 | 23 | 24 | try: 25 | from torch.utils.tensorboard import SummaryWriter 26 | except ImportError: 27 | from tensorboardX import SummaryWriter 28 | 29 | 30 | logging.basicConfig( 31 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 32 | datefmt="%m/%d/%Y %H:%M:%S", 33 | level=logging.DEBUG, 34 | ) 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | def set_seed(args): 39 | """ 40 | Set the random seed for reproducibility 41 | """ 42 | random.seed(args.seed) 43 | np.random.seed(args.seed) 44 | torch.manual_seed(args.seed) 45 | if torch.cuda.is_available(): 46 | torch.cuda.manual_seed_all(args.seed) 47 | 48 | def get_loss(args, batch, model): 49 | """ 50 | Compute this batch loss 51 | """ 52 | token_ids = batch["examples"].to(args.device) 53 | input_mask = batch["input_mask"].to(args.device) 54 | 55 | # We don't send labels to model.forward because we want to compute per token loss 56 | lm_logits = model(token_ids, attention_mask=input_mask)[0] 57 | shift_logits = lm_logits[..., :-1, :].contiguous() 58 | batch_size, max_length, vocab_size = shift_logits.shape 59 | 60 | # Compute loss for each instance and each token 61 | loss_fct = CrossEntropyLoss(reduction="none") 62 | shift_logits = shift_logits.view(-1, vocab_size) 63 | shift_labels = token_ids[..., 1:].contiguous().view(-1) 64 | loss = loss_fct(shift_logits, shift_labels).view(batch_size, max_length) 65 | 66 | # Only consider non padded tokens 67 | loss_mask = input_mask[..., :-1].contiguous() 68 | loss = torch.mul(loss_mask, loss) # [batch_size, max_length] 69 | 70 | return loss 71 | 72 | def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False): 73 | """ 74 | Keep a maximum of args.save_total_limit checkpoints. 75 | """ 76 | if not args.save_total_limit: 77 | return 78 | 79 | if args.save_total_limit <= 0: 80 | return 81 | 82 | # Check if we should delete older checkpoint(s) 83 | glob_checkpoints = glob.glob( 84 | os.path.join(args.out_dir, "{}-*".format(checkpoint_prefix)) 85 | ) 86 | if len(glob_checkpoints) <= args.save_total_limit: 87 | return 88 | 89 | ordering_and_checkpoint_path = [] 90 | for path in glob_checkpoints: 91 | if use_mtime: 92 | ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) 93 | else: 94 | regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path) 95 | if regex_match and regex_match.groups(): 96 | ordering_and_checkpoint_path.append( 97 | (int(regex_match.groups()[0]), path) 98 | ) 99 | 100 | checkpoints_sorted = sorted(ordering_and_checkpoint_path) 101 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] 102 | number_of_checkpoints_to_delete = max( 103 | 0, len(checkpoints_sorted) - args.save_total_limit 104 | ) 105 | checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] 106 | for checkpoint in checkpoints_to_be_deleted: 107 | logger.info( 108 | "Deleting older checkpoint [{}] due to args.save_total_limit".format( 109 | checkpoint 110 | ) 111 | ) 112 | shutil.rmtree(checkpoint) 113 | 114 | 115 | def train(args, train_dataset, model, tokenizer, loss_fnc=get_loss, eval_dataset=None): 116 | """ 117 | Train the model. 118 | """ 119 | tb_writer = SummaryWriter() 120 | train_sampler = RandomSampler(train_dataset) 121 | train_dataloader = DataLoader( 122 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size 123 | ) 124 | 125 | # Set the number of steps based on the num_epochs * len(train) or args.max_steps if specified. 126 | if args.max_steps > 0: 127 | t_total = args.max_steps 128 | args.num_train_epochs = ( 129 | args.max_steps 130 | // (len(train_dataloader) // args.gradient_accumulation_steps) 131 | + 1 132 | ) 133 | else: 134 | t_total = ( 135 | len(train_dataloader) 136 | // args.gradient_accumulation_steps 137 | * args.num_train_epochs 138 | ) 139 | 140 | # Prepare optimizer and scheduler (linear warmup and decay) 141 | no_decay = ["bias", "LayerNorm.weight"] 142 | optimizer_grouped_parameters = [ 143 | { 144 | "params": [ 145 | p 146 | for n, p in model.named_parameters() 147 | if not any(nd in n for nd in no_decay) 148 | ], 149 | "weight_decay": args.weight_decay, 150 | }, 151 | { 152 | "params": [ 153 | p 154 | for n, p in model.named_parameters() 155 | if any(nd in n for nd in no_decay) 156 | ], 157 | "weight_decay": 0.0, 158 | }, 159 | ] 160 | 161 | optimizer = AdamW( 162 | optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon 163 | ) 164 | scheduler = get_linear_schedule_with_warmup( 165 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 166 | ) 167 | 168 | # Check if saved optimizer or scheduler states exist and load from there 169 | if os.path.isfile( 170 | os.path.join(args.model_name_or_path, "optimizer.pt") 171 | ) and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt")): 172 | optimizer.load_state_dict( 173 | torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")) 174 | ) 175 | scheduler.load_state_dict( 176 | torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")) 177 | ) 178 | 179 | # Train 180 | total_batch_size = args.train_batch_size * args.gradient_accumulation_steps 181 | logger.info("***** Running training *****") 182 | logger.info(f" Num examples = {len(train_dataset)}") 183 | logger.info(f" Num Epochs = {args.num_train_epochs}") 184 | logger.info(f" Instantaneous batch size = {args.train_batch_size}") 185 | logger.info( 186 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 187 | ) 188 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 189 | logger.info(f" Total optimization steps = {t_total}") 190 | 191 | global_step = 0 192 | epochs_trained = 0 193 | steps_trained_in_current_epoch = 0 194 | 195 | # Check if continuing training from a checkpoint 196 | if os.path.exists(args.model_name_or_path): 197 | try: 198 | # set global_step to global_step of last saved checkpoint from model path 199 | checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] 200 | global_step = int(checkpoint_suffix) 201 | epochs_trained = global_step // ( 202 | len(train_dataloader) // args.gradient_accumulation_steps 203 | ) 204 | steps_trained_in_current_epoch = global_step % ( 205 | len(train_dataloader) // args.gradient_accumulation_steps 206 | ) 207 | 208 | logger.info( 209 | " Continuing training from checkpoint, will skip to saved global_step" 210 | ) 211 | logger.info(f" Continuing training from epoch {epochs_trained}") 212 | logger.info(f" Continuing training from global step {global_step}") 213 | logger.info( 214 | f" Will skip the first {steps_trained_in_current_epoch} steps in the first epoch" 215 | ) 216 | except ValueError: 217 | logger.info(" Starting fine-tuning.") 218 | 219 | tr_loss, logging_loss = 0.0, 0.0 220 | 221 | model_to_resize = model.module if hasattr(model, "module") else model 222 | model_to_resize.resize_token_embeddings(len(tokenizer)) 223 | 224 | model.zero_grad() 225 | train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch") 226 | set_seed(args) # Added here for reproducibility 227 | 228 | for _ in train_iterator: 229 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 230 | for step, batch in enumerate(epoch_iterator): 231 | 232 | # Skip past any already trained steps if resuming training 233 | if steps_trained_in_current_epoch > 0: 234 | steps_trained_in_current_epoch -= 1 235 | continue 236 | 237 | model.train() 238 | 239 | # Take the loss only for the part after the input (as in seq2seq architecture) 240 | loss = loss_fnc(args, batch, model) 241 | loss = loss.mean() 242 | 243 | if args.gradient_accumulation_steps > 1: 244 | loss = loss / args.gradient_accumulation_steps 245 | 246 | loss.backward() 247 | 248 | tr_loss += loss.item() 249 | if (step + 1) % args.gradient_accumulation_steps == 0: 250 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 251 | optimizer.step() 252 | scheduler.step() # Update learning rate schedule 253 | model.zero_grad() 254 | global_step += 1 255 | 256 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 257 | # Log metrics 258 | if args.eval_during_train: 259 | results = evaluate(eval_dataset, args, model, loss_fnc=loss_fnc) 260 | for key, value in results.items(): 261 | tb_writer.add_scalar( 262 | "eval_{}".format(key), value, global_step 263 | ) 264 | tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) 265 | tb_writer.add_scalar( 266 | "loss", 267 | (tr_loss - logging_loss) / args.logging_steps, 268 | global_step, 269 | ) 270 | logging_loss = tr_loss 271 | 272 | if args.save_steps > 0 and global_step % args.save_steps == 0: 273 | checkpoint_prefix = "checkpoint" 274 | 275 | # Save model checkpoint 276 | out_dir = os.path.join( 277 | args.out_dir, "{}-{}".format(checkpoint_prefix, global_step) 278 | ) 279 | 280 | if not os.path.exists(out_dir): 281 | os.makedirs(out_dir) 282 | 283 | model_to_save = model.module if hasattr(model, "module") else model 284 | model_to_save.save_pretrained(out_dir) 285 | tokenizer.save_pretrained(out_dir) 286 | torch.save(args, os.path.join(out_dir, "training_args.bin")) 287 | logger.info("Saving model checkpoint to %s", out_dir) 288 | 289 | _rotate_checkpoints(args, checkpoint_prefix) 290 | 291 | torch.save( 292 | optimizer.state_dict(), os.path.join(out_dir, "optimizer.pt") 293 | ) 294 | torch.save( 295 | scheduler.state_dict(), os.path.join(out_dir, "scheduler.pt") 296 | ) 297 | logger.info("Saving optimizer and scheduler states to %s", out_dir) 298 | 299 | if 0 < args.max_steps < global_step: 300 | epoch_iterator.close() 301 | break 302 | 303 | if 0 < args.max_steps < global_step: 304 | train_iterator.close() 305 | break 306 | 307 | tb_writer.close() 308 | return global_step, tr_loss / global_step 309 | 310 | 311 | def evaluate(eval_dataset, args, model, prefix="", loss_fnc=get_loss): 312 | """ 313 | Evaluation 314 | """ 315 | eval_out_dir = args.out_dir 316 | 317 | if not os.path.exists(eval_out_dir): 318 | os.makedirs(eval_out_dir) 319 | 320 | eval_sampler = SequentialSampler(eval_dataset) 321 | eval_dataloader = DataLoader( 322 | eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size 323 | ) 324 | 325 | logger.info(f"***** Running evaluation {prefix} *****") 326 | logger.info(f" Num examples = {len(eval_dataset)}") 327 | logger.info(f" Batch size = {args.eval_batch_size}") 328 | micro_loss = macro_loss = 0.0 329 | num_tokens = num_batches = 0 330 | model.eval() 331 | 332 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 333 | with torch.no_grad(): 334 | batch_loss = loss_fnc(args, batch, model) 335 | macro_loss += batch_loss.mean().item() 336 | micro_loss += batch_loss.sum().item() 337 | num_tokens += batch_loss.view(-1).shape[0] 338 | num_batches += 1 339 | 340 | macro_perplexity = torch.exp(torch.tensor(macro_loss / num_batches)) 341 | micro_perplexity = torch.exp(torch.tensor(micro_loss / num_tokens)) 342 | 343 | result = { 344 | "macro_perplexity": macro_perplexity, 345 | "micro_perplexity": micro_perplexity, 346 | } 347 | 348 | output_eval_file = os.path.join(eval_out_dir, prefix, "eval_results.txt") 349 | with open(output_eval_file, "w") as writer: 350 | logger.info(f"***** Eval results {prefix} *****") 351 | for key in sorted(result.keys()): 352 | logger.info(f" {key} = {result[key]}") 353 | writer.write(f"{key} = {result[key]}\n") 354 | 355 | return result 356 | 357 | 358 | if __name__ == "__main__": 359 | main() 360 | -------------------------------------------------------------------------------- /source/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | fine-tuning the encoder-decoder BART/T5 model. 3 | """ 4 | import os 5 | import torch 6 | import pickle 7 | import logging 8 | import argparse 9 | 10 | from torch.utils.data import Dataset 11 | from torch.nn import CrossEntropyLoss 12 | 13 | from source.common import init_model, load_data 14 | from source.train import evaluate, train, set_seed 15 | 16 | 17 | logging.basicConfig( 18 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 19 | datefmt="%m/%d/%Y %H:%M:%S", 20 | level=logging.DEBUG, 21 | ) 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class EncoderDecoderTextDataset(Dataset): 26 | def __init__(self, tokenizer, args, file_path, block_size=512): 27 | print(file_path) 28 | assert os.path.isfile(file_path) 29 | directory, filename = os.path.split(file_path) 30 | filename = f"{os.path.basename(args.model_type)}_cached_{block_size}_{filename}{'_' + args.task if args.task else ''}" 31 | cached_features_file = os.path.join(directory, filename) 32 | 33 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 34 | logger.info(f"Loading features from cached file {cached_features_file}") 35 | with open(cached_features_file, "rb") as handle: 36 | self.examples = pickle.load(handle) 37 | else: 38 | logger.info("Converting to token IDs") 39 | examples = load_data(file_path, args.task) 40 | logger.info(examples[:5]) 41 | 42 | # Add prefix to the output so we can predict the first real token in the decoder 43 | inputs = [ 44 | tokenizer.convert_tokens_to_ids(tokenizer.tokenize(ex[0])) 45 | for ex in examples 46 | ] 47 | outputs = [ 48 | [inputs[i][-1]] 49 | + tokenizer.convert_tokens_to_ids(tokenizer.tokenize(ex[1])) 50 | for i, ex in enumerate(examples) 51 | ] 52 | 53 | # Pad 54 | max_input_length = min( 55 | args.max_input_length, max([len(ex) for ex in inputs]) 56 | ) 57 | max_output_length = min( 58 | args.max_output_length, max([len(ex) for ex in outputs]) 59 | ) 60 | 61 | input_lengths = [min(len(ex), max_input_length) for ex in inputs] 62 | output_lengths = [min(len(ex), max_output_length) for ex in outputs] 63 | 64 | inputs = [tokenizer.encode( 65 | ex, add_special_tokens=False, max_length=max_input_length, pad_to_max_length=True) 66 | for ex in inputs] 67 | 68 | outputs = [tokenizer.encode( 69 | ex, add_special_tokens=False, max_length=max_output_length, pad_to_max_length=True) 70 | for ex in outputs] 71 | 72 | self.examples = { 73 | "inputs": inputs, 74 | "outputs": outputs, 75 | "input_lengths": input_lengths, 76 | "output_lengths": output_lengths, 77 | } 78 | 79 | logger.info(f"Saving features into cached file {cached_features_file}") 80 | with open(cached_features_file, "wb") as handle: 81 | pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) 82 | 83 | def __len__(self): 84 | return len(self.examples["input_lengths"]) 85 | 86 | def __getitem__(self, item): 87 | inputs = torch.tensor(self.examples["inputs"][item]) 88 | outputs = torch.tensor(self.examples["outputs"][item]) 89 | 90 | max_length = inputs.shape[0] 91 | input_lengths = self.examples["input_lengths"][item] 92 | input_mask = torch.tensor([1] * input_lengths + [0] * (max_length - input_lengths)) 93 | 94 | max_length = outputs.shape[0] 95 | output_lengths = self.examples["output_lengths"][item] 96 | output_mask = torch.tensor([1] * output_lengths + [0] * (max_length - output_lengths)) 97 | 98 | return { 99 | "inputs": inputs, 100 | "input_mask": input_mask, 101 | "outputs": outputs, 102 | "output_mask": output_mask, 103 | } 104 | 105 | 106 | def get_loss(args, batch, model): 107 | """ 108 | Compute this batch loss 109 | """ 110 | input_ids = batch["inputs"].to(args.device) 111 | input_mask = batch["input_mask"].to(args.device) 112 | target_ids = batch["outputs"].to(args.device) 113 | output_mask = batch["output_mask"].to(args.device) 114 | decoder_input_ids = target_ids[:, :-1].contiguous() 115 | 116 | # We don't send labels to model.forward because we want to compute per token loss 117 | lm_logits = model( 118 | input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids, use_cache=False 119 | )[0] # use_cache=false is added for HF > 3.0 120 | batch_size, max_length, vocab_size = lm_logits.shape 121 | 122 | # Compute loss for each instance and each token 123 | loss_fct = CrossEntropyLoss(reduction="none") 124 | lm_labels = target_ids[:, 1:].clone().contiguous() 125 | lm_labels[target_ids[:, 1:] == args.pad_token_id] = -100 126 | loss = loss_fct(lm_logits.view(-1, vocab_size), lm_labels.view(-1)).view( 127 | batch_size, max_length 128 | ) 129 | 130 | # Only consider non padded tokens 131 | loss_mask = output_mask[..., :-1].contiguous() 132 | loss = torch.mul(loss_mask, loss) # [batch_size, max_length] 133 | return loss 134 | 135 | 136 | def main(): 137 | parser = argparse.ArgumentParser() 138 | 139 | # Required parameters 140 | parser.add_argument( 141 | "--out_dir", 142 | default=None, 143 | type=str, 144 | required=True, 145 | help="Out directory for checkpoints.", 146 | ) 147 | 148 | # Other parameters 149 | parser.add_argument( 150 | "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer." 151 | ) 152 | parser.add_argument( 153 | "--device", default="cpu", type=str, help="GPU number or 'cpu'." 154 | ) 155 | parser.add_argument( 156 | "--do_eval", action="store_true", help="Whether to run eval on the dev set." 157 | ) 158 | parser.add_argument( 159 | "--do_lower_case", 160 | action="store_true", 161 | help="Set this flag if you are using an uncased model.", 162 | ) 163 | parser.add_argument( 164 | "--do_train", action="store_true", help="Whether to run training." 165 | ) 166 | parser.add_argument( 167 | "--eval_batch_size", default=64, type=int, help="Batch size for evaluation." 168 | ) 169 | parser.add_argument( 170 | "--eval_data_file", 171 | type=str, 172 | required=True, 173 | help="The input CSV validation file." 174 | ) 175 | parser.add_argument( 176 | "--eval_during_train", 177 | action="store_true", 178 | help="Evaluate at each train logging step.", 179 | ) 180 | parser.add_argument( 181 | "--gradient_accumulation_steps", 182 | type=int, 183 | default=1, 184 | help="Steps before backward pass.", 185 | ) 186 | parser.add_argument( 187 | "--learning_rate", 188 | default=5e-6, 189 | type=float, 190 | help="The initial learning rate for Adam.", 191 | ) 192 | parser.add_argument( 193 | "--logging_steps", 194 | type=int, 195 | default=-1, 196 | help="Log every X updates steps (default after each epoch).", 197 | ) 198 | parser.add_argument( 199 | "--max_input_length", 200 | default=140, 201 | type=int, 202 | help="Maximum input event length in words.", 203 | ) 204 | parser.add_argument( 205 | "--max_output_length", 206 | default=120, 207 | type=int, 208 | help="Maximum output event length in words.", 209 | ) 210 | parser.add_argument( 211 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 212 | ) 213 | parser.add_argument( 214 | "--max_steps", 215 | default=-1, 216 | type=int, 217 | help="If > 0: total number of training steps to perform.", 218 | ) 219 | parser.add_argument( 220 | "--model_name_or_path", 221 | default="bart-large", 222 | type=str, 223 | help="LM checkpoint for initialization.", 224 | ) 225 | parser.add_argument( 226 | "--model_type", 227 | default="", 228 | type=str, 229 | help="which family of LM, e.g. gpt, gpt-xl, ....", 230 | ) 231 | parser.add_argument( 232 | "--num_train_epochs", 233 | default=2.0, 234 | type=float, 235 | help="Number of training epochs to perform.", 236 | ) 237 | parser.add_argument( 238 | "--overwrite_cache", action="store_true", help="Overwrite the cached data." 239 | ) 240 | parser.add_argument( 241 | "--overwrite_out_dir", 242 | action="store_true", 243 | help="Overwrite the output directory.", 244 | ) 245 | parser.add_argument( 246 | "--continue_training", 247 | action="store_true", 248 | help="Continue training from the last checkpoint.", 249 | ) 250 | parser.add_argument( 251 | "--save_steps", 252 | type=int, 253 | default=-1, 254 | help="Save checkpoint every X updates steps (default after each epoch).", 255 | ) 256 | parser.add_argument( 257 | "--save_total_limit", 258 | type=int, 259 | default=None, 260 | help="Maximum number of checkpoints to keep", 261 | ) 262 | parser.add_argument( 263 | "--seed", type=int, default=42, help="Random seed for initialization." 264 | ) 265 | parser.add_argument( 266 | "--train_batch_size", default=64, type=int, help="Batch size for training." 267 | ) 268 | parser.add_argument( 269 | "--train_file", 270 | type=str, 271 | required=False, 272 | help="The input CSV train file." 273 | ) 274 | parser.add_argument( 275 | "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps." 276 | ) 277 | parser.add_argument( 278 | "--weight_decay", default=0.0, type=float, help="Weight decay if we apply some." 279 | ) 280 | parser.add_argument( 281 | "--task", 282 | type=str, 283 | help="what is the task?" 284 | ) 285 | args = parser.parse_args() 286 | 287 | if args.eval_data_file is None and args.do_eval: 288 | raise ValueError( 289 | "Cannot do evaluation without an evaluation data file. Either supply --eval_data_file " 290 | "or remove the --do_eval argument." 291 | ) 292 | 293 | if ( 294 | os.path.exists(args.out_dir) 295 | and len(os.listdir(args.out_dir)) > 1 296 | and args.do_train 297 | and not args.overwrite_out_dir 298 | and not args.continue_training 299 | ): 300 | raise ValueError( 301 | f"Output directory {args.out_dir} already exists and is not empty. " 302 | f"Use --overwrite_out_dir or --continue_training." 303 | ) 304 | 305 | # Setup device 306 | device = torch.device( 307 | f"cuda:{args.device}" 308 | if torch.cuda.is_available() and args.device != "cpu" 309 | else "cpu" 310 | ) 311 | 312 | # Set seed 313 | set_seed(args) 314 | 315 | # Load the models 316 | if args.continue_training: 317 | args.model_name_or_path = args.out_dir 318 | # Delete the current results file 319 | else: 320 | eval_results_file = os.path.join(args.out_dir, "eval_results.txt") 321 | if os.path.exists(eval_results_file): 322 | os.remove(eval_results_file) 323 | 324 | args.device = "cpu" 325 | tokenizer, model = init_model( 326 | args.model_name_or_path, device=args.device, do_lower_case=args.do_lower_case, args = args 327 | ) 328 | 329 | args.pad_token_id = tokenizer.pad_token_id 330 | logger.info(f"Pad token ID: {args.pad_token_id}") 331 | args.block_size = tokenizer.max_len_single_sentence 332 | logger.info(f"Training/evaluation parameters {args}") 333 | 334 | eval_dataset = None 335 | if args.do_eval or args.eval_during_train: 336 | eval_dataset = EncoderDecoderTextDataset( 337 | tokenizer, args, file_path=args.eval_data_file, block_size=args.block_size) 338 | 339 | # Add special tokens (if loading a model before fine-tuning) 340 | if args.do_train and not args.continue_training: 341 | special_tokens = ["[shuffled]", "[orig]", ""] 342 | extra_specials = [f"" for i in range(args.max_output_length)] 343 | special_tokens += extra_specials 344 | tokenizer.pad_token = "" 345 | tokenizer.eos_token = "" 346 | tokenizer.add_tokens(special_tokens) 347 | model.resize_token_embeddings(len(tokenizer)) 348 | 349 | args.pad_token_id = tokenizer.pad_token_id 350 | 351 | # resize_token_embeddings for Bart doesn't work if the model is already on the device 352 | args.device = device 353 | model.to(args.device) 354 | 355 | # Training 356 | if args.do_train: 357 | train_dataset = EncoderDecoderTextDataset( 358 | tokenizer, 359 | args, 360 | file_path=args.train_file, 361 | block_size=args.block_size, 362 | ) 363 | global_step, tr_loss = train( 364 | args, 365 | train_dataset, 366 | model, 367 | tokenizer, 368 | loss_fnc=get_loss, 369 | eval_dataset=eval_dataset, 370 | ) 371 | logger.info(f" global_step = {global_step}, average loss = {tr_loss}") 372 | 373 | # Create output directory if needed 374 | if not os.path.exists(args.out_dir): 375 | os.makedirs(args.out_dir) 376 | 377 | logger.info(f"Saving model checkpoint to {args.out_dir}") 378 | 379 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 380 | # They can then be reloaded using `from_pretrained()` 381 | model_to_save = model.module if hasattr(model, "module") else model 382 | model_to_save.save_pretrained(args.out_dir) 383 | tokenizer.save_pretrained(args.out_dir) 384 | 385 | # Good practice: save your training arguments together with the trained model 386 | torch.save(args, os.path.join(args.out_dir, "training_args.bin")) 387 | 388 | # Load a trained model and vocabulary that you have fine-tuned 389 | tokenizer, model = init_model( 390 | args.out_dir, device=args.device, do_lower_case=args.do_lower_case, args=args 391 | ) 392 | args.block_size = tokenizer.max_len_single_sentence 393 | model.to(args.device) 394 | 395 | # Evaluation 396 | results = {} 397 | if args.do_eval: 398 | checkpoint = args.out_dir 399 | logger.info(f"Evaluate the following checkpoint: {checkpoint}") 400 | prefix = ( 401 | checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" 402 | ) 403 | _, model = init_model( 404 | checkpoint, device=args.device, do_lower_case=args.do_lower_case, args=args 405 | ) 406 | 407 | model.to(args.device) 408 | result = evaluate(eval_dataset, args, model, prefix=prefix, loss_fnc=get_loss) 409 | results.update(result) 410 | 411 | return results 412 | 413 | 414 | if __name__ == "__main__": 415 | main() 416 | --------------------------------------------------------------------------------