├── LICENSE ├── README.md ├── billboard ├── README.md ├── evaluate.py ├── generator-output.jsonl ├── output_scores_coherence.txt ├── output_scores_consistency.txt ├── output_scores_fluency.txt ├── output_scores_overall.txt ├── output_scores_relevance.txt ├── reference-file.jsonl ├── run.sh └── source-file.jsonl ├── evaluation_tasks ├── README.md ├── train_continual.sh ├── train_multi.sh └── train_seq2seq.py ├── examples.py ├── figures ├── UniEval.png ├── evaluation.png └── intermediate.png ├── intermediate_tasks ├── README.md ├── data_info.txt ├── train_inter.sh └── train_seq2seq.py ├── metric ├── evaluator.py └── scorer.py ├── pseudo_data_summ.py ├── reproduce ├── README.md ├── correlation.py ├── data │ ├── data2text │ │ ├── sfhot.json │ │ └── sfres.json │ ├── dialogue │ │ └── topical_chat.json │ ├── fact │ │ ├── qags_cnndm.json │ │ └── qags_xsum.json │ └── summarization │ │ └── summeval.json ├── data_utils.py ├── eval_data2text.sh ├── eval_dialogue.sh ├── eval_fact.sh ├── eval_summarization.sh ├── predict_score.py └── unieval_predict │ ├── data2text │ ├── sfhot_result.json │ └── sfres_result.json │ ├── dialogue │ └── topical_chat_result.json │ ├── fact │ ├── qags_cnndm_result.json │ └── qags_xsum_result.json │ └── summarization │ └── summeval_result.json ├── requirements.txt └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ming Zhong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniEval 2 | 3 | This repository maintains code, data and pre-trained evaluators for EMNLP 2022 paper 4 | 5 | *[Towards a Unified Multi-Dimensional Evaluator for Text Generation](https://arxiv.org/abs/2210.07197)* 6 | 7 | ## Overview 8 | 9 | **Multi-dimensional evaluation** is the dominant paradigm for human evaluation in Natural Language Generation (NLG), i.e., evaluating the generated text from multiple explainable dimensions, such as *coherence* and *fluency*. 10 | 11 | However, automatic evaluation in NLG is still dominated by similarity-based metrics (e.g., ROUGE, BLEU), but they are not sufficient to portray the difference between the advanced generation models. 12 | 13 | Therefore, we propose **UniEval** to bridge this gap so that a more comprehensive and fine-grained evaluation of NLG systems can be achieved. 14 | 15 | ## Method 16 |

17 | method 18 |

19 | 20 | We convert all evaluation tasks of different dimensions into Boolean QA problems and utilize the model to answer with “Yes” or “No”. 21 | 22 | 23 | This unified QA format allows the model to incorporate external knowledge from multiple related tasks, i.e., intermediate multi-task learning in the Figure. The code and data for intermediate pre-training can be found in the [intermediate_tasks](./intermediate_tasks) folder. 24 | 25 | Then we construct pseudo data for each dimension and train them sequentially to obtain **UniEval**. Details about unsupervised learning on evaluation tasks can be found in the [evaluation_tasks](./evaluation_tasks) folder. 26 | 27 | 28 | ## Get Multi-Dimenisonal Scores 29 | 30 | ### Environment 31 | ``` 32 | git clone https://github.com/maszhongming/UniEval.git 33 | cd UniEval 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ### Pre-trained Evaluators 38 | We release four pre-trained evaluators for different NLG tasks as follows: 39 | 40 | - [unieval-sum](https://huggingface.co/MingZhong/unieval-sum) evaluates *coherence*, *consistency*, *fluency* and *relevance* for text summarization. It can also used to evaluate *naturalness* and *informativeness* for data-to-text. 41 | - [unieval-dialog](https://huggingface.co/MingZhong/unieval-dialog) evaluates *naturalness*, *coherence*, *engagingness*, *groundedness* and *understandability* for dialogue response generation. 42 | - [unieval-fact](https://huggingface.co/MingZhong/unieval-fact) is specifically used to evaluate factual consistency. 43 | - [unieval-intermediate](https://huggingface.co/MingZhong/unieval-intermediate) is obtained after intermediate pre-training. It can be viewed as a Boolean answer generator. 44 | 45 | ### Get Scores for Summarization 46 | Example usage for summarization is shown below. 47 | ```python 48 | from utils import convert_to_json 49 | from metric.evaluator import get_evaluator 50 | 51 | task = 'summarization' 52 | 53 | # a list of source documents 54 | src_list = ['Peter and Elizabeth took a taxi to attend the night party in the city. \ 55 | While in the party, Elizabeth collapsed and was rushed to the hospital.'] 56 | # a list of human-annotated reference summaries 57 | ref_list = ['Elizabeth was hospitalized after attending a party with Peter.'] 58 | # a list of model outputs to be evaluataed 59 | output_list = ['Peter and Elizabeth attend party city. Elizabeth rushed hospital.'] 60 | 61 | # Prepare data for pre-trained evaluators 62 | data = convert_to_json(output_list=output_list, 63 | src_list=src_list, ref_list=ref_list) 64 | # Initialize evaluator for a specific task 65 | evaluator = get_evaluator(task) 66 | # Get multi-dimensional evaluation scores 67 | eval_scores = evaluator.evaluate(data, print_result=True) 68 | ``` 69 | eval_scores contains the scores of all dimensions for each sample. The printed average scores should look like: 70 | ``` 71 | +-------------+----------+ 72 | | Dimensions | Score | 73 | +-------------+----------+ 74 | | coherence | 0.948185 | 75 | | consistency | 0.883036 | 76 | | fluency | 0.42928 | 77 | | relevance | 0.636075 | 78 | | overall | 0.724144 | 79 | +-------------+----------+ 80 | ``` 81 | Overall score here can be customized as a combination of scores based on different dimensions. The default is the average score of all dimensions. 82 | 83 | Notably, because the different dimensions have different focuses, they usually require different content as input. For summarization, the inputs when evaluating the four dimensions are as follows: 84 | 85 | - *coherence*: output_list, src_list 86 | - *consistency*: output_list, src_list 87 | - *fluency*: output_list 88 | - *relevance*: output_list, ref_list 89 | 90 | Therefore, **UniEval** is a reference-free evaluator in all dimensions except *relevance*. So it is also possible to evaluate the generated summaries without reference as: 91 | 92 | ```python 93 | eval_scores = evaluator.evaluate(data, dims=['coherence', 'consistency', 'fluency'], 94 | overall=False, print_result=True) 95 | ``` 96 | 97 | ### Get Scores for Dialogue 98 | Example usage for dialogue response generation is shown below. 99 | ```python 100 | from utils import convert_to_json 101 | from metric.evaluator import get_evaluator 102 | 103 | task = 'dialogue' 104 | 105 | # a list of dialogue histories 106 | src_list = ['hi , do you know much about the internet ? \n i know a lot about different sites and some website design , how about you ? \n\n'] 107 | # a list of additional context that should be included into the generated response 108 | context_list = ['the 3 horizontal line menu on apps and websites is called a hamburger button .\n'] 109 | # a list of model outputs to be evaluated 110 | output_list = ['i do too . did you know the 3 horizontal line menu on apps and websites is called the hamburger button ?'] 111 | 112 | # Prepare data for pre-trained evaluators 113 | data = convert_to_json(output_list=output_list, 114 | src_list=src_list, context_list=context_list) 115 | # Initialize evaluator for a specific task 116 | evaluator = get_evaluator(task) 117 | # Get multi-dimensional evaluation scores 118 | eval_scores = evaluator.evaluate(data, print_result=True) 119 | ``` 120 | The results should be: 121 | ``` 122 | +-------------------+----------+ 123 | | Dimensions | Score | 124 | +-------------------+----------+ 125 | | naturalness | 0.950218 | 126 | | coherence | 0.973135 | 127 | | engagingness | 1.750486 | 128 | | groundedness | 0.999566 | 129 | | understandability | 0.946209 | 130 | | overall | 1.123923 | 131 | +-------------------+----------+ 132 | ``` 133 | *engagingness* is the only dimension that uses summation scores, as it indicates the total volume of interesting fact presented in the response. Therefore, the scoring range for *engagingness* is [0, +∞), while all others are [0, 1]. 134 | 135 | Please keep the format of the input dialogue consistent with [topical_chat.json](./reproduce/data/dialogue/topical_chat.json), i.e. use `\n` to separate the different turns in the dialogue history and end it with `\n\n`. In addition, each context also ends with `\n`. 136 | 137 | **UniEval** is a reference-free evaluator for dialogue response generation. The input content for each dimension is: 138 | 139 | - *naturalness*: output_list 140 | - *coherence*: output_list, src_list 141 | - *engagingness*: output_list, src_list, context_list 142 | - *groundedness*: output_list, context_list 143 | - *understandability*: output_list 144 | 145 | ### Get Factual Consistency Score 146 | **UniEval** can also act as a high-performance single-dimensional evaluator, such as achieving the best correlation when evaluating factual consistency (see Tables 3 and 9 in the paper). Example usage for factual consistency detection is shown below. 147 | ```python 148 | from utils import convert_to_json 149 | from metric.evaluator import get_evaluator 150 | 151 | task = 'fact' 152 | 153 | # a list of source documents 154 | src_list = ['Peter and Elizabeth took a taxi to attend the night party in the city. \ 155 | While in the party, Elizabeth collapsed and was rushed to the hospital.'] 156 | # a list of model outputs (claims) to be evaluataed 157 | output_list = ['Tom was rushed to hospital.'] 158 | 159 | # Prepare data for pre-trained evaluators 160 | data = convert_to_json(output_list=output_list, src_list=src_list) 161 | # Initialize evaluator for a specific task 162 | evaluator = get_evaluator(task) 163 | # Get factual consistency scores 164 | eval_scores = evaluator.evaluate(data, print_result=True) 165 | ``` 166 | The results only include one dimension: 167 | ``` 168 | +-------------+----------+ 169 | | Dimensions | Score | 170 | +-------------+----------+ 171 | | consistency | 0.025441 | 172 | +-------------+----------+ 173 | ``` 174 | 175 | ### Transfer to Other NLG Tasks 176 | **UniEval** also demonstrates the ability to transfer to new NLG tasks. We provide instructions for two scenarios: 177 | 178 | 1. Transfer to other dimensions 179 | 180 | (a) If the new dimension is close to one of UniEval's existing dimensions, you can directly evaluate it with the corresponding evaluator and specify the desired dimension. 181 | 182 | (b) If the new dimension requires a different input or question description, please modify the `add_question` function in [utils.py](./utils.py) and select an evaluator of a similar task for evaluation. 183 | 184 | 2. Transfer to other generation tasks 185 | 186 | We take the data-to-text task as an example to show how to transfer UniEval to an unseen task. 187 | 188 | (1) Create a task-specific evaluator in [metric/evaluator.py](./metric/evaluator.py), initializing it by specifying the pre-trained evaluator used and the dimensions to be evaluated. All required content should be inputted in the `self.evaluate()` function. Details can refer to `D2tEvaluator` in [metric/evaluator.py](./metric/evaluator.py). 189 | 190 | (2) Specify the required content and a specific question description for each dimension in `add_question`. They form the input to the evaluator. The input format for evaluating *naturalness* and *informativeness* in the data-to-text task can be found in [utils.py](./utils.py). 191 | 192 | (3) As in [examples.py](./examples.py), multi-dimensional evaluation scores can be obtained. 193 | 194 | 195 | ## Reproduce 196 | 197 | To reproduce all the results in the paper, we provide all meta-evaluation datasets, codes, and evaluation scores predicted by **UniEval** in the folder [reproduce](./reproduce). 198 | 199 | 200 | -------------------------------------------------------------------------------- /billboard/README.md: -------------------------------------------------------------------------------- 1 | # BillBoard 2 | To submit UniEval to [Bidimensional Leaderboards](https://nlp.cs.washington.edu/billboard/#tasks/cnndm/metrics.html) for summarization, we provide the relevant code here. 3 | 4 | The input should contain three files, `source-file.jsonl`, `generator-output.jsonl`, and `reference-file.jsonl`. Then please run the following script: 5 | ``` 6 | ./run.sh 7 | ``` 8 | The results will be presented in five files, representing the scores of each model output in different dimensions (*fluency*, *coherence*, *consistency*, *relevance* and *overall*). 9 | -------------------------------------------------------------------------------- /billboard/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import argparse 4 | sys.path.append("..") 5 | from utils import convert_to_json 6 | from metric.evaluator import get_evaluator 7 | 8 | def load_src(src_path): 9 | src_list = [] 10 | with open(src_path) as f: 11 | for line in f: 12 | data = json.loads(line) 13 | src_list.append(data['src']) 14 | return src_list 15 | 16 | def load_ref(ref_path): 17 | ref_list = [] 18 | with open("reference-file.jsonl")as f: 19 | for line in f: 20 | data = json.loads(line) 21 | ref_list.append(data['ref'][0]) 22 | return ref_list 23 | 24 | def load_output(output_path): 25 | output_list = [] 26 | with open("generator-output.jsonl") as f: 27 | for line in f: 28 | data = json.loads(line) 29 | output_list.append(data['hyp']) 30 | return output_list 31 | 32 | def evaluate(args): 33 | # load data 34 | src_list = load_src(args.src_path) 35 | ref_list = load_ref(args.ref_path) 36 | output_list = load_output(args.hyp_path) 37 | 38 | # Prepare data for pre-trained evaluators 39 | data = convert_to_json(output_list=output_list, 40 | src_list=src_list, ref_list=ref_list) 41 | 42 | # Initialize evaluator for a specific task 43 | evaluator = get_evaluator(task=args.task, 44 | max_length=args.max_source_length, 45 | device=args.device, 46 | cache_dir=args.cache_dir) 47 | 48 | # Get multi-dimensional evaluation scores 49 | eval_scores = evaluator.evaluate(data, print_result=False) 50 | 51 | # Write predicted scores for all dimensions 52 | dims = ['fluency', 'coherence', 'consistency', 'relevance', 'overall'] 53 | for dim in dims: 54 | with open('output_scores_{}.txt'.format(dim), 'w') as f: 55 | for i in range(len(eval_scores)): 56 | print(eval_scores[i][dim], file=f) 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser( 60 | description='Get evaluation scores from UniEval from different NLG tasks' 61 | ) 62 | 63 | parser.add_argument('--src_path', required=True, 64 | help='Path to the source files', type=str) 65 | parser.add_argument('--ref_path', required=True, 66 | help='Path to the reference files', type=str) 67 | parser.add_argument('--hyp_path', required=True, 68 | help='Path to the generated files', type=str) 69 | parser.add_argument('--task', default='summarization', 70 | help='Specific NLG task to be evaluated', type=str) 71 | parser.add_argument('--cache_dir', default=None, 72 | help='Where to store the pretrained models downloaded from huggingface.co', type=str) 73 | parser.add_argument('--device', default='cuda:0', 74 | help='Available device for the calculations', type=str) 75 | parser.add_argument('--max_source_length', default=1024, 76 | help='The maximum total input sequence length after tokenization', type=int) 77 | 78 | args = parser.parse_args() 79 | 80 | evaluate(args) -------------------------------------------------------------------------------- /billboard/generator-output.jsonl: -------------------------------------------------------------------------------- 1 | {"hyp": "Paul merson was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . Andros townsend scored the tottenham midfielder in the 89th minute . Paul merson had another dig at andros townsend after his appearance . The midfielder had been brought on to the england squad last week . Click here for all the latest arsenal news news ."} 2 | {"hyp": "Paul merson has restarted his row with andros townsend . The tottenham midfielder was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley . Andros townsend scores england 's equaliser in their 1-1 friendly draw with italy in turin ."} 3 | {"hyp": "Paul merson has restarted his row with andros townsend after the tottenham midfielder was brought on with only seven minutes remaining in his team 's 0-0 draw with burnley on sunday . Townsend was brought on in the 83rd minute for tottenham as they drew 0-0 against burnley . Townsend hit back at merson on twitter after scoring for england against italy ."} 4 | -------------------------------------------------------------------------------- /billboard/output_scores_coherence.txt: -------------------------------------------------------------------------------- 1 | 0.11246328216251741 2 | 0.2524910531423081 3 | 0.875345771739276 4 | -------------------------------------------------------------------------------- /billboard/output_scores_consistency.txt: -------------------------------------------------------------------------------- 1 | 0.5720639343819058 2 | 0.9295646026501481 3 | 0.9273843716661299 4 | -------------------------------------------------------------------------------- /billboard/output_scores_fluency.txt: -------------------------------------------------------------------------------- 1 | 0.5960423377733144 2 | 0.9154313160577754 3 | 0.9303071418755243 4 | -------------------------------------------------------------------------------- /billboard/output_scores_overall.txt: -------------------------------------------------------------------------------- 1 | 0.33479411447508534 2 | 0.5296161622650459 3 | 0.8088783258869782 4 | -------------------------------------------------------------------------------- /billboard/output_scores_relevance.txt: -------------------------------------------------------------------------------- 1 | 0.05860690358260355 2 | 0.02097767720995224 3 | 0.5024760182669827 4 | -------------------------------------------------------------------------------- /billboard/reference-file.jsonl: -------------------------------------------------------------------------------- 1 | {"ref": ["Andros Townsend an 83rd minute sub in Tottenham 's draw with Burnley . He was unable to find a winner as the game ended without a goal . Townsend had clashed with Paul Merson last week over England call-up ."]} 2 | {"ref": ["Andros Townsend an 83rd minute sub in Tottenham 's draw with Burnley . He was unable to find a winner as the game ended without a goal . Townsend had clashed with Paul Merson last week over England call-up ."]} 3 | {"ref": ["Andros Townsend an 83rd minute sub in Tottenham 's draw with Burnley . He was unable to find a winner as the game ended without a goal . Townsend had clashed with Paul Merson last week over England call-up ."]} 4 | -------------------------------------------------------------------------------- /billboard/run.sh: -------------------------------------------------------------------------------- 1 | python evaluate.py \ 2 | --src_path source-file.jsonl \ 3 | --ref_path reference-file.jsonl \ 4 | --hyp_path generator-output.jsonl \ 5 | -------------------------------------------------------------------------------- /billboard/source-file.jsonl: -------------------------------------------------------------------------------- 1 | {"src": "Paul Merson has restarted his row with Andros Townsend after the Tottenham midfielder was brought on with only seven minutes remaining in his team 's 0-0 draw with Burnley on Sunday . 'Just been watching the game , did you miss the coach ? # RubberDub # 7minutes , ' Merson put on Twitter . Merson initially angered Townsend for writing in his Sky Sports column that 'if Andros Townsend can get in ( the England team ) then it opens it up to anybody . ' Paul Merson had another dig at Andros Townsend after his appearance for Tottenham against Burnley Townsend was brought on in the 83rd minute for Tottenham as they drew 0-0 against Burnley Andros Townsend scores England 's equaliser in their 1-1 friendly draw with Italy in Turin on Tuesday night The former Arsenal man was proven wrong when Townsend hit a stunning equaliser for England against Italy and he duly admitted his mistake . 'It 's not as though I was watching hoping he would n't score for England , I 'm genuinely pleased for him and fair play to him \u2013 it was a great goal , ' Merson said . 'It 's just a matter of opinion , and my opinion was that he got pulled off after half an hour at Manchester United in front of Roy Hodgson , so he should n't have been in the squad . 'When I 'm wrong , I hold my hands up . I do n't have a problem with doing that - I 'll always be the first to admit when I 'm wrong . ' Townsend hit back at Merson on Twitter after scoring for England against Italy Sky Sports pundit Merson ( centre ) criticised Townsend 's call-up to the England squad last week Townsend hit back at Merson after netting for England in Turin on Wednesday , saying 'Not bad for a player that should be 'nowhere near the squad ' ay @ PaulMerse ? ' Any bad feeling between the pair seemed to have passed but Merson was unable to resist having another dig at Townsend after Tottenham drew at Turf Moor ."} 2 | {"src": "Paul Merson has restarted his row with Andros Townsend after the Tottenham midfielder was brought on with only seven minutes remaining in his team 's 0-0 draw with Burnley on Sunday . 'Just been watching the game , did you miss the coach ? # RubberDub # 7minutes , ' Merson put on Twitter . Merson initially angered Townsend for writing in his Sky Sports column that 'if Andros Townsend can get in ( the England team ) then it opens it up to anybody . ' Paul Merson had another dig at Andros Townsend after his appearance for Tottenham against Burnley Townsend was brought on in the 83rd minute for Tottenham as they drew 0-0 against Burnley Andros Townsend scores England 's equaliser in their 1-1 friendly draw with Italy in Turin on Tuesday night The former Arsenal man was proven wrong when Townsend hit a stunning equaliser for England against Italy and he duly admitted his mistake . 'It 's not as though I was watching hoping he would n't score for England , I 'm genuinely pleased for him and fair play to him \u2013 it was a great goal , ' Merson said . 'It 's just a matter of opinion , and my opinion was that he got pulled off after half an hour at Manchester United in front of Roy Hodgson , so he should n't have been in the squad . 'When I 'm wrong , I hold my hands up . I do n't have a problem with doing that - I 'll always be the first to admit when I 'm wrong . ' Townsend hit back at Merson on Twitter after scoring for England against Italy Sky Sports pundit Merson ( centre ) criticised Townsend 's call-up to the England squad last week Townsend hit back at Merson after netting for England in Turin on Wednesday , saying 'Not bad for a player that should be 'nowhere near the squad ' ay @ PaulMerse ? ' Any bad feeling between the pair seemed to have passed but Merson was unable to resist having another dig at Townsend after Tottenham drew at Turf Moor ."} 3 | {"src": "Paul Merson has restarted his row with Andros Townsend after the Tottenham midfielder was brought on with only seven minutes remaining in his team 's 0-0 draw with Burnley on Sunday . 'Just been watching the game , did you miss the coach ? # RubberDub # 7minutes , ' Merson put on Twitter . Merson initially angered Townsend for writing in his Sky Sports column that 'if Andros Townsend can get in ( the England team ) then it opens it up to anybody . ' Paul Merson had another dig at Andros Townsend after his appearance for Tottenham against Burnley Townsend was brought on in the 83rd minute for Tottenham as they drew 0-0 against Burnley Andros Townsend scores England 's equaliser in their 1-1 friendly draw with Italy in Turin on Tuesday night The former Arsenal man was proven wrong when Townsend hit a stunning equaliser for England against Italy and he duly admitted his mistake . 'It 's not as though I was watching hoping he would n't score for England , I 'm genuinely pleased for him and fair play to him \u2013 it was a great goal , ' Merson said . 'It 's just a matter of opinion , and my opinion was that he got pulled off after half an hour at Manchester United in front of Roy Hodgson , so he should n't have been in the squad . 'When I 'm wrong , I hold my hands up . I do n't have a problem with doing that - I 'll always be the first to admit when I 'm wrong . ' Townsend hit back at Merson on Twitter after scoring for England against Italy Sky Sports pundit Merson ( centre ) criticised Townsend 's call-up to the England squad last week Townsend hit back at Merson after netting for England in Turin on Wednesday , saying 'Not bad for a player that should be 'nowhere near the squad ' ay @ PaulMerse ? ' Any bad feeling between the pair seemed to have passed but Merson was unable to resist having another dig at Townsend after Tottenham drew at Turf Moor ."} 4 | -------------------------------------------------------------------------------- /evaluation_tasks/README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Learning on Evaluation Tasks 2 | 3 |

4 | evaluation 5 |

6 | 7 | Based on the Boolean Answer Generator, we then construct pseudo data for each dimension and train them sequentially to obtain UniEval. 8 | 9 | ## Pseudo Data 10 | All the pseudo data for summarization and dialogue response generation can be found [here](https://drive.google.com/file/d/1SHsPPNvEAFNQToCdAFLhPulvQ6jEHdA5/view?usp=sharing). Please unzip it and put it in `./data`. 11 | 12 | ## Training 13 | We use two strategies to train UniEval: Multi-task Learning and Continual Learning. 14 | 15 | ### Multi-task Learning 16 | Run the following script to conduct multi-task learning: 17 | ```bash 18 | export TOKENIZERS_PARALLELISM=true 19 | export OMP_NUM_THREADS=1 20 | 21 | CUDA_VISIBLE_DEVICES=0,1 \ 22 | python -m torch.distributed.launch --nproc_per_node 2 train_seq2seq.py \ 23 | --model_name_or_path MingZhong/unieval-intermediate \ 24 | --do_train \ 25 | --train_file data/summarization/train_all.json \ 26 | --text_column src \ 27 | --summary_column tgt \ 28 | --output_dir ./multitask_summ \ 29 | --per_device_train_batch_size 3 \ 30 | --gradient_accumulation_steps 6 \ 31 | --max_source_length 1024 \ 32 | --max_target_length 16 \ 33 | --save_strategy steps \ 34 | --save_steps 2000 \ 35 | --num_train_epochs 3 \ 36 | --ddp_find_unused_parameters False \ 37 | ``` 38 | 39 | ### Continual Learning 40 | Run the following script to perform continual learning: 41 | ```bash 42 | export TOKENIZERS_PARALLELISM=true 43 | export OMP_NUM_THREADS=1 44 | 45 | CUDA_VISIBLE_DEVICES=0,1 \ 46 | python -m torch.distributed.launch --nproc_per_node 2 train_seq2seq.py \ 47 | --model_name_or_path MingZhong/unieval-intermediate \ 48 | --do_train \ 49 | --train_file data/summarization/coherence_3w.json \ 50 | --text_column src \ 51 | --summary_column tgt \ 52 | --output_dir ./continual_summ_coherence \ 53 | --per_device_train_batch_size 3 \ 54 | --gradient_accumulation_steps 6 \ 55 | --max_source_length 1024 \ 56 | --max_target_length 16 \ 57 | --save_strategy steps \ 58 | --save_steps 500 \ 59 | --num_train_epochs 3 \ 60 | --ddp_find_unused_parameters False \ 61 | 62 | ``` 63 | - After training on *coherence*, we need to continue training for *fluency* based on the obtained checkpoint. In this case, the input data are randomly sampled 20% `coherence_3w.json` (replay data) and 100% `fluency_3w.json`. 64 | - Repeating the above process and training the four dimensions sequentially, we can finally obtain the evaluator for summarization. 65 | - Training order for summarization: *coherence* → *fluency* → *consistency* → *relevance* 66 | - Training order for dialogue response generation: *coherence* → *naturalness* → *groundedness* → *engagingness* -------------------------------------------------------------------------------- /evaluation_tasks/train_continual.sh: -------------------------------------------------------------------------------- 1 | export TOKENIZERS_PARALLELISM=true 2 | export OMP_NUM_THREADS=1 3 | 4 | CUDA_VISIBLE_DEVICES=0,1 \ 5 | python -m torch.distributed.launch --nproc_per_node 2 train_seq2seq.py \ 6 | --model_name_or_path MingZhong/unieval-intermediate \ 7 | --do_train \ 8 | --train_file data/summarization/coherence_3w.json \ 9 | --text_column src \ 10 | --summary_column tgt \ 11 | --output_dir ./continual_summ_coherence \ 12 | --per_device_train_batch_size 3 \ 13 | --gradient_accumulation_steps 6 \ 14 | --max_source_length 1024 \ 15 | --max_target_length 16 \ 16 | --save_strategy steps \ 17 | --save_steps 500 \ 18 | --num_train_epochs 3 \ 19 | --ddp_find_unused_parameters False \ 20 | -------------------------------------------------------------------------------- /evaluation_tasks/train_multi.sh: -------------------------------------------------------------------------------- 1 | export TOKENIZERS_PARALLELISM=true 2 | export OMP_NUM_THREADS=1 3 | 4 | CUDA_VISIBLE_DEVICES=0,1 \ 5 | python -m torch.distributed.launch --nproc_per_node 2 train_seq2seq.py \ 6 | --model_name_or_path MingZhong/unieval-intermediate \ 7 | --do_train \ 8 | --train_file data/summarization/train_all.json \ 9 | --text_column src \ 10 | --summary_column tgt \ 11 | --output_dir ./multitask_summ \ 12 | --per_device_train_batch_size 3 \ 13 | --gradient_accumulation_steps 6 \ 14 | --max_source_length 1024 \ 15 | --max_target_length 16 \ 16 | --save_strategy steps \ 17 | --save_steps 2000 \ 18 | --num_train_epochs 3 \ 19 | --ddp_find_unused_parameters False \ 20 | -------------------------------------------------------------------------------- /evaluation_tasks/train_seq2seq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import datasets 28 | import nltk # Here to have a nice missing dependency error message early on 29 | import numpy as np 30 | from datasets import load_dataset, load_metric 31 | 32 | import transformers 33 | from filelock import FileLock 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForSeq2SeqLM, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | MBart50Tokenizer, 41 | MBart50TokenizerFast, 42 | MBartTokenizer, 43 | MBartTokenizerFast, 44 | Seq2SeqTrainer, 45 | Seq2SeqTrainingArguments, 46 | set_seed, 47 | ) 48 | from transformers.file_utils import is_offline_mode 49 | from transformers.trainer_utils import get_last_checkpoint 50 | from transformers.utils import check_min_version 51 | from transformers.utils.versions import require_version 52 | 53 | 54 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 55 | check_min_version("4.17.0.dev0") 56 | 57 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 58 | 59 | logger = logging.getLogger(__name__) 60 | 61 | try: 62 | nltk.data.find("tokenizers/punkt") 63 | except (LookupError, OSError): 64 | if is_offline_mode(): 65 | raise LookupError( 66 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 67 | ) 68 | with FileLock(".lock") as lock: 69 | nltk.download("punkt", quiet=True) 70 | 71 | # A list of all multilingual tokenizer which require lang attribute. 72 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] 73 | 74 | 75 | @dataclass 76 | class ModelArguments: 77 | """ 78 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 79 | """ 80 | 81 | model_name_or_path: str = field( 82 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 83 | ) 84 | config_name: Optional[str] = field( 85 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 86 | ) 87 | tokenizer_name: Optional[str] = field( 88 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 89 | ) 90 | cache_dir: Optional[str] = field( 91 | default=None, 92 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 93 | ) 94 | use_fast_tokenizer: bool = field( 95 | default=True, 96 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 97 | ) 98 | model_revision: str = field( 99 | default="main", 100 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 101 | ) 102 | use_auth_token: bool = field( 103 | default=False, 104 | metadata={ 105 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 106 | "with private models)." 107 | }, 108 | ) 109 | resize_position_embeddings: Optional[bool] = field( 110 | default=None, 111 | metadata={ 112 | "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 113 | "the model's position embeddings." 114 | }, 115 | ) 116 | 117 | 118 | @dataclass 119 | class DataTrainingArguments: 120 | """ 121 | Arguments pertaining to what data we are going to input our model for training and eval. 122 | """ 123 | 124 | lang: str = field(default=None, metadata={"help": "Language id for summarization."}) 125 | 126 | dataset_name: Optional[str] = field( 127 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 128 | ) 129 | dataset_config_name: Optional[str] = field( 130 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 131 | ) 132 | text_column: Optional[str] = field( 133 | default=None, 134 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 135 | ) 136 | summary_column: Optional[str] = field( 137 | default=None, 138 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 139 | ) 140 | train_file: Optional[str] = field( 141 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 142 | ) 143 | validation_file: Optional[str] = field( 144 | default=None, 145 | metadata={ 146 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 147 | "(a jsonlines or csv file)." 148 | }, 149 | ) 150 | test_file: Optional[str] = field( 151 | default=None, 152 | metadata={ 153 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 154 | }, 155 | ) 156 | overwrite_cache: bool = field( 157 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 158 | ) 159 | preprocessing_num_workers: Optional[int] = field( 160 | default=None, 161 | metadata={"help": "The number of processes to use for the preprocessing."}, 162 | ) 163 | max_source_length: Optional[int] = field( 164 | default=1024, 165 | metadata={ 166 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 167 | "than this will be truncated, sequences shorter will be padded." 168 | }, 169 | ) 170 | max_target_length: Optional[int] = field( 171 | default=128, 172 | metadata={ 173 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 174 | "than this will be truncated, sequences shorter will be padded." 175 | }, 176 | ) 177 | val_max_target_length: Optional[int] = field( 178 | default=None, 179 | metadata={ 180 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 181 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 182 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 183 | "during ``evaluate`` and ``predict``." 184 | }, 185 | ) 186 | pad_to_max_length: bool = field( 187 | default=False, 188 | metadata={ 189 | "help": "Whether to pad all samples to model maximum sentence length. " 190 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 191 | "efficient on GPU but very bad for TPU." 192 | }, 193 | ) 194 | max_train_samples: Optional[int] = field( 195 | default=None, 196 | metadata={ 197 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 198 | "value if set." 199 | }, 200 | ) 201 | max_eval_samples: Optional[int] = field( 202 | default=None, 203 | metadata={ 204 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 205 | "value if set." 206 | }, 207 | ) 208 | max_predict_samples: Optional[int] = field( 209 | default=None, 210 | metadata={ 211 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 212 | "value if set." 213 | }, 214 | ) 215 | num_beams: Optional[int] = field( 216 | default=None, 217 | metadata={ 218 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 219 | "which is used during ``evaluate`` and ``predict``." 220 | }, 221 | ) 222 | ignore_pad_token_for_loss: bool = field( 223 | default=True, 224 | metadata={ 225 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 226 | }, 227 | ) 228 | source_prefix: Optional[str] = field( 229 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 230 | ) 231 | 232 | forced_bos_token: Optional[str] = field( 233 | default=None, 234 | metadata={ 235 | "help": "The token to force as the first generated token after the decoder_start_token_id." 236 | "Useful for multilingual models like mBART where the first generated token" 237 | "needs to be the target language token (Usually it is the target language token)" 238 | }, 239 | ) 240 | 241 | def __post_init__(self): 242 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 243 | raise ValueError("Need either a dataset name or a training/validation file.") 244 | else: 245 | if self.train_file is not None: 246 | extension = self.train_file.split(".")[-1] 247 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 248 | if self.validation_file is not None: 249 | extension = self.validation_file.split(".")[-1] 250 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 251 | if self.val_max_target_length is None: 252 | self.val_max_target_length = self.max_target_length 253 | 254 | 255 | summarization_name_mapping = { 256 | "amazon_reviews_multi": ("review_body", "review_title"), 257 | "big_patent": ("description", "abstract"), 258 | "cnn_dailymail": ("article", "highlights"), 259 | "orange_sum": ("text", "summary"), 260 | "pn_summary": ("article", "summary"), 261 | "psc": ("extract_text", "summary_text"), 262 | "samsum": ("dialogue", "summary"), 263 | "thaisum": ("body", "summary"), 264 | "xglue": ("news_body", "news_title"), 265 | "xsum": ("document", "summary"), 266 | "wiki_summary": ("article", "highlights"), 267 | } 268 | 269 | 270 | def main(): 271 | # See all possible arguments in src/transformers/training_args.py 272 | # or by passing the --help flag to this script. 273 | # We now keep distinct sets of args, for a cleaner separation of concerns. 274 | 275 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 276 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 277 | # If we pass only one argument to the script and it's the path to a json file, 278 | # let's parse it to get our arguments. 279 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 280 | else: 281 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 282 | 283 | # Setup logging 284 | logging.basicConfig( 285 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 286 | datefmt="%m/%d/%Y %H:%M:%S", 287 | handlers=[logging.StreamHandler(sys.stdout)], 288 | ) 289 | log_level = training_args.get_process_log_level() 290 | logger.setLevel(log_level) 291 | datasets.utils.logging.set_verbosity(log_level) 292 | transformers.utils.logging.set_verbosity(log_level) 293 | transformers.utils.logging.enable_default_handler() 294 | transformers.utils.logging.enable_explicit_format() 295 | 296 | # Log on each process the small summary: 297 | logger.warning( 298 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 299 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 300 | ) 301 | logger.info(f"Training/evaluation parameters {training_args}") 302 | 303 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 304 | "t5-small", 305 | "t5-base", 306 | "t5-large", 307 | "t5-3b", 308 | "t5-11b", 309 | ]: 310 | logger.warning( 311 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 312 | "`--source_prefix 'summarize: ' `" 313 | ) 314 | 315 | # Detecting last checkpoint. 316 | last_checkpoint = None 317 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 318 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 319 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 320 | raise ValueError( 321 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 322 | "Use --overwrite_output_dir to overcome." 323 | ) 324 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 325 | logger.info( 326 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 327 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 328 | ) 329 | 330 | # Set seed before initializing model. 331 | set_seed(training_args.seed) 332 | 333 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 334 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 335 | # (the dataset will be downloaded automatically from the datasets Hub). 336 | # 337 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 338 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 339 | # 340 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 341 | # download the dataset. 342 | if data_args.dataset_name is not None: 343 | # Downloading and loading a dataset from the hub. 344 | raw_datasets = load_dataset( 345 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 346 | ) 347 | else: 348 | data_files = {} 349 | if data_args.train_file is not None: 350 | data_files["train"] = data_args.train_file 351 | extension = data_args.train_file.split(".")[-1] 352 | if data_args.validation_file is not None: 353 | data_files["validation"] = data_args.validation_file 354 | extension = data_args.validation_file.split(".")[-1] 355 | if data_args.test_file is not None: 356 | data_files["test"] = data_args.test_file 357 | extension = data_args.test_file.split(".")[-1] 358 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 359 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 360 | # https://huggingface.co/docs/datasets/loading_datasets.html. 361 | 362 | # Load pretrained model and tokenizer 363 | # 364 | # Distributed training: 365 | # The .from_pretrained methods guarantee that only one local process can concurrently 366 | # download model & vocab. 367 | config = AutoConfig.from_pretrained( 368 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 369 | cache_dir=model_args.cache_dir, 370 | revision=model_args.model_revision, 371 | use_auth_token=True if model_args.use_auth_token else None, 372 | ) 373 | tokenizer = AutoTokenizer.from_pretrained( 374 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 375 | cache_dir=model_args.cache_dir, 376 | use_fast=model_args.use_fast_tokenizer, 377 | revision=model_args.model_revision, 378 | use_auth_token=True if model_args.use_auth_token else None, 379 | ) 380 | model = AutoModelForSeq2SeqLM.from_pretrained( 381 | model_args.model_name_or_path, 382 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 383 | config=config, 384 | cache_dir=model_args.cache_dir, 385 | revision=model_args.model_revision, 386 | use_auth_token=True if model_args.use_auth_token else None, 387 | ) 388 | 389 | model.resize_token_embeddings(len(tokenizer)) 390 | 391 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): 392 | if isinstance(tokenizer, MBartTokenizer): 393 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] 394 | else: 395 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang) 396 | 397 | if model.config.decoder_start_token_id is None: 398 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 399 | 400 | if ( 401 | hasattr(model.config, "max_position_embeddings") 402 | and model.config.max_position_embeddings < data_args.max_source_length 403 | ): 404 | if model_args.resize_position_embeddings is None: 405 | logger.warning( 406 | f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} " 407 | f"to {data_args.max_source_length}." 408 | ) 409 | model.resize_position_embeddings(data_args.max_source_length) 410 | elif model_args.resize_position_embeddings: 411 | model.resize_position_embeddings(data_args.max_source_length) 412 | else: 413 | raise ValueError( 414 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}" 415 | f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically " 416 | "resize the model's position encodings by passing `--resize_position_embeddings`." 417 | ) 418 | 419 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 420 | 421 | # Preprocessing the datasets. 422 | # We need to tokenize inputs and targets. 423 | if training_args.do_train: 424 | column_names = raw_datasets["train"].column_names 425 | elif training_args.do_eval: 426 | column_names = raw_datasets["validation"].column_names 427 | elif training_args.do_predict: 428 | column_names = raw_datasets["test"].column_names 429 | else: 430 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 431 | return 432 | 433 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): 434 | assert ( 435 | data_args.lang is not None 436 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" 437 | 438 | tokenizer.src_lang = data_args.lang 439 | tokenizer.tgt_lang = data_args.lang 440 | 441 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token 442 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. 443 | forced_bos_token_id = ( 444 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None 445 | ) 446 | model.config.forced_bos_token_id = forced_bos_token_id 447 | 448 | # Get the column names for input/target. 449 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 450 | if data_args.text_column is None: 451 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 452 | else: 453 | text_column = data_args.text_column 454 | if text_column not in column_names: 455 | raise ValueError( 456 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 457 | ) 458 | if data_args.summary_column is None: 459 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 460 | else: 461 | summary_column = data_args.summary_column 462 | if summary_column not in column_names: 463 | raise ValueError( 464 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 465 | ) 466 | 467 | # Temporarily set max_target_length for training. 468 | max_target_length = data_args.max_target_length 469 | padding = "max_length" if data_args.pad_to_max_length else False 470 | 471 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 472 | logger.warning( 473 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 474 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 475 | ) 476 | 477 | def preprocess_function(examples): 478 | # remove pairs where at least one record is None 479 | 480 | inputs, targets = [], [] 481 | for i in range(len(examples[text_column])): 482 | if examples[text_column][i] is not None and examples[summary_column][i] is not None: 483 | inputs.append(examples[text_column][i]) 484 | targets.append(examples[summary_column][i]) 485 | 486 | inputs = examples[text_column] 487 | targets = examples[summary_column] 488 | inputs = [prefix + inp for inp in inputs] 489 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 490 | 491 | # Setup the tokenizer for targets 492 | with tokenizer.as_target_tokenizer(): 493 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 494 | 495 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 496 | # padding in the loss. 497 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 498 | labels["input_ids"] = [ 499 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 500 | ] 501 | 502 | model_inputs["labels"] = labels["input_ids"] 503 | return model_inputs 504 | 505 | if training_args.do_train: 506 | if "train" not in raw_datasets: 507 | raise ValueError("--do_train requires a train dataset") 508 | train_dataset = raw_datasets["train"] 509 | if data_args.max_train_samples is not None: 510 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 511 | with training_args.main_process_first(desc="train dataset map pre-processing"): 512 | train_dataset = train_dataset.map( 513 | preprocess_function, 514 | batched=True, 515 | num_proc=data_args.preprocessing_num_workers, 516 | remove_columns=column_names, 517 | load_from_cache_file=not data_args.overwrite_cache, 518 | desc="Running tokenizer on train dataset", 519 | ) 520 | 521 | if training_args.do_eval: 522 | max_target_length = data_args.val_max_target_length 523 | if "validation" not in raw_datasets: 524 | raise ValueError("--do_eval requires a validation dataset") 525 | eval_dataset = raw_datasets["validation"] 526 | if data_args.max_eval_samples is not None: 527 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 528 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 529 | eval_dataset = eval_dataset.map( 530 | preprocess_function, 531 | batched=True, 532 | num_proc=data_args.preprocessing_num_workers, 533 | remove_columns=column_names, 534 | load_from_cache_file=not data_args.overwrite_cache, 535 | desc="Running tokenizer on validation dataset", 536 | ) 537 | 538 | if training_args.do_predict: 539 | max_target_length = data_args.val_max_target_length 540 | if "test" not in raw_datasets: 541 | raise ValueError("--do_predict requires a test dataset") 542 | predict_dataset = raw_datasets["test"] 543 | if data_args.max_predict_samples is not None: 544 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 545 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 546 | predict_dataset = predict_dataset.map( 547 | preprocess_function, 548 | batched=True, 549 | num_proc=data_args.preprocessing_num_workers, 550 | remove_columns=column_names, 551 | load_from_cache_file=not data_args.overwrite_cache, 552 | desc="Running tokenizer on prediction dataset", 553 | ) 554 | 555 | # Data collator 556 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 557 | data_collator = DataCollatorForSeq2Seq( 558 | tokenizer, 559 | model=model, 560 | label_pad_token_id=label_pad_token_id, 561 | pad_to_multiple_of=8 if training_args.fp16 else None, 562 | ) 563 | 564 | # Metric 565 | metric = load_metric("rouge") 566 | 567 | def postprocess_text(preds, labels): 568 | preds = [pred.strip() for pred in preds] 569 | labels = [label.strip() for label in labels] 570 | 571 | # rougeLSum expects newline after each sentence 572 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 573 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 574 | 575 | return preds, labels 576 | 577 | def compute_metrics(eval_preds): 578 | preds, labels = eval_preds 579 | if isinstance(preds, tuple): 580 | preds = preds[0] 581 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 582 | if data_args.ignore_pad_token_for_loss: 583 | # Replace -100 in the labels as we can't decode them. 584 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 585 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 586 | 587 | # Some simple post-processing 588 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 589 | 590 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 591 | # Extract a few results from ROUGE 592 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 593 | 594 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 595 | result["gen_len"] = np.mean(prediction_lens) 596 | result = {k: round(v, 4) for k, v in result.items()} 597 | return result 598 | 599 | # Initialize our Trainer 600 | trainer = Seq2SeqTrainer( 601 | model=model, 602 | args=training_args, 603 | train_dataset=train_dataset if training_args.do_train else None, 604 | eval_dataset=eval_dataset if training_args.do_eval else None, 605 | tokenizer=tokenizer, 606 | data_collator=data_collator, 607 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 608 | ) 609 | 610 | # Training 611 | if training_args.do_train: 612 | checkpoint = None 613 | if training_args.resume_from_checkpoint is not None: 614 | checkpoint = training_args.resume_from_checkpoint 615 | elif last_checkpoint is not None: 616 | checkpoint = last_checkpoint 617 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 618 | trainer.save_model() # Saves the tokenizer too for easy upload 619 | 620 | metrics = train_result.metrics 621 | max_train_samples = ( 622 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 623 | ) 624 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 625 | 626 | trainer.log_metrics("train", metrics) 627 | trainer.save_metrics("train", metrics) 628 | trainer.save_state() 629 | 630 | # Evaluation 631 | results = {} 632 | max_length = ( 633 | training_args.generation_max_length 634 | if training_args.generation_max_length is not None 635 | else data_args.val_max_target_length 636 | ) 637 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 638 | if training_args.do_eval: 639 | logger.info("*** Evaluate ***") 640 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 641 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 642 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 643 | 644 | trainer.log_metrics("eval", metrics) 645 | trainer.save_metrics("eval", metrics) 646 | 647 | if training_args.do_predict: 648 | logger.info("*** Predict ***") 649 | 650 | predict_results = trainer.predict( 651 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 652 | ) 653 | metrics = predict_results.metrics 654 | max_predict_samples = ( 655 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 656 | ) 657 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 658 | 659 | trainer.log_metrics("predict", metrics) 660 | trainer.save_metrics("predict", metrics) 661 | 662 | if trainer.is_world_process_zero(): 663 | if training_args.predict_with_generate: 664 | predictions = tokenizer.batch_decode( 665 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 666 | ) 667 | predictions = [pred.strip() for pred in predictions] 668 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 669 | with open(output_prediction_file, "w") as writer: 670 | writer.write("\n".join(predictions)) 671 | 672 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 673 | if data_args.dataset_name is not None: 674 | kwargs["dataset_tags"] = data_args.dataset_name 675 | if data_args.dataset_config_name is not None: 676 | kwargs["dataset_args"] = data_args.dataset_config_name 677 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 678 | else: 679 | kwargs["dataset"] = data_args.dataset_name 680 | 681 | if data_args.lang is not None: 682 | kwargs["language"] = data_args.lang 683 | 684 | if training_args.push_to_hub: 685 | trainer.push_to_hub(**kwargs) 686 | else: 687 | trainer.create_model_card(**kwargs) 688 | 689 | return results 690 | 691 | 692 | def _mp_fn(index): 693 | # For xla_spawn (TPUs) 694 | main() 695 | 696 | 697 | if __name__ == "__main__": 698 | main() 699 | -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | from utils import convert_to_json 2 | from metric.evaluator import get_evaluator 3 | 4 | # Example for data-to-text 5 | task = 'data2text' 6 | 7 | # a list of model outputs to be evaluataed 8 | output_list = ['You would like to search financial district ?'] 9 | # a list of human-annotated reference texts 10 | ref_list = ['You are looking near the financial district , right ?'] 11 | 12 | # Prepare data for pre-trained evaluators 13 | data = convert_to_json(output_list=output_list, ref_list=ref_list) 14 | # Initialize evaluator for a specific task 15 | evaluator = get_evaluator(task) 16 | # Get multi-dimensional evaluation scores 17 | eval_scores = evaluator.evaluate(data, print_result=True) 18 | 19 | 20 | 21 | ''' 22 | # Example for summarization 23 | task = 'summarization' 24 | 25 | # a list of source documents 26 | src_list = ['Peter and Elizabeth took a taxi to attend the night party in the city. \ 27 | While in the party, Elizabeth collapsed and was rushed to the hospital.'] 28 | # a list of human-annotated reference summaries 29 | ref_list = ['Elizabeth was hospitalized after attending a party with Peter.'] 30 | # a list of model outputs to be evaluataed 31 | output_list = ['Peter and Elizabeth attend party city. Elizabeth rushed hospital.'] 32 | 33 | # Prepare data for pre-trained evaluators 34 | data = convert_to_json(output_list=output_list, 35 | src_list=src_list, ref_list=ref_list) 36 | # Initialize evaluator for a specific task 37 | evaluator = get_evaluator(task) 38 | # Get multi-dimensional evaluation scores 39 | eval_scores = evaluator.evaluate(data, print_result=True) 40 | # eval_scores = evaluator.evaluate(data, dims=['coherence', 'consistency', 'fluency'], 41 | # overall=False, print_result=True) 42 | 43 | 44 | 45 | 46 | # Example for dialogue response generation 47 | task = 'dialogue' 48 | 49 | # a list of dialogue histories 50 | src_list = ['hi , do you know much about the internet ? \n i know a lot about different sites and some website design , how about you ? \n\n'] 51 | # a list of additional context that should be included into the generated response 52 | context_list = ['the 3 horizontal line menu on apps and websites is called a hamburger button .\n'] 53 | # a list of model outputs to be evaluated 54 | output_list = ['i do too . did you know the 3 horizontal line menu on apps and websites is called the hamburger button ?'] 55 | 56 | # Prepare data for pre-trained evaluators 57 | data = convert_to_json(output_list=output_list, 58 | src_list=src_list, context_list=context_list) 59 | # Initialize evaluator for a specific task 60 | evaluator = get_evaluator(task) 61 | # Get multi-dimensional evaluation scores 62 | eval_scores = evaluator.evaluate(data, print_result=True) 63 | 64 | 65 | 66 | # Example for factual consistency detection 67 | task = 'fact' 68 | 69 | # a list of source documents 70 | src_list = ['Peter and Elizabeth took a taxi to attend the night party in the city. \ 71 | While in the party, Elizabeth collapsed and was rushed to the hospital.'] 72 | # a list of model outputs (claims) to be evaluataed 73 | output_list = ['Tom was rushed to hospital.'] 74 | 75 | # Prepare data for pre-trained evaluators 76 | data = convert_to_json(output_list=output_list, src_list=src_list) 77 | # Initialize evaluator for a specific task 78 | evaluator = get_evaluator(task) 79 | # Get factual consistency scores 80 | eval_scores = evaluator.evaluate(data, print_result=True) 81 | ''' -------------------------------------------------------------------------------- /figures/UniEval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maszhongming/UniEval/d33e7b6cfebe97b2bafe435adbd818230d5a416a/figures/UniEval.png -------------------------------------------------------------------------------- /figures/evaluation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maszhongming/UniEval/d33e7b6cfebe97b2bafe435adbd818230d5a416a/figures/evaluation.png -------------------------------------------------------------------------------- /figures/intermediate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maszhongming/UniEval/d33e7b6cfebe97b2bafe435adbd818230d5a416a/figures/intermediate.png -------------------------------------------------------------------------------- /intermediate_tasks/README.md: -------------------------------------------------------------------------------- 1 | # Intermediate Pre-training 2 | 3 |

4 | intermediate 5 |

6 | 7 | By performing intermediate multi-task learning on T5, we can obtain a Boolean Answer Generator. We have released our intermediate model and based on [unieval-intermediate](https://huggingface.co/MingZhong/unieval-intermediate), you can train a custom evaluator for a specific NLG task. 8 | 9 | ## Pre-train Data 10 | 11 | In total, we use data from the following four tasks to perform intermediate multi-task learning: 12 | 13 | - Question Answering: [BoolQ](https://github.com/google-research-datasets/boolean-questions), [BoolQ-NP](https://github.com/allenai/natural-perturbations), [BoolQ-CS](https://github.com/allenai/contrast-sets), [StrategyQA](https://allenai.org/data/strategyqa) and [MultiRC](https://cogcomp.seas.upenn.edu/multirc/). 14 | - Natural Language Inference: [DocNLI](https://arxiv.org/abs/2106.09449), [MRPC](https://huggingface.co/datasets/glue/viewer/mrpc/train) and [QQP](https://huggingface.co/datasets/glue/viewer/qqp/train) 15 | - Self-supervised Task: Opening Sentence Prediction on [CNN/DailyMail Corpus](https://huggingface.co/datasets/cnn_dailymail) 16 | - Linguistics-Related Task: [CoLA](https://huggingface.co/datasets/glue/viewer/cola/train) 17 | 18 | The statistics are in [data_info.txt](./data_info.txt). All the pre-train data in the Boolean QA format can be found [here](https://drive.google.com/file/d/16T2tlAZDrgA5LMa5WYRhMz7SrAFwQfH7/view?usp=sharing). Please unzip it and put it in `./data`. 19 | 20 | ## Training 21 | Run the following script to perform intermediate pre-training: 22 | ```bash 23 | export TOKENIZERS_PARALLELISM=true 24 | export OMP_NUM_THREADS=1 25 | 26 | CUDA_VISIBLE_DEVICES=0,1,2 \ 27 | python -m torch.distributed.launch --nproc_per_node 3 train_seq2seq.py \ 28 | --model_name_or_path google/t5-v1_1-large \ 29 | --do_train \ 30 | --train_file data/intermediate_train.json \ 31 | --text_column src \ 32 | --summary_column tgt \ 33 | --output_dir ./inter_model \ 34 | --per_device_train_batch_size 3 \ 35 | --gradient_accumulation_steps 4 \ 36 | --max_source_length 1024 \ 37 | --max_target_length 16 \ 38 | --save_strategy epoch \ 39 | --num_train_epochs 10 \ 40 | --ddp_find_unused_parameters False \ 41 | ``` 42 | 43 | - The batch size can be determined based on your GPUs. 44 | - We use the checkpoint of the second epochs as [unieval-intermediate](https://huggingface.co/MingZhong/unieval-intermediate). -------------------------------------------------------------------------------- /intermediate_tasks/data_info.txt: -------------------------------------------------------------------------------- 1 | In NLI Task: 2 | docnli datasets: 3 | positive samples: 29687 4 | negative samples: 30313 5 | total samples: 60000 6 | mrpc datasets: 7 | positive samples: 3893 8 | negative samples: 1908 9 | total samples: 5801 10 | qqp datasets: 11 | positive samples: 7467 12 | negative samples: 12533 13 | total samples: 20000 14 | Statistics of NLI datasets: 15 | positive samples: 41047 16 | negative samples: 44754 17 | total samples: 85801 18 | ---------------------------------------------------- 19 | In SST Task: 20 | fsp datasets: 21 | positive samples: 30000 22 | negative samples: 30000 23 | total samples: 60000 24 | Statistics of SST datasets: 25 | positive samples: 30000 26 | negative samples: 30000 27 | total samples: 60000 28 | ---------------------------------------------------- 29 | In QA Task: 30 | boolq datasets: 31 | positive samples: 7907 32 | negative samples: 4790 33 | total samples: 12697 34 | boolq_cs datasets: 35 | positive samples: 165 36 | negative samples: 170 37 | total samples: 335 38 | boolq_np datasets: 39 | positive samples: 7697 40 | negative samples: 6795 41 | total samples: 14492 42 | multirc datasets: 43 | positive samples: 192 44 | negative samples: 122 45 | total samples: 314 46 | strategyqa datasets: 47 | positive samples: 1071 48 | negative samples: 1219 49 | total samples: 2290 50 | Statistics of QA datasets: 51 | positive samples: 17032 52 | negative samples: 13096 53 | total samples: 30128 54 | ---------------------------------------------------- 55 | In LIN Task: 56 | cola datasets: 57 | positive samples: 6744 58 | negative samples: 2850 59 | total samples: 9594 60 | Statistics of LIN datasets: 61 | positive samples: 6744 62 | negative samples: 2850 63 | total samples: 9594 64 | ---------------------------------------------------- 65 | Total Statistics of Intermediate datasets: 66 | positive samples: 94823 67 | negative samples: 90700 68 | total samples: 185523 -------------------------------------------------------------------------------- /intermediate_tasks/train_inter.sh: -------------------------------------------------------------------------------- 1 | export TOKENIZERS_PARALLELISM=true 2 | export OMP_NUM_THREADS=1 3 | 4 | CUDA_VISIBLE_DEVICES=0,1,2 \ 5 | python -m torch.distributed.launch --nproc_per_node 3 train_seq2seq.py \ 6 | --model_name_or_path google/t5-v1_1-large \ 7 | --do_train \ 8 | --train_file data/intermediate_train.json \ 9 | --text_column src \ 10 | --summary_column tgt \ 11 | --output_dir ./inter_model \ 12 | --per_device_train_batch_size 3 \ 13 | --gradient_accumulation_steps 4 \ 14 | --max_source_length 1024 \ 15 | --max_target_length 16 \ 16 | --save_strategy epoch \ 17 | --num_train_epochs 10 \ 18 | --ddp_find_unused_parameters False \ 19 | -------------------------------------------------------------------------------- /intermediate_tasks/train_seq2seq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import datasets 28 | import nltk # Here to have a nice missing dependency error message early on 29 | import numpy as np 30 | from datasets import load_dataset, load_metric 31 | 32 | import transformers 33 | from filelock import FileLock 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForSeq2SeqLM, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | MBart50Tokenizer, 41 | MBart50TokenizerFast, 42 | MBartTokenizer, 43 | MBartTokenizerFast, 44 | Seq2SeqTrainer, 45 | Seq2SeqTrainingArguments, 46 | set_seed, 47 | ) 48 | from transformers.file_utils import is_offline_mode 49 | from transformers.trainer_utils import get_last_checkpoint 50 | from transformers.utils import check_min_version 51 | from transformers.utils.versions import require_version 52 | 53 | 54 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 55 | check_min_version("4.17.0.dev0") 56 | 57 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 58 | 59 | logger = logging.getLogger(__name__) 60 | 61 | try: 62 | nltk.data.find("tokenizers/punkt") 63 | except (LookupError, OSError): 64 | if is_offline_mode(): 65 | raise LookupError( 66 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 67 | ) 68 | with FileLock(".lock") as lock: 69 | nltk.download("punkt", quiet=True) 70 | 71 | # A list of all multilingual tokenizer which require lang attribute. 72 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] 73 | 74 | 75 | @dataclass 76 | class ModelArguments: 77 | """ 78 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 79 | """ 80 | 81 | model_name_or_path: str = field( 82 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 83 | ) 84 | config_name: Optional[str] = field( 85 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 86 | ) 87 | tokenizer_name: Optional[str] = field( 88 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 89 | ) 90 | cache_dir: Optional[str] = field( 91 | default=None, 92 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 93 | ) 94 | use_fast_tokenizer: bool = field( 95 | default=True, 96 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 97 | ) 98 | model_revision: str = field( 99 | default="main", 100 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 101 | ) 102 | use_auth_token: bool = field( 103 | default=False, 104 | metadata={ 105 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 106 | "with private models)." 107 | }, 108 | ) 109 | resize_position_embeddings: Optional[bool] = field( 110 | default=None, 111 | metadata={ 112 | "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 113 | "the model's position embeddings." 114 | }, 115 | ) 116 | 117 | 118 | @dataclass 119 | class DataTrainingArguments: 120 | """ 121 | Arguments pertaining to what data we are going to input our model for training and eval. 122 | """ 123 | 124 | lang: str = field(default=None, metadata={"help": "Language id for summarization."}) 125 | 126 | dataset_name: Optional[str] = field( 127 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 128 | ) 129 | dataset_config_name: Optional[str] = field( 130 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 131 | ) 132 | text_column: Optional[str] = field( 133 | default=None, 134 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 135 | ) 136 | summary_column: Optional[str] = field( 137 | default=None, 138 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 139 | ) 140 | train_file: Optional[str] = field( 141 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 142 | ) 143 | validation_file: Optional[str] = field( 144 | default=None, 145 | metadata={ 146 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 147 | "(a jsonlines or csv file)." 148 | }, 149 | ) 150 | test_file: Optional[str] = field( 151 | default=None, 152 | metadata={ 153 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 154 | }, 155 | ) 156 | overwrite_cache: bool = field( 157 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 158 | ) 159 | preprocessing_num_workers: Optional[int] = field( 160 | default=None, 161 | metadata={"help": "The number of processes to use for the preprocessing."}, 162 | ) 163 | max_source_length: Optional[int] = field( 164 | default=1024, 165 | metadata={ 166 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 167 | "than this will be truncated, sequences shorter will be padded." 168 | }, 169 | ) 170 | max_target_length: Optional[int] = field( 171 | default=128, 172 | metadata={ 173 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 174 | "than this will be truncated, sequences shorter will be padded." 175 | }, 176 | ) 177 | val_max_target_length: Optional[int] = field( 178 | default=None, 179 | metadata={ 180 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 181 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 182 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 183 | "during ``evaluate`` and ``predict``." 184 | }, 185 | ) 186 | pad_to_max_length: bool = field( 187 | default=False, 188 | metadata={ 189 | "help": "Whether to pad all samples to model maximum sentence length. " 190 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 191 | "efficient on GPU but very bad for TPU." 192 | }, 193 | ) 194 | max_train_samples: Optional[int] = field( 195 | default=None, 196 | metadata={ 197 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 198 | "value if set." 199 | }, 200 | ) 201 | max_eval_samples: Optional[int] = field( 202 | default=None, 203 | metadata={ 204 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 205 | "value if set." 206 | }, 207 | ) 208 | max_predict_samples: Optional[int] = field( 209 | default=None, 210 | metadata={ 211 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 212 | "value if set." 213 | }, 214 | ) 215 | num_beams: Optional[int] = field( 216 | default=None, 217 | metadata={ 218 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 219 | "which is used during ``evaluate`` and ``predict``." 220 | }, 221 | ) 222 | ignore_pad_token_for_loss: bool = field( 223 | default=True, 224 | metadata={ 225 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 226 | }, 227 | ) 228 | source_prefix: Optional[str] = field( 229 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 230 | ) 231 | 232 | forced_bos_token: Optional[str] = field( 233 | default=None, 234 | metadata={ 235 | "help": "The token to force as the first generated token after the decoder_start_token_id." 236 | "Useful for multilingual models like mBART where the first generated token" 237 | "needs to be the target language token (Usually it is the target language token)" 238 | }, 239 | ) 240 | 241 | def __post_init__(self): 242 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 243 | raise ValueError("Need either a dataset name or a training/validation file.") 244 | else: 245 | if self.train_file is not None: 246 | extension = self.train_file.split(".")[-1] 247 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 248 | if self.validation_file is not None: 249 | extension = self.validation_file.split(".")[-1] 250 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 251 | if self.val_max_target_length is None: 252 | self.val_max_target_length = self.max_target_length 253 | 254 | 255 | summarization_name_mapping = { 256 | "amazon_reviews_multi": ("review_body", "review_title"), 257 | "big_patent": ("description", "abstract"), 258 | "cnn_dailymail": ("article", "highlights"), 259 | "orange_sum": ("text", "summary"), 260 | "pn_summary": ("article", "summary"), 261 | "psc": ("extract_text", "summary_text"), 262 | "samsum": ("dialogue", "summary"), 263 | "thaisum": ("body", "summary"), 264 | "xglue": ("news_body", "news_title"), 265 | "xsum": ("document", "summary"), 266 | "wiki_summary": ("article", "highlights"), 267 | } 268 | 269 | 270 | def main(): 271 | # See all possible arguments in src/transformers/training_args.py 272 | # or by passing the --help flag to this script. 273 | # We now keep distinct sets of args, for a cleaner separation of concerns. 274 | 275 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 276 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 277 | # If we pass only one argument to the script and it's the path to a json file, 278 | # let's parse it to get our arguments. 279 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 280 | else: 281 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 282 | 283 | # Setup logging 284 | logging.basicConfig( 285 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 286 | datefmt="%m/%d/%Y %H:%M:%S", 287 | handlers=[logging.StreamHandler(sys.stdout)], 288 | ) 289 | log_level = training_args.get_process_log_level() 290 | logger.setLevel(log_level) 291 | datasets.utils.logging.set_verbosity(log_level) 292 | transformers.utils.logging.set_verbosity(log_level) 293 | transformers.utils.logging.enable_default_handler() 294 | transformers.utils.logging.enable_explicit_format() 295 | 296 | # Log on each process the small summary: 297 | logger.warning( 298 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 299 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 300 | ) 301 | logger.info(f"Training/evaluation parameters {training_args}") 302 | 303 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 304 | "t5-small", 305 | "t5-base", 306 | "t5-large", 307 | "t5-3b", 308 | "t5-11b", 309 | ]: 310 | logger.warning( 311 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 312 | "`--source_prefix 'summarize: ' `" 313 | ) 314 | 315 | # Detecting last checkpoint. 316 | last_checkpoint = None 317 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 318 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 319 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 320 | raise ValueError( 321 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 322 | "Use --overwrite_output_dir to overcome." 323 | ) 324 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 325 | logger.info( 326 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 327 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 328 | ) 329 | 330 | # Set seed before initializing model. 331 | set_seed(training_args.seed) 332 | 333 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 334 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 335 | # (the dataset will be downloaded automatically from the datasets Hub). 336 | # 337 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 338 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 339 | # 340 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 341 | # download the dataset. 342 | if data_args.dataset_name is not None: 343 | # Downloading and loading a dataset from the hub. 344 | raw_datasets = load_dataset( 345 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 346 | ) 347 | else: 348 | data_files = {} 349 | if data_args.train_file is not None: 350 | data_files["train"] = data_args.train_file 351 | extension = data_args.train_file.split(".")[-1] 352 | if data_args.validation_file is not None: 353 | data_files["validation"] = data_args.validation_file 354 | extension = data_args.validation_file.split(".")[-1] 355 | if data_args.test_file is not None: 356 | data_files["test"] = data_args.test_file 357 | extension = data_args.test_file.split(".")[-1] 358 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 359 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 360 | # https://huggingface.co/docs/datasets/loading_datasets.html. 361 | 362 | # Load pretrained model and tokenizer 363 | # 364 | # Distributed training: 365 | # The .from_pretrained methods guarantee that only one local process can concurrently 366 | # download model & vocab. 367 | config = AutoConfig.from_pretrained( 368 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 369 | cache_dir=model_args.cache_dir, 370 | revision=model_args.model_revision, 371 | use_auth_token=True if model_args.use_auth_token else None, 372 | ) 373 | tokenizer = AutoTokenizer.from_pretrained( 374 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 375 | cache_dir=model_args.cache_dir, 376 | use_fast=model_args.use_fast_tokenizer, 377 | revision=model_args.model_revision, 378 | use_auth_token=True if model_args.use_auth_token else None, 379 | ) 380 | model = AutoModelForSeq2SeqLM.from_pretrained( 381 | model_args.model_name_or_path, 382 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 383 | config=config, 384 | cache_dir=model_args.cache_dir, 385 | revision=model_args.model_revision, 386 | use_auth_token=True if model_args.use_auth_token else None, 387 | ) 388 | 389 | model.resize_token_embeddings(len(tokenizer)) 390 | 391 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): 392 | if isinstance(tokenizer, MBartTokenizer): 393 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] 394 | else: 395 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang) 396 | 397 | if model.config.decoder_start_token_id is None: 398 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 399 | 400 | if ( 401 | hasattr(model.config, "max_position_embeddings") 402 | and model.config.max_position_embeddings < data_args.max_source_length 403 | ): 404 | if model_args.resize_position_embeddings is None: 405 | logger.warning( 406 | f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} " 407 | f"to {data_args.max_source_length}." 408 | ) 409 | model.resize_position_embeddings(data_args.max_source_length) 410 | elif model_args.resize_position_embeddings: 411 | model.resize_position_embeddings(data_args.max_source_length) 412 | else: 413 | raise ValueError( 414 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}" 415 | f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically " 416 | "resize the model's position encodings by passing `--resize_position_embeddings`." 417 | ) 418 | 419 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 420 | 421 | # Preprocessing the datasets. 422 | # We need to tokenize inputs and targets. 423 | if training_args.do_train: 424 | column_names = raw_datasets["train"].column_names 425 | elif training_args.do_eval: 426 | column_names = raw_datasets["validation"].column_names 427 | elif training_args.do_predict: 428 | column_names = raw_datasets["test"].column_names 429 | else: 430 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 431 | return 432 | 433 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): 434 | assert ( 435 | data_args.lang is not None 436 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" 437 | 438 | tokenizer.src_lang = data_args.lang 439 | tokenizer.tgt_lang = data_args.lang 440 | 441 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token 442 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. 443 | forced_bos_token_id = ( 444 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None 445 | ) 446 | model.config.forced_bos_token_id = forced_bos_token_id 447 | 448 | # Get the column names for input/target. 449 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 450 | if data_args.text_column is None: 451 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 452 | else: 453 | text_column = data_args.text_column 454 | if text_column not in column_names: 455 | raise ValueError( 456 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 457 | ) 458 | if data_args.summary_column is None: 459 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 460 | else: 461 | summary_column = data_args.summary_column 462 | if summary_column not in column_names: 463 | raise ValueError( 464 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 465 | ) 466 | 467 | # Temporarily set max_target_length for training. 468 | max_target_length = data_args.max_target_length 469 | padding = "max_length" if data_args.pad_to_max_length else False 470 | 471 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 472 | logger.warning( 473 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 474 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 475 | ) 476 | 477 | def preprocess_function(examples): 478 | # remove pairs where at least one record is None 479 | 480 | inputs, targets = [], [] 481 | for i in range(len(examples[text_column])): 482 | if examples[text_column][i] is not None and examples[summary_column][i] is not None: 483 | inputs.append(examples[text_column][i]) 484 | targets.append(examples[summary_column][i]) 485 | 486 | inputs = examples[text_column] 487 | targets = examples[summary_column] 488 | inputs = [prefix + inp for inp in inputs] 489 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 490 | 491 | # Setup the tokenizer for targets 492 | with tokenizer.as_target_tokenizer(): 493 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 494 | 495 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 496 | # padding in the loss. 497 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 498 | labels["input_ids"] = [ 499 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 500 | ] 501 | 502 | model_inputs["labels"] = labels["input_ids"] 503 | return model_inputs 504 | 505 | if training_args.do_train: 506 | if "train" not in raw_datasets: 507 | raise ValueError("--do_train requires a train dataset") 508 | train_dataset = raw_datasets["train"] 509 | if data_args.max_train_samples is not None: 510 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 511 | with training_args.main_process_first(desc="train dataset map pre-processing"): 512 | train_dataset = train_dataset.map( 513 | preprocess_function, 514 | batched=True, 515 | num_proc=data_args.preprocessing_num_workers, 516 | remove_columns=column_names, 517 | load_from_cache_file=not data_args.overwrite_cache, 518 | desc="Running tokenizer on train dataset", 519 | ) 520 | 521 | if training_args.do_eval: 522 | max_target_length = data_args.val_max_target_length 523 | if "validation" not in raw_datasets: 524 | raise ValueError("--do_eval requires a validation dataset") 525 | eval_dataset = raw_datasets["validation"] 526 | if data_args.max_eval_samples is not None: 527 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 528 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 529 | eval_dataset = eval_dataset.map( 530 | preprocess_function, 531 | batched=True, 532 | num_proc=data_args.preprocessing_num_workers, 533 | remove_columns=column_names, 534 | load_from_cache_file=not data_args.overwrite_cache, 535 | desc="Running tokenizer on validation dataset", 536 | ) 537 | 538 | if training_args.do_predict: 539 | max_target_length = data_args.val_max_target_length 540 | if "test" not in raw_datasets: 541 | raise ValueError("--do_predict requires a test dataset") 542 | predict_dataset = raw_datasets["test"] 543 | if data_args.max_predict_samples is not None: 544 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 545 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 546 | predict_dataset = predict_dataset.map( 547 | preprocess_function, 548 | batched=True, 549 | num_proc=data_args.preprocessing_num_workers, 550 | remove_columns=column_names, 551 | load_from_cache_file=not data_args.overwrite_cache, 552 | desc="Running tokenizer on prediction dataset", 553 | ) 554 | 555 | # Data collator 556 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 557 | data_collator = DataCollatorForSeq2Seq( 558 | tokenizer, 559 | model=model, 560 | label_pad_token_id=label_pad_token_id, 561 | pad_to_multiple_of=8 if training_args.fp16 else None, 562 | ) 563 | 564 | # Metric 565 | metric = load_metric("rouge") 566 | 567 | def postprocess_text(preds, labels): 568 | preds = [pred.strip() for pred in preds] 569 | labels = [label.strip() for label in labels] 570 | 571 | # rougeLSum expects newline after each sentence 572 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 573 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 574 | 575 | return preds, labels 576 | 577 | def compute_metrics(eval_preds): 578 | preds, labels = eval_preds 579 | if isinstance(preds, tuple): 580 | preds = preds[0] 581 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 582 | if data_args.ignore_pad_token_for_loss: 583 | # Replace -100 in the labels as we can't decode them. 584 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 585 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 586 | 587 | # Some simple post-processing 588 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 589 | 590 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 591 | # Extract a few results from ROUGE 592 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 593 | 594 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 595 | result["gen_len"] = np.mean(prediction_lens) 596 | result = {k: round(v, 4) for k, v in result.items()} 597 | return result 598 | 599 | # Initialize our Trainer 600 | trainer = Seq2SeqTrainer( 601 | model=model, 602 | args=training_args, 603 | train_dataset=train_dataset if training_args.do_train else None, 604 | eval_dataset=eval_dataset if training_args.do_eval else None, 605 | tokenizer=tokenizer, 606 | data_collator=data_collator, 607 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 608 | ) 609 | 610 | # Training 611 | if training_args.do_train: 612 | checkpoint = None 613 | if training_args.resume_from_checkpoint is not None: 614 | checkpoint = training_args.resume_from_checkpoint 615 | elif last_checkpoint is not None: 616 | checkpoint = last_checkpoint 617 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 618 | trainer.save_model() # Saves the tokenizer too for easy upload 619 | 620 | metrics = train_result.metrics 621 | max_train_samples = ( 622 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 623 | ) 624 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 625 | 626 | trainer.log_metrics("train", metrics) 627 | trainer.save_metrics("train", metrics) 628 | trainer.save_state() 629 | 630 | # Evaluation 631 | results = {} 632 | max_length = ( 633 | training_args.generation_max_length 634 | if training_args.generation_max_length is not None 635 | else data_args.val_max_target_length 636 | ) 637 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 638 | if training_args.do_eval: 639 | logger.info("*** Evaluate ***") 640 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 641 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 642 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 643 | 644 | trainer.log_metrics("eval", metrics) 645 | trainer.save_metrics("eval", metrics) 646 | 647 | if training_args.do_predict: 648 | logger.info("*** Predict ***") 649 | 650 | predict_results = trainer.predict( 651 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 652 | ) 653 | metrics = predict_results.metrics 654 | max_predict_samples = ( 655 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 656 | ) 657 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 658 | 659 | trainer.log_metrics("predict", metrics) 660 | trainer.save_metrics("predict", metrics) 661 | 662 | if trainer.is_world_process_zero(): 663 | if training_args.predict_with_generate: 664 | predictions = tokenizer.batch_decode( 665 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 666 | ) 667 | predictions = [pred.strip() for pred in predictions] 668 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 669 | with open(output_prediction_file, "w") as writer: 670 | writer.write("\n".join(predictions)) 671 | 672 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 673 | if data_args.dataset_name is not None: 674 | kwargs["dataset_tags"] = data_args.dataset_name 675 | if data_args.dataset_config_name is not None: 676 | kwargs["dataset_args"] = data_args.dataset_config_name 677 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 678 | else: 679 | kwargs["dataset"] = data_args.dataset_name 680 | 681 | if data_args.lang is not None: 682 | kwargs["language"] = data_args.lang 683 | 684 | if training_args.push_to_hub: 685 | trainer.push_to_hub(**kwargs) 686 | else: 687 | trainer.create_model_card(**kwargs) 688 | 689 | return results 690 | 691 | 692 | def _mp_fn(index): 693 | # For xla_spawn (TPUs) 694 | main() 695 | 696 | 697 | if __name__ == "__main__": 698 | main() 699 | -------------------------------------------------------------------------------- /metric/evaluator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | from nltk import sent_tokenize 4 | from metric.scorer import UniEvaluator 5 | sys.path.append("..") 6 | from utils import add_question, print_scores 7 | 8 | class SumEvaluator: 9 | def __init__(self, max_length=1024, device='cuda:0', cache_dir=None): 10 | """ Set up evaluator for text summarization """ 11 | self.scorer = UniEvaluator(model_name_or_path='MingZhong/unieval-sum', 12 | max_length=max_length, 13 | device=device, cache_dir=cache_dir) 14 | self.task = 'summarization' 15 | self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance'] 16 | 17 | def evaluate(self, data, dims=None, overall=True, print_result=False): 18 | """ 19 | Get the scores of all the given dimensions 20 | 21 | dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate 22 | four dimensions: coherence, consistency, fluency, relevance. 23 | 24 | overall: indicates whether the overall score is to be calculated. 25 | Overall score can be customized to a combination of scores based on different 26 | dimensions. The default here is the average score of all the given dimensions. 27 | 28 | print_result: whether to print the average score of each dimension on the screen 29 | """ 30 | n_data = len(data) 31 | eval_scores = [{} for _ in range(n_data)] 32 | 33 | if dims == None: 34 | eval_dims = self.dimensions 35 | else: 36 | assert isinstance(dims, list) 37 | eval_dims = dims 38 | 39 | for dim in eval_dims: 40 | print('Evaluating {} of {} samples !!!'.format(dim, n_data)) 41 | 42 | # Calculate average sentence-level scores for 'consistency' and 'fluency' 43 | if dim == 'consistency' or dim == 'fluency': 44 | src_list, output_list = [], [] 45 | n_sents = [] # the number of sentences in each generated summary 46 | for i in range(n_data): 47 | if dim == 'consistency': 48 | source = data[i]['source'] 49 | else: 50 | source = '' 51 | system_outputs = sent_tokenize(data[i]['system_output']) 52 | n_sents.append(len(system_outputs)) 53 | for j in range(len(system_outputs)): 54 | src_list.append(source) 55 | output_list.append(system_outputs[j]) 56 | input_list = add_question(dimension=dim, output=output_list, 57 | src=src_list, task=self.task) 58 | sent_score = self.scorer.score(input_list) 59 | 60 | # Get average score for each sample 61 | start_idx = 0 62 | score = [] 63 | for cur_n_sent in n_sents: 64 | score.append(sum(sent_score[start_idx: start_idx + cur_n_sent]) / cur_n_sent) 65 | start_idx += cur_n_sent 66 | 67 | # Calculate summary-level score for 'coherence' and 'relevance' 68 | elif dim == 'coherence' or dim == 'relevance': 69 | src_list, output_list, ref_list = [], [], [] 70 | for i in range(n_data): 71 | src_list.append(data[i]['source']) 72 | output_list.append(data[i]['system_output']) 73 | if dim == 'relevance': 74 | ref_list.append(data[i]['reference']) 75 | input_list = add_question(dimension=dim, output=output_list, 76 | src=src_list, ref=ref_list, task=self.task) 77 | score = self.scorer.score(input_list) 78 | 79 | # Please customize other dimensions here for summarization 80 | else: 81 | raise NotImplementedError('The input format for this dimension is still undefined. \ 82 | Please customize it first.') 83 | 84 | for i in range(n_data): 85 | eval_scores[i][dim] = score[i] 86 | 87 | # Customize your overall score here. 88 | if overall == True: 89 | for i in range(n_data): 90 | eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) 91 | 92 | if print_result == True: 93 | print_scores(eval_scores) 94 | 95 | return eval_scores 96 | 97 | 98 | class DialogEvaluator: 99 | def __init__(self, max_length=1024, device='cuda:0', cache_dir=None): 100 | """ Set up evaluator for dialogues """ 101 | self.scorer = UniEvaluator(model_name_or_path='MingZhong/unieval-dialog', 102 | max_length=max_length, 103 | device=device, cache_dir=cache_dir) 104 | self.task = 'dialogue' 105 | self.dimensions = ['naturalness', 'coherence', 'engagingness', 106 | 'groundedness', 'understandability'] 107 | 108 | def evaluate(self, data, dims=None, overall=True, print_result=False): 109 | """ 110 | Get the scores of all the given dimensions 111 | 112 | dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate 113 | five dimensions: naturalness, coherence, engagingness, groundedness and understandability. 114 | 115 | overall: indicates whether the overall score is to be calculated. 116 | Overall score can be customized to a combination of scores based on different 117 | dimensions. The default here is the average score of all the given dimensions. 118 | 119 | print_result: whether to print the average score of each dimension on the screen 120 | """ 121 | n_data = len(data) 122 | eval_scores = [{} for _ in range(n_data)] 123 | 124 | if dims == None: 125 | eval_dims = self.dimensions 126 | else: 127 | assert isinstance(dims, list) 128 | eval_dims = dims 129 | 130 | for dim in eval_dims: 131 | print('Evaluating {} of {} samples !!!'.format(dim, n_data)) 132 | 133 | # Calculate summation score for 'engagingness' 134 | if dim == 'engagingness': 135 | src_list, output_list, context_list = [], [], [] 136 | n_sents = [] # the number of sentences in each generated response 137 | for i in range(n_data): 138 | source = data[i]['source'] 139 | context = data[i]['context'] 140 | system_outputs = sent_tokenize(data[i]['system_output']) 141 | n_sents.append(len(system_outputs)) 142 | for j in range(len(system_outputs)): 143 | src_list.append(source) 144 | context_list.append(context) 145 | output_list.append(system_outputs[j]) 146 | input_list = add_question(dimension=dim, output=output_list, 147 | src=src_list, context=context_list, task=self.task) 148 | sent_score = self.scorer.score(input_list) 149 | 150 | # Get the summation score for each sample 151 | start_idx = 0 152 | score = [] 153 | for cur_n_sent in n_sents: 154 | score.append(sum(sent_score[start_idx: start_idx + cur_n_sent])) 155 | start_idx += cur_n_sent 156 | 157 | # Calculate turn-level score for other dimensions 158 | elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']: 159 | src_list, output_list, context_list = [], [], [] 160 | for i in range(n_data): 161 | if dim == 'coherence': 162 | src_list.append(data[i]['source']) 163 | else: 164 | src_list.append('') 165 | output_list.append(data[i]['system_output']) 166 | if dim == 'groundedness': 167 | context_list.append(data[i]['context']) 168 | else: 169 | context_list.append('') 170 | input_list = add_question(dimension=dim, output=output_list, 171 | src=src_list, context=context_list, task=self.task) 172 | score = self.scorer.score(input_list) 173 | 174 | # Please customize other dimensions here for summarization 175 | else: 176 | raise NotImplementedError('The input format for this dimension is still undefined. \ 177 | Please customize it first.') 178 | 179 | for i in range(n_data): 180 | eval_scores[i][dim] = score[i] 181 | 182 | # Customize your overall score here. 183 | if overall == True: 184 | for i in range(n_data): 185 | eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) 186 | 187 | if print_result == True: 188 | print_scores(eval_scores) 189 | 190 | return eval_scores 191 | 192 | 193 | class D2tEvaluator: 194 | def __init__(self, max_length=1024, device='cuda:0', cache_dir=None): 195 | """ Set up evaluator for data-to-text """ 196 | self.scorer = UniEvaluator(model_name_or_path='MingZhong/unieval-sum', 197 | max_length=max_length, 198 | device=device, cache_dir=cache_dir) 199 | self.task = 'data2text' 200 | self.dimensions = ['naturalness', 'informativeness'] 201 | 202 | def evaluate(self, data, dims=None, overall=True, print_result=False): 203 | """ 204 | Get the scores of all the given dimensions 205 | 206 | dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate 207 | two dimensions: naturalness and informativeness. 208 | 209 | overall: indicates whether the overall score is to be calculated. 210 | Overall score can be customized to a combination of scores based on different 211 | dimensions. The default here is the average score of all the given dimensions. 212 | 213 | print_result: whether to print the average score of each dimension on the screen 214 | """ 215 | n_data = len(data) 216 | eval_scores = [{} for _ in range(n_data)] 217 | 218 | if dims == None: 219 | eval_dims = self.dimensions 220 | else: 221 | assert isinstance(dims, list) 222 | eval_dims = dims 223 | 224 | for dim in eval_dims: 225 | print('Evaluating {} of {} samples !!!'.format(dim, n_data)) 226 | 227 | output_list, ref_list = [], [] 228 | for i in range(n_data): 229 | output_list.append(data[i]['system_output']) 230 | ref_list.append(data[i]['reference']) 231 | 232 | input_list = add_question(dimension=dim, output=output_list, 233 | ref=ref_list, task=self.task) 234 | score = self.scorer.score(input_list) 235 | 236 | for i in range(n_data): 237 | eval_scores[i][dim] = score[i] 238 | 239 | # Customize your overall score here. 240 | if overall == True: 241 | for i in range(n_data): 242 | eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) 243 | 244 | if print_result == True: 245 | print_scores(eval_scores) 246 | 247 | return eval_scores 248 | 249 | 250 | class FactEvaluator: 251 | def __init__(self, max_length=1024, device='cuda:0', cache_dir=None): 252 | """ Set up evaluator for factual consistency detection """ 253 | self.scorer = UniEvaluator(model_name_or_path='MingZhong/unieval-fact', 254 | max_length=max_length, 255 | device=device, cache_dir=cache_dir) 256 | self.task = 'fact' 257 | self.dim = 'consistency' 258 | 259 | def evaluate(self, data, print_result=False): 260 | """ 261 | Get the factual consistency score (only 1 dimension for this task) 262 | 263 | print_result: whether to print the average factual score on the screen 264 | """ 265 | n_data = len(data) 266 | eval_scores = [{} for _ in range(n_data)] 267 | 268 | print('Evaluating {} of {} samples !!!'.format(self.dim, n_data)) 269 | 270 | # Calculate average sentence-level scores for facutal consistency 271 | src_list, output_list = [], [] 272 | n_sents = [] # the number of sentences in the claim 273 | for i in range(n_data): 274 | source = data[i]['source'] 275 | system_outputs = sent_tokenize(data[i]['system_output']) 276 | n_sents.append(len(system_outputs)) 277 | for j in range(len(system_outputs)): 278 | src_list.append(source) 279 | output_list.append(system_outputs[j]) 280 | input_list = add_question(dimension=self.dim, output=output_list, 281 | src=src_list, task=self.task) 282 | sent_score = self.scorer.score(input_list) 283 | 284 | # Get average score for each sample 285 | start_idx = 0 286 | score = [] 287 | for cur_n_sent in n_sents: 288 | score.append(sum(sent_score[start_idx: start_idx + cur_n_sent]) / cur_n_sent) 289 | start_idx += cur_n_sent 290 | 291 | for i in range(n_data): 292 | eval_scores[i][self.dim] = score[i] 293 | 294 | if print_result == True: 295 | print_scores(eval_scores) 296 | 297 | return eval_scores 298 | 299 | def get_evaluator(task, max_length=1024, device='cuda:0', cache_dir=None): 300 | assert task in ['summarization', 'dialogue', 'data2text', 'fact'] 301 | if task == 'summarization': 302 | return SumEvaluator(max_length=max_length, 303 | device=device, 304 | cache_dir=cache_dir) 305 | elif task == 'dialogue': 306 | return DialogEvaluator(max_length=max_length, 307 | device=device, 308 | cache_dir=cache_dir) 309 | elif task == 'data2text': 310 | return D2tEvaluator(max_length=max_length, 311 | device=device, 312 | cache_dir=cache_dir) 313 | elif task == 'fact': 314 | return FactEvaluator(max_length=max_length, 315 | device=device, 316 | cache_dir=cache_dir) 317 | else: 318 | raise NotImplementedError('Other tasks are not implemented, \ 319 | please customize specific tasks here.') 320 | 321 | -------------------------------------------------------------------------------- /metric/scorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM 4 | from tqdm import tqdm 5 | 6 | class UniEvaluator: 7 | def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): 8 | """ Set up model """ 9 | self.device = device 10 | self.max_length = max_length 11 | 12 | self.config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) 13 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) 14 | self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config, 15 | cache_dir=cache_dir) 16 | 17 | self.model.eval() 18 | self.model.to(device) 19 | 20 | self.softmax = nn.Softmax(dim=1) 21 | 22 | self.pos_id = self.tokenizer("Yes")["input_ids"][0] 23 | self.neg_id = self.tokenizer("No")["input_ids"][0] 24 | 25 | def score(self, inputs, batch_size=8): 26 | """ 27 | Get scores for the given samples. 28 | final_score = postive_score / (postive_score + negative_score) 29 | """ 30 | 31 | # The implementation of "forward" in T5 still requires decoder_input_ids. 32 | # Therefore, we construct a random one-word target sequence. 33 | # The content of the target has no effect on the final scores. 34 | tgts = ["No" for _ in range(len(inputs))] 35 | 36 | pos_score_list, neg_score_list = [], [] 37 | for i in tqdm(range(0, len(inputs), batch_size)): 38 | src_list = inputs[i: i + batch_size] 39 | tgt_list = tgts[i: i + batch_size] 40 | try: 41 | with torch.no_grad(): 42 | encoded_src = self.tokenizer( 43 | src_list, 44 | max_length=self.max_length, 45 | truncation=True, 46 | padding=True, 47 | return_tensors='pt' 48 | ) 49 | encoded_tgt = self.tokenizer( 50 | tgt_list, 51 | max_length=self.max_length, 52 | truncation=True, 53 | padding=True, 54 | return_tensors='pt' 55 | ) 56 | 57 | src_tokens = encoded_src['input_ids'].to(self.device) 58 | src_mask = encoded_src['attention_mask'].to(self.device) 59 | 60 | tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1) 61 | 62 | output = self.model( 63 | input_ids=src_tokens, 64 | attention_mask=src_mask, 65 | labels = tgt_tokens 66 | ) 67 | logits = output.logits.view(-1, self.model.config.vocab_size) 68 | 69 | pos_score = self.softmax(logits)[:, self.pos_id] # Yes 70 | neg_score = self.softmax(logits)[:, self.neg_id] # No 71 | 72 | cur_pos_score = [x.item() for x in pos_score] 73 | cur_neg_score = [x.item() for x in neg_score] 74 | pos_score_list += cur_pos_score 75 | neg_score_list += cur_neg_score 76 | 77 | except RuntimeError: 78 | print(f'source: {src_list}') 79 | print(f'target: {tgt_list}') 80 | exit(0) 81 | 82 | score_list = [] 83 | for i in range(len(pos_score_list)): 84 | score_list.append(pos_score_list[i] / (pos_score_list[i] + neg_score_list[i])) 85 | 86 | return score_list 87 | -------------------------------------------------------------------------------- /pseudo_data_summ.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | from tqdm import tqdm 4 | import random 5 | import numpy as np 6 | from rank_bm25 import BM25Okapi 7 | from nltk import sent_tokenize 8 | from utils import fast_rouge, get_dec_and_ref 9 | 10 | data_path = '/path/to/cnndm_train.jsonl' 11 | 12 | def load_data(data_path): 13 | data = [] 14 | with open(data_path) as f: 15 | for line in f: 16 | data.append(json.loads(line)) 17 | return data 18 | 19 | # Generate disfluent data. 1 positive sample corresponds to n_neg negative samples. 20 | # Each negative sample contains n_noise disfluent noises 21 | def disfluency_transformation(data, n_neg=3, n_noise=1): 22 | new_data = [] 23 | for i in tqdm(range(len(data))): 24 | cur_sample = {} 25 | ### reference summary as groundtruth 26 | # cur_sample['src'] = data[i]['src'] 27 | # cur_sample['tgt'] = ' '.join(data[i]['tgt']) 28 | ### lead 3 sentences as groundtruth 29 | cur_src = sent_tokenize(data[i]['src']) 30 | cur_sample['src'] = ' '.join(cur_src[3:]) 31 | cur_sample['tgt'] = ' '.join(cur_src[:3]) 32 | cur_sample['disfluent_tgt'] = [] 33 | # j-th negative sample for i-th data 34 | for j in range(n_neg): 35 | ### reference summary as groundtruth 36 | # cur_tgt = (' '.join(data[i]['tgt'])).split() 37 | cur_tgt = (' '.join(cur_src[:3])).split() 38 | # add k noises 39 | for k in range(n_noise): 40 | tgt_len = len(cur_tgt) 41 | # length of span for transformation. Sampled from poisson distribution. 42 | span_len = min(tgt_len, np.random.poisson(5, 1)[0]) 43 | # 1: insert, 2: delete, 3: shuffle 44 | transform_type = random.randint(1, 3) 45 | start_idx = random.randint(0, tgt_len - span_len) 46 | if transform_type == 1: 47 | copy_idx = random.randint(0, tgt_len - span_len) 48 | cur_tgt = cur_tgt[:start_idx] + cur_tgt[copy_idx:copy_idx+span_len] + cur_tgt[start_idx:] 49 | elif transform_type == 2: 50 | cur_tgt = cur_tgt[:start_idx] + cur_tgt[start_idx+span_len:] 51 | elif transform_type == 3: 52 | shuffled_span = cur_tgt[start_idx:start_idx+span_len] 53 | random.shuffle(shuffled_span) 54 | cur_tgt = cur_tgt[:start_idx] + shuffled_span + cur_tgt[start_idx+span_len:] 55 | cur_tgt = ' '.join(cur_tgt) 56 | cur_sample['disfluent_tgt'].append(cur_tgt) 57 | new_data.append(cur_sample) 58 | return new_data 59 | 60 | # Generate incoherent data. 1 positive sample corresponds to n_neg negative samples. 61 | # Each negative sample contains n_noise incoherent sentences 62 | # retrieved path: processed data containing bm25_rankning 63 | def incoherence_transformation(data, n_neg=3, n_noise=1, retrieved_path=None): 64 | if retrieved_path == None: 65 | corpus = [] 66 | for i in range(len(data)): 67 | corpus.append(data[i]['src'].split()) 68 | bm25 = BM25Okapi(corpus) 69 | for i in tqdm(range(len(data))): 70 | query = corpus[i] 71 | scores = bm25.get_scores(query) 72 | retrieved_index = np.flip(np.argsort(scores)).tolist() 73 | cur = {} 74 | cur['src'] = data[i]['src'] 75 | cur['tgt'] = data[i]['tgt'] 76 | cur['bm25_ranking'] = retrieved_index[:100] 77 | ### write data 78 | # with open('/path/to/cnndm/train_with_bm25.jsonl', 'a') as f: 79 | # print(json.dumps(cur), file=f) 80 | else: 81 | data_with_bm25 = load_data(retrieved_path) 82 | new_data = [] 83 | for i in tqdm(range(len(data))): 84 | cnt = 0 85 | # irrelevant_tgt = [] 86 | incoherent_tgt = [] 87 | cur_src = sent_tokenize(data[i]['src']) 88 | for idx in data_with_bm25[i]['bm25_ranking']: 89 | if idx == i or data[idx]['src'] == data[i]['src']: 90 | continue 91 | ''' 92 | # for reference summary 93 | cur_n = min(n_noise, len(data[i]['tgt'])) 94 | cur_n = min(cur_n, len(data[idx]['tgt'])) 95 | old_idx = random.sample(range(0, len(data[i]['tgt'])), cur_n) 96 | new_idx = random.sample(range(0, len(data[idx]['tgt'])), cur_n) 97 | cur_tgt = copy.deepcopy(data[i]['tgt']) 98 | for j in range(cur_n): 99 | cur_tgt[old_idx[j]] = data[idx]['tgt'][new_idx[j]] 100 | ''' 101 | # for lead 3 102 | cur_n = min(n_noise, 3) 103 | cur_tgt = copy.deepcopy(cur_src[:3]) 104 | retrieved_tgt = sent_tokenize(data[idx]['src'])[:3] 105 | old_idx = random.sample(range(0, len(cur_tgt)), cur_n) 106 | new_idx = random.sample(range(0, len(retrieved_tgt)), cur_n) 107 | for j in range(cur_n): 108 | cur_tgt[old_idx[j]] = retrieved_tgt[new_idx[j]] 109 | # irrelevant_tgt.append(' '.join(cur_tgt)) 110 | incoherent_tgt.append(' '.join(cur_tgt)) 111 | cnt += 1 112 | if cnt == n_neg: 113 | break 114 | cur = {} 115 | cur['src'] = ' '.join(cur_src) 116 | cur['tgt'] = ' '.join(cur_src[:3]) 117 | cur['gold_summary'] = data[i]['tgt'] 118 | cur['incoherent_tgt'] = incoherent_tgt 119 | new_data.append(cur) 120 | return new_data 121 | 122 | # Generate irrelevant data. 1 positive sample corresponds to n_neg negative samples. 123 | # retrieved path: processed data containing bm25_rankning 124 | def irrelevance_transformation(data, n_neg=3, retrieved_path=None): 125 | data_with_bm25 = load_data(retrieved_path) 126 | new_data = [] 127 | for i in tqdm(range(len(data))): 128 | cnt = 0 129 | irrelevant_tgt = [] 130 | cur_src = sent_tokenize(data[i]['src']) 131 | for idx in data_with_bm25[i]['bm25_ranking']: 132 | if idx == i or data[idx]['tgt'] == data[i]['tgt']: 133 | continue 134 | 135 | retrieved_tgt = sent_tokenize(data[idx]['src'])[:3] # negative samples 136 | irrelevant_tgt.append(' '.join(retrieved_tgt)) 137 | cnt += 1 138 | if cnt == n_neg: 139 | break 140 | cur = {} 141 | cur['src'] = data[i]['src'] 142 | cur['tgt'] = ' '.join(cur_src[:3]) # positive samples 143 | cur['gold_summary'] = data[i]['tgt'] # gold summary 144 | cur['irrelevant_tgt'] = irrelevant_tgt 145 | new_data.append(cur) 146 | return new_data 147 | 148 | def main(): 149 | # load data 150 | data = load_data(data_path) 151 | # process data for relevance dimension 152 | new_data = irrelevance_transformation(data, retrieved_path='/path/to/cnndm/train_with_bm25.jsonl') 153 | # write new data 154 | with open('/path/to/new_data.jsonl', 'w') as f: 155 | for i in range(len(new_data)): 156 | print(json.dumps(new_data[i]), file=f) 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /reproduce/README.md: -------------------------------------------------------------------------------- 1 | # Reproduce 2 | 3 | To reproduce all the results in the paper, we provide all meta-evaluation datasets, codes, and evaluation scores predicted by UniEval here. 4 | 5 | ## Meta-Evaluation Benchmarks 6 | Experiments are conducted on four tasks as follows: 7 | 8 | - Text Summarization: [SummEval](data/summarization/summeval.json) 9 | - Dialogue Response Generation: [Topical_Chat](data/dialogue/topical_chat.json) 10 | - Data-to-text: [SFRES](data/data2text/sfres.json) and [SFHOT](data/data2text/sfhot.json) 11 | - Facutal Consistency: [QAGS-CNNDM](data/fact/qags_cnndm.json) and [QAGS-XSum](data/fact/qags_xsum.json) 12 | 13 | Please note that the overall score in SummEval is the average score of the four dimensions, while the overall scores in other benchmarks are human-annotated scores. 14 | 15 | ## Calculate Correlations with Human Scores 16 | To verify that the proposed evaluator is qualified, we need to calculate correlations with human scores in each benchamark. 17 | 18 | We provide scripts to automatically get evaluation scores and correlations. For example, for summarization, run the following script: 19 | ``` 20 | ./eval_summarization.sh 21 | ``` 22 | The results of the predicted scores will be stored in the `predict/summarization` folder. It will then calculate the correlations between the predicted scores and the human judgments, and the results will be printed on the screen: 23 | ``` 24 | ********** Sample Level Correlations ********* 25 | +-------------+----------+----------+----------+ 26 | | Dimensions | Pearson | Spearman | Kendall | 27 | +-------------+----------+----------+----------+ 28 | | coherence | 0.533249 | 0.591811 | 0.424627 | 29 | | consistency | 0.634377 | 0.434997 | 0.349272 | 30 | | fluency | 0.597067 | 0.451053 | 0.353974 | 31 | | relevance | 0.434236 | 0.465623 | 0.337676 | 32 | | overall | 0.69961 | 0.658277 | 0.476311 | 33 | +-------------+----------+----------+----------+ 34 | 35 | ********* Summary Level Correlations ********* 36 | +-------------+----------+----------+----------+ 37 | | Dimensions | Pearson | Spearman | Kendall | 38 | +-------------+----------+----------+----------+ 39 | | coherence | 0.553818 | 0.575186 | 0.44249 | 40 | | consistency | 0.648491 | 0.445596 | 0.370913 | 41 | | fluency | 0.605978 | 0.449168 | 0.370628 | 42 | | relevance | 0.416225 | 0.42569 | 0.324938 | 43 | | overall | 0.698316 | 0.647441 | 0.496725 | 44 | +-------------+----------+----------+----------+ 45 | 46 | ********** System Level Correlations ********* 47 | +-------------+----------+----------+----------+ 48 | | Dimensions | Pearson | Spearman | Kendall | 49 | +-------------+----------+----------+----------+ 50 | | coherence | 0.810345 | 0.811765 | 0.683333 | 51 | | consistency | 0.945761 | 0.911765 | 0.75 | 52 | | fluency | 0.908509 | 0.844739 | 0.661094 | 53 | | relevance | 0.900644 | 0.838235 | 0.666667 | 54 | | overall | 0.967897 | 0.894118 | 0.733333 | 55 | +-------------+----------+----------+----------+ 56 | ``` 57 | Results for dialogue response generation should be: 58 | ``` 59 | ************** Turn Level Correlations ************* 60 | +-------------------+----------+----------+----------+ 61 | | Dimensions | Pearson | Spearman | Kendall | 62 | +-------------------+----------+----------+----------+ 63 | | naturalness | 0.443666 | 0.513986 | 0.373973 | 64 | | coherence | 0.595143 | 0.612942 | 0.465915 | 65 | | engagingness | 0.55651 | 0.604739 | 0.455941 | 66 | | groundedness | 0.536209 | 0.574954 | 0.451533 | 67 | | understandability | 0.380038 | 0.467807 | 0.360741 | 68 | | overall | 0.632796 | 0.662583 | 0.487272 | 69 | +-------------------+----------+----------+----------+ 70 | ``` 71 | Results for data-to-text should look like: 72 | ``` 73 | SFRES: 74 | ************ Sample Level Correlations *********** 75 | +-----------------+----------+----------+----------+ 76 | | Dimensions | Pearson | Spearman | Kendall | 77 | +-----------------+----------+----------+----------+ 78 | | naturalness | 0.367252 | 0.333399 | 0.247094 | 79 | | informativeness | 0.282079 | 0.224918 | 0.169297 | 80 | | overall | 0.370815 | 0.291593 | 0.214708 | 81 | +-----------------+----------+----------+----------+ 82 | 83 | SFHOT: 84 | +-----------------+----------+----------+----------+ 85 | | Dimensions | Pearson | Spearman | Kendall | 86 | +-----------------+----------+----------+----------+ 87 | | naturalness | 0.397428 | 0.319813 | 0.237635 | 88 | | informativeness | 0.357353 | 0.249329 | 0.191217 | 89 | | overall | 0.406425 | 0.320721 | 0.236024 | 90 | +-----------------+----------+----------+----------+ 91 | ``` 92 | Results of factual consistency detection are: 93 | ``` 94 | QAGS_Xsum: 95 | ********** Sample Level Correlations ********* 96 | +-------------+----------+----------+----------+ 97 | | Dimensions | Pearson | Spearman | Kendall | 98 | +-------------+----------+----------+----------+ 99 | | consistency | 0.461376 | 0.48792 | 0.399218 | 100 | +-------------+----------+----------+----------+ 101 | 102 | QAGS_CNNDM: 103 | ********** Sample Level Correlations ********* 104 | +-------------+----------+----------+----------+ 105 | | Dimensions | Pearson | Spearman | Kendall | 106 | +-------------+----------+----------+----------+ 107 | | consistency | 0.681681 | 0.662255 | 0.531636 | 108 | +-------------+----------+----------+----------+ 109 | ``` 110 | 111 | ## Predicted Scores 112 | [unieval_predict](./unieval_predict) folder contains the evaluation scores of UniEval on all meta-evaluation benchmarks. 113 | -------------------------------------------------------------------------------- /reproduce/correlation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os.path import join 3 | from prettytable import PrettyTable 4 | from scipy.stats import spearmanr, pearsonr, kendalltau 5 | from data_utils import load_json 6 | 7 | def calculate_correlation(pred_score, human_score, dim, result): 8 | assert len(pred_score) == len(human_score) 9 | if dim not in result: 10 | result[dim] = [0] * 3 11 | result[dim][0] += pearsonr(pred_score, human_score)[0] 12 | result[dim][1] += spearmanr(pred_score, human_score)[0] 13 | result[dim][2] += kendalltau(pred_score, human_score)[0] 14 | return result 15 | 16 | def print_correlations(result): 17 | table = PrettyTable(['Dimensions','Pearson', 'Spearman', 'Kendall']) 18 | for dim in result: 19 | table.add_row([dim, round(result[dim][0], 6), round(result[dim][1], 6), 20 | round(result[dim][2], 6)]) 21 | print(table) 22 | 23 | def get_unique_value(data, key): 24 | """ 25 | Get a list of unique values for a specific key in the data. 26 | """ 27 | value = set() 28 | for i in range(len(data)): 29 | if data[i][key] not in value: 30 | value.add(data[i][key]) 31 | return list(value) 32 | 33 | def correlation_for_summ(data, overall=True): 34 | """ 35 | Provides calculation results of correlation at sample level, summary level and system level. 36 | For the specific definitions, please refer to the paper: https://arxiv.org/abs/2010.07100 37 | """ 38 | dimensions = ['coherence', 'consistency', 'fluency', 'relevance'] 39 | if overall == True: 40 | dimensions.append('overall') 41 | 42 | # sample level correlation 43 | print('\n ********** Sample Level Correlations *********') 44 | result = {} 45 | for dim in dimensions: 46 | pred_score, human_score = [], [] 47 | for i in range(len(data)): 48 | pred_score.append(data[i]['predict_scores'][dim]) 49 | human_score.append(data[i]['scores'][dim]) 50 | result = calculate_correlation(pred_score, human_score, dim, result) 51 | print_correlations(result) 52 | 53 | # summary level correlation 54 | print('\n ********* Summary Level Correlations *********') 55 | result = {} 56 | docs = get_unique_value(data, 'doc_id') 57 | for dim in dimensions: 58 | valid_cnt = 0 59 | for doc_idx in docs: 60 | pred_score, human_score = [], [] 61 | for i in range(len(data)): 62 | if data[i]['doc_id'] == doc_idx: 63 | pred_score.append(data[i]['predict_scores'][dim]) 64 | human_score.append(data[i]['scores'][dim]) 65 | if len(set(pred_score)) == 1 or len(set(human_score)) == 1: 66 | continue 67 | result = calculate_correlation(pred_score, human_score, dim, result) 68 | valid_cnt += 1 69 | for j in range(3): 70 | result[dim][j] /= valid_cnt 71 | print_correlations(result) 72 | 73 | # system level correlations 74 | print('\n ********** System Level Correlations *********') 75 | result = {} 76 | systems = get_unique_value(data, 'system_id') 77 | for dim in dimensions: 78 | pred_score, human_score = [], [] 79 | for system_idx in systems: 80 | doc_cnt = 0 81 | cur_pred, cur_human = 0, 0 82 | for i in range(len(data)): 83 | if data[i]['system_id'] == system_idx: 84 | cur_pred += data[i]['predict_scores'][dim] 85 | cur_human += data[i]['scores'][dim] 86 | doc_cnt += 1 87 | pred_score.append(cur_pred / doc_cnt) 88 | human_score.append(cur_human / doc_cnt) 89 | result = calculate_correlation(pred_score, human_score, dim, result) 90 | print_correlations(result) 91 | 92 | 93 | def correlation_for_dialog(data, overall=True): 94 | """ 95 | Calculate turn-level correlation for dialogue response generation. 96 | """ 97 | dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability'] 98 | if overall == True: 99 | dimensions.append('overall') 100 | 101 | # turn level correlation 102 | print('\n ************** Turn Level Correlations *************') 103 | result = {} 104 | for dim in dimensions: 105 | pred_score, human_score = [], [] 106 | for i in range(len(data)): 107 | pred_score.append(data[i]['predict_scores'][dim]) 108 | human_score.append(data[i]['scores'][dim]) 109 | result = calculate_correlation(pred_score, human_score, dim, result) 110 | print_correlations(result) 111 | 112 | 113 | def correlation_for_d2t(data, overall=True): 114 | """ 115 | Calculate sample-level correlation for data-to-text. 116 | """ 117 | dimensions = ['naturalness', 'informativeness'] 118 | if overall == True: 119 | dimensions.append('overall') 120 | 121 | # sample level correlation 122 | print('\n ************ Sample Level Correlations ***********') 123 | result = {} 124 | for dim in dimensions: 125 | pred_score, human_score = [], [] 126 | for i in range(len(data)): 127 | pred_score.append(data[i]['predict_scores'][dim]) 128 | human_score.append(data[i]['scores'][dim]) 129 | result = calculate_correlation(pred_score, human_score, dim, result) 130 | print_correlations(result) 131 | 132 | def correlation_for_fact(data): 133 | """ 134 | Calculate sample-level factual consistency score. 135 | """ 136 | dim = 'consistency' 137 | 138 | # sample level correlation 139 | print('\n ********** Sample Level Correlations *********') 140 | result = {} 141 | pred_score, human_score = [], [] 142 | for i in range(len(data)): 143 | pred_score.append(data[i]['predict_scores'][dim]) 144 | human_score.append(data[i]['scores'][dim]) 145 | result = calculate_correlation(pred_score, human_score, dim, result) 146 | print_correlations(result) 147 | 148 | def main(args): 149 | data_path = join(join('predict', args.task), '{}_result.json'.format(args.dataset)) 150 | print('\nCorrelations for \'{}\' are shown below:'.format(data_path)) 151 | data = load_json(data_path) 152 | if args.task == 'summarization': 153 | correlation_for_summ(data) 154 | elif args.task == 'dialogue': 155 | correlation_for_dialog(data) 156 | elif args.task == 'data2text': 157 | correlation_for_d2t(data) 158 | else: 159 | correlation_for_fact(data) 160 | 161 | if __name__ == "__main__": 162 | parser = argparse.ArgumentParser( 163 | description='Calculate the correlations between predicted scores and human scores' 164 | ) 165 | 166 | parser.add_argument('--task', required=True, 167 | help='Specific NLG task to be evaluated', type=str) 168 | parser.add_argument('--dataset', required=True, 169 | help='The name of the meta-evaluation benchmark', type=str) 170 | 171 | args = parser.parse_args() 172 | assert args.task in ['summarization', 'dialogue', 'data2text', 'fact'] 173 | 174 | main(args) 175 | -------------------------------------------------------------------------------- /reproduce/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from os.path import exists, join 4 | 5 | def load_json(data_path): 6 | with open(data_path) as f: 7 | data = json.loads(f.read()) 8 | return data 9 | 10 | def write_predict(task, dataset, data, eval_scores): 11 | task_path = join('predict', task) 12 | if not exists(task_path): 13 | os.makedirs(task_path) 14 | write_path = join(task_path, '{}_result.json'.format(dataset)) 15 | if exists(write_path): 16 | print("\nThe predicted scores are not saved because the result file already exists !!!") 17 | else: 18 | assert len(data) == len(eval_scores) 19 | for i in range(len(data)): 20 | data[i]['predict_scores'] = eval_scores[i] 21 | with open(write_path, 'w') as f: 22 | json.dump(data, f, indent=4, ensure_ascii=False) 23 | print('\nPredicted scores are saved in {}'.format(write_path)) 24 | 25 | 26 | -------------------------------------------------------------------------------- /reproduce/eval_data2text.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=data/data2text/sfres.json 2 | 3 | python predict_score.py \ 4 | --task data2text \ 5 | --data_path ${DATA_DIR} \ 6 | --max_source_length 1024 \ 7 | 8 | python correlation.py \ 9 | --task data2text \ 10 | --dataset sfres \ 11 | 12 | DATA_DIR=data/data2text/sfhot.json 13 | 14 | python predict_score.py \ 15 | --task data2text \ 16 | --data_path ${DATA_DIR} \ 17 | --max_source_length 1024 \ 18 | 19 | python correlation.py \ 20 | --task data2text \ 21 | --dataset sfhot \ 22 | 23 | -------------------------------------------------------------------------------- /reproduce/eval_dialogue.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=data/dialogue/topical_chat.json 2 | 3 | python predict_score.py \ 4 | --task dialogue \ 5 | --data_path ${DATA_DIR} \ 6 | --max_source_length 1024 \ 7 | 8 | python correlation.py \ 9 | --task dialogue \ 10 | --dataset topical_chat \ 11 | -------------------------------------------------------------------------------- /reproduce/eval_fact.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=data/fact/qags_xsum.json 2 | 3 | python predict_score.py \ 4 | --task fact \ 5 | --data_path ${DATA_DIR} \ 6 | --max_source_length 1024 \ 7 | 8 | python correlation.py \ 9 | --task fact \ 10 | --dataset qags_xsum \ 11 | 12 | DATA_DIR=data/fact/qags_cnndm.json 13 | 14 | python predict_score.py \ 15 | --task fact \ 16 | --data_path ${DATA_DIR} \ 17 | --max_source_length 1024 \ 18 | 19 | python correlation.py \ 20 | --task fact \ 21 | --dataset qags_cnndm \ 22 | 23 | -------------------------------------------------------------------------------- /reproduce/eval_summarization.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=data/summarization/summeval.json 2 | 3 | python predict_score.py \ 4 | --task summarization \ 5 | --data_path ${DATA_DIR} \ 6 | --max_source_length 1024 \ 7 | 8 | python correlation.py \ 9 | --task summarization \ 10 | --dataset summeval \ 11 | -------------------------------------------------------------------------------- /reproduce/predict_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | from data_utils import load_json, write_predict 5 | sys.path.append("..") 6 | from metric.evaluator import get_evaluator 7 | 8 | def predict(args, save_result=True): 9 | # load standard meta-evaluation benchmark 10 | data = load_json(args.data_path) 11 | 12 | # Initialize the evaluator for a specific task 13 | evaluator = get_evaluator(task=args.task, 14 | max_length=args.max_source_length, 15 | device=args.device, 16 | cache_dir=args.cache_dir) 17 | 18 | # get the evaluation scores for all the dimensions 19 | eval_scores = evaluator.evaluate(data) 20 | 21 | # save results with predicted scores 22 | if save_result == True: 23 | dataset = os.path.basename(args.data_path[:-5]) # get the name of dataset (w/o '.json') 24 | write_predict(args.task, dataset, data, eval_scores) 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser( 28 | description='Get evaluation scores from UniEval from different NLG tasks' 29 | ) 30 | 31 | parser.add_argument('--data_path', required=True, 32 | help='Path to the meta-evaluation benchmark', type=str) 33 | parser.add_argument('--task', required=True, 34 | help='Specific NLG task to be evaluated', type=str) 35 | parser.add_argument('--cache_dir', default=None, 36 | help='Where to store the pretrained models downloaded from huggingface.co', type=str) 37 | parser.add_argument('--device', default='cuda:0', 38 | help='Available device for the calculations', type=str) 39 | parser.add_argument('--max_source_length', default=1024, 40 | help='The maximum total input sequence length after tokenization', type=int) 41 | 42 | args = parser.parse_args() 43 | 44 | predict(args) 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers >= 4.17.0.dev0 2 | accelerate 3 | datasets >= 1.8.0 4 | sentencepiece != 0.1.92 5 | protobuf 6 | rouge-score 7 | nltk 8 | py7zr 9 | torch >= 1.3 10 | evaluate 11 | prettytable -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from prettytable import PrettyTable 2 | 3 | def convert_to_json(output_list, src_list=None, ref_list=None, context_list=None, \ 4 | scores=None, doc_id=None, system_id=None): 5 | """ 6 | Convert the data into the json format. 7 | 8 | output_list: a list of model output 9 | src_list: source input for different NLG tasks. For example, source document for summarization 10 | and dialogue history for dialogue response generation 11 | ref_list: human-annotated groundtruth 12 | context_list: the context needed to evaluate several specific dimension. For example, 13 | additional factual information when evaluating engagingness and groundedness in dialogues 14 | scores: human scores for evaluating the model output. They can be used to calculate the correlation 15 | between evaluators and human judgements. The scores should be stored in a dictionary. For example, 16 | {'fluency': 2.0, 'coherence': 3.0} could be the human score for a sample. 17 | doc_id: the index of the input source. It can be used to calculate summary-level correlation for summarzation 18 | system_id: the index of the generation system. It can be used to calculate system-level correlation. 19 | """ 20 | json_data = [] 21 | for i in range(len(output_list)): 22 | cur = {} 23 | cur['system_output'] = output_list[i] 24 | if src_list is not None: 25 | cur['source'] = src_list[i] 26 | if ref_list is not None: 27 | cur['reference'] = ref_list[i] 28 | if context_list is not None: 29 | cur['context'] = context_list[i] 30 | if scores is not None: 31 | cur['scores'] = scores[i] 32 | if doc_id is not None: 33 | cur['doc_id'] = doc_id[i] 34 | if system_id is not None: 35 | cur['system_id'] = system_id[i] 36 | json_data.append(cur) 37 | return json_data 38 | 39 | 40 | def add_question(dimension, output, src=None, ref=None, context=None, task=None): 41 | """ 42 | Add questions to generate input in Bool-QA format for UniEval. 43 | 44 | dimension: specific dimension to be evaluated 45 | src: source input for different NLG tasks. For example, source document for summarization 46 | and dialogue history for dialogue response generation. 47 | output: output text generated by the models 48 | ref: human-annotataed groundtruth 49 | context: the context needed to evaluate several specific dimension. For example, 50 | additional factual information when evaluating engagingness and groundedness in dialogues. 51 | """ 52 | 53 | input_with_question = [] 54 | for i in range(len(output)): 55 | # For summarization 56 | if task == 'summarization': 57 | if dimension == 'fluency': 58 | cur_input = 'question: Is this a fluent paragraph? paragraph: ' + output[i] 59 | elif dimension == 'coherence': 60 | cur_input = 'question: Is this a coherent summary to the document? summary: ' + output[i] + ' document: ' + src[i] 61 | elif dimension == 'consistency': 62 | cur_input = 'question: Is this claim consistent with the document? claim: ' + output[i] + ' document: ' + src[i] 63 | elif dimension == 'relevance': 64 | cur_input = 'question: Is this summary relevant to the reference? summary: ' + output[i] + ' reference: ' + ref[i] 65 | else: 66 | raise NotImplementedError('The input format for this dimension is still undefined. Please customize it first.') 67 | # For dialogues 68 | elif task == 'dialogue': 69 | if dimension == 'naturalness': 70 | cur_input = 'question: Is this a natural response in the dialogue? response: ' + output[i] 71 | elif dimension == 'coherence': 72 | cur_input = 'question: Is this a coherent response given the dialogue history? response: '\ 73 | + output[i] + ' dialogue history: ' + src[i] 74 | elif dimension == 'engagingness': 75 | cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? response: '\ 76 | + output[i] + ' dialogue history: ' + src[i] + ' fact: ' + context[i] 77 | elif dimension == 'groundedness': 78 | cur_input = 'question: Is this response consistent with knowledge in the fact? response: '\ 79 | + output[i] + ' fact: ' + context[i] 80 | elif dimension == 'understandability': 81 | cur_input = 'question: Is this an understandable response in the dialogue? response: ' + output[i] 82 | else: 83 | raise NotImplementedError('The input format for this dimension is still undefined. Please customize it first.') 84 | # For data-to-text 85 | elif task == 'data2text': 86 | if dimension == 'naturalness': 87 | cur_input = 'question: Is this a fluent utterance? utterance: ' + output[i] 88 | elif dimension == 'informativeness': 89 | cur_input = 'question: Is this sentence informative according to the reference? sentence: '\ 90 | + output[i] + ' reference: ' + ref[i] 91 | else: 92 | raise NotImplementedError('The input format for this dimension is still undefined. Please customize it first.') 93 | # For factual consistency detection 94 | elif task == 'fact': 95 | if dimension == 'consistency': 96 | cur_input = 'question: Is this claim consistent with the document? claim: ' + output[i] + ' document: ' + src[i] 97 | else: 98 | raise NotImplementedError('No other dimensions for the factual consistency detection task.') 99 | # For new customized tasks 100 | else: 101 | raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.') 102 | input_with_question.append(cur_input) 103 | return input_with_question 104 | 105 | 106 | def print_scores(scores): 107 | table = PrettyTable(['Dimensions','Score']) 108 | print('\nEvaluation scores are shown below:') 109 | dims = list(scores[0].keys()) 110 | for dim in dims: 111 | cur_score = 0 112 | for i in range(len(scores)): 113 | cur_score += scores[i][dim] 114 | table.add_row([dim, round(cur_score / len(scores), 6)]) 115 | print(table) 116 | --------------------------------------------------------------------------------