├── sample_classification_f1_pred.jsonl ├── sample_classification_f1_test.jsonl ├── README.md ├── install.sh ├── sample_generation_pred.jsonl ├── sample_generation_test.jsonl ├── log └── 2022-11-21_04-09-46.json ├── sample_test.py ├── evaluation.py └── rouge_metric.py /sample_classification_f1_pred.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "nikluge-au-2022-train-000015", "input": "왜 청년들 일하는 데에 끼어 드는데?", "output": 0} 2 | {"id": "nikluge-au-2022-train-000016", "input": "왜 굳이 청년 정책 관련된 일을 하는데?", "output": 0} 3 | {"id": "nikluge-au-2022-train-000017", "input": "국민들이 암말 않고 열심히 세금 내니까 정말 별 개 그지 같은 데에 돈쓰고 쳐 자빠졌네...", "output": 1} 4 | {"id": "nikluge-au-2022-train-000018", "input": "예를 들어 누가 살해 됐고 용의자가 있는데 용의자가 자기 아니라 하면 진범이 있어야 하잖아?", "output": 0} 5 | {"id": "nikluge-au-2022-train-000019", "input": "&name& 일당은 뭐야?", "output": 0} -------------------------------------------------------------------------------- /sample_classification_f1_test.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "nikluge-au-2022-train-000015", "input": "왜 청년들 일하는 데에 끼어 드는데?", "output": 1} 2 | {"id": "nikluge-au-2022-train-000016", "input": "왜 굳이 청년 정책 관련된 일을 하는데?", "output": 1} 3 | {"id": "nikluge-au-2022-train-000017", "input": "국민들이 암말 않고 열심히 세금 내니까 정말 별 개 그지 같은 데에 돈쓰고 쳐 자빠졌네...", "output": 1} 4 | {"id": "nikluge-au-2022-train-000018", "input": "예를 들어 누가 살해 됐고 용의자가 있는데 용의자가 자기 아니라 하면 진범이 있어야 하잖아?", "output": 0} 5 | {"id": "nikluge-au-2022-train-000019", "input": "&name& 일당은 뭐야?", "output": 0} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 국립국어원 경진대회 평가 코드 2 | 3 | 이 코드는 국립국어원 경진대회에서 활용하는 평가코드 입니다. 4 | 모든 지표와 입력 형태에 대응하고 있지는 않습니다. - 계속적으로 추가될 예정입니다. 5 | sample_test.py 에 평가 코드를 활용한 예시가 있습니다. 6 | 7 | # 입력 8 | 9 | 데이터 세트는 JSON-L(jsonlines) 형식으로 제공되며, 각 JSON은 'id', 'input', 'output' 세 가지의 key를 가집니다. 평가에서는 'id'와 'output'만 활용됩니다. 10 | 주어진 훈련 데이터와 시험 데이터는 동일한 JSON-L 형식으로 제공되며, 시험 데이터의 경우에는 각 텍스트에 대한 output 항목이 빈 목록으로 제공됩니다. 참가팀은 해당 목록에 대해 모델의 출력 결과를 추가하여 제출합니다. 11 | ※ 훈련 데이터와 제출용 데이터의 형식 및 형태는 동일합니다. 12 | 13 | 14 | 정답 파일과 예측 파일의 데이터 수는 동일해야하고, 예측 파일의 모든 데이터 id는 정답 데이터에 존재 하며, 중복되면 안됩니다. 15 | 16 | 위와같이 데이터의 형식이 다른 경우 평가를 진행하지 않고 평가 결과 대신 오류 메세지를 제공합니다. 17 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | 2 | pip install --upgrade pip # ensures that pip is current 3 | 4 | pip install -U scikit-learn 5 | pip install --user -U nltk 6 | 7 | apt-get install g++ openjdk-8-jdk python3-dev python3-pip curl 8 | python3 -m pip install konlpy 9 | apt-get install curl git 10 | bash <(curl -s https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh) 11 | 12 | git clone https://github.com/google-research/bleurt.git 13 | cd bleurt 14 | pip install . 15 | 16 | pip install evaluate 17 | pip install bert-score 18 | 19 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 20 | 21 | pip install mecab 22 | 23 | wget https://storage.googleapis.com/bleurt-oss-21/BLEURT-20.zip . 24 | unzip BLEURT-20.zip -------------------------------------------------------------------------------- /sample_generation_pred.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "nikluge-2023-sc-dev-000001", "input": {"sentence1": "친구가 침대에 누워도 되냐고 물었다.", "sentence3": "친구가 양말과 겉옷을 벗고 침대에 누웠다."}, "output": "나는 양말과 겉옷을 벗으면 누워도 된다고 답해주었다."} 2 | {"id": "nikluge-2023-sc-dev-000002", "input": {"sentence1": "그는 도박을 하기 위해 부모님 돈을 가져다 썼다.", "sentence3": "결국 부모님 돈은 탕진되었고 그는 엄청난 원망을 받았다."}, "output": "부모님 돈은 점점 바닥나기 시작했다."} 3 | {"id": "nikluge-2023-sc-dev-000003", "input": {"sentence1": "취객이 영희의 앞을 가로막으며 영희를 희롱했다.", "sentence3": "철수의 기세에 눌린 취객은 뒤로 주춤 물러섰다."}, "output": "덩치가 큰 철수가 영희를 뒤로 숨기며 취객에게 다가섰다."} 4 | {"id": "nikluge-2023-sc-dev-000004", "input": {"sentence1": "나는 동생의 온도를 재기 위해 겨드랑이에 체온기를 꽂았다.", "sentence3": "그래서 나는 동생을 데리고 병원으로 향했다."}, "output": "그런데 생각보다 동생의 열이 높았다."} 5 | {"id": "nikluge-2023-sc-dev-000005", "input": {"sentence1": "친구와 둘이서 야구장에 갔다.", "sentence3": "하지만 친구는 잔뜩 신난 채 자신의 팀을 응원했다."}, "output": "나는 야구에 흥미가 없어서 졸았다."} 6 | -------------------------------------------------------------------------------- /sample_generation_test.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "nikluge-2023-sc-dev-000001", "input": {"sentence1": "친구가 침대에 누워도 되냐고 물었다.", "sentence3": "친구가 양말과 겉옷을 벗고 침대에 누웠다."}, "output": "나는 양말과 겉옷을 벗으면 누워도 된다고 답해주었다."} 2 | {"id": "nikluge-2023-sc-dev-000002", "input": {"sentence1": "그는 도박을 하기 위해 부모님 돈을 가져다 썼다.", "sentence3": "결국 부모님 돈은 탕진되었고 그는 엄청난 원망을 받았다."}, "output": "부모님 돈은 점점 바닥나기 시작했다."} 3 | {"id": "nikluge-2023-sc-dev-000003", "input": {"sentence1": "취객이 영희의 앞을 가로막으며 영희를 희롱했다.", "sentence3": "철수의 기세에 눌린 취객은 뒤로 주춤 물러섰다."}, "output": "덩치가 큰 철수가 영희를 뒤로 숨기며 취객에게 다가섰다."} 4 | {"id": "nikluge-2023-sc-dev-000004", "input": {"sentence1": "나는 동생의 온도를 재기 위해 겨드랑이에 체온기를 꽂았다.", "sentence3": "그래서 나는 동생을 데리고 병원으로 향했다."}, "output": "그런데 생각보다 동생의 열이 높았다."} 5 | {"id": "nikluge-2023-sc-dev-000005", "input": {"sentence1": "친구와 둘이서 야구장에 갔다.", "sentence3": "하지만 친구는 잔뜩 신난 채 자신의 팀을 응원했다."}, "output": "나는 야구에 흥미가 없어서 졸았다."} 6 | -------------------------------------------------------------------------------- /log/2022-11-21_04-09-46.json: -------------------------------------------------------------------------------- 1 | { 2 | "evaluation_complete": { 3 | "sample_generation_pred.jsonl": { 4 | "ROUGE": [ 5 | { 6 | "rouge-2": { 7 | "f": 0.7327070382345199, 8 | "p": 0.7553118939883646, 9 | "r": 0.7127307715899062 10 | }, 11 | "rouge-1": { 12 | "f": 0.8557483310257651, 13 | "p": 0.8802598638433663, 14 | "r": 0.8339819873544405 15 | }, 16 | "rouge-l": { 17 | "f": 0.7677128595677981, 18 | "p": 0.7908485284279742, 19 | "r": 0.7471444010426592 20 | } 21 | } 22 | ], 23 | "BLEU": [ 24 | 3.250373184486248e-232 25 | ] 26 | } 27 | }, 28 | "evalutation_fail": {} 29 | } -------------------------------------------------------------------------------- /sample_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | from evaluation import evaluation 4 | import os 5 | 6 | def jsonload(fname, encoding="utf-8"): 7 | with open(fname, encoding=encoding) as f: 8 | j = json.load(f) 9 | 10 | return j 11 | 12 | 13 | # json 개체를 파일이름으로 깔끔하게 저장 14 | def jsondump(j, fname): 15 | with open(fname, "w", encoding="UTF8") as f: 16 | json.dump(j, f, ensure_ascii=False, indent=4) 17 | 18 | 19 | # jsonl 파일 읽어서 list에 저장 20 | def jsonlload(fname, encoding="utf-8"): 21 | json_list = [] 22 | with open(fname, encoding=encoding) as f: 23 | for line in f.readlines(): 24 | json_list.append(json.loads(line)) 25 | return json_list 26 | 27 | 28 | if __name__ == '__main__': 29 | 30 | 31 | test_file_path = './2025_경진대회/korean_language_rag_V1.0_test_gold.json' 32 | submit_file_path = './2025_경진대회/result.test_rag.qwen-8b_250604.json' 33 | 34 | # test_file_path = './함의분석_result_with_SFT.json' 35 | # submit_file_path = './함의test(output포함).json' 36 | 37 | 38 | # test_file_path = './국회회의록안건별요약_test_with_answer.json' 39 | # submit_file_path = './국회회의록안건별요약_result_without SFT.json' 40 | 41 | # test_file_path = './sample_mse_test.jsonl' 42 | # submit_file_path = './sample_mse_pred.jsonl' 43 | 44 | # test_file_path = '(정답지)nikluge-2022-nli-test-answer.jsonl' 45 | # submit_file_path = 'mse_test_236(100).jsonl' 46 | 47 | # test_file_path = 'sample_sa_test.jsonl' 48 | # submit_file_path = 'sample_sa_pred.jsonl' 49 | 50 | # test_file_path = 'sample_generation_test.jsonl' 51 | # submit_file_path = 'sample_generation_pred.jsonl' 52 | 53 | # test_file_path = './data/nikluge-sc-2023-test-answer.jsonl' 54 | # submit_file_path = '02.jsonl' 55 | 56 | # test_file_path = './data/nikluge-2022-nli-test-answer.jsonl' 57 | # submit_file_path = './submit/2022_확신성 추론_180개.jsonl' 58 | print(test_file_path) 59 | print(submit_file_path) 60 | 61 | # test_file_path = 'sample_multi-label-dict_true.jsonl' 62 | # submit_file_path = 'sample_multi-label-dict_pred.jsonl' 63 | 64 | log_file_path = 'log/' 65 | log_dict = { 66 | 'evaluation_complete':{ 67 | 68 | }, 69 | 'evalutation_fail':{ 70 | 71 | } 72 | } 73 | # test file load 74 | test_data = jsonload(test_file_path) 75 | submit_data = jsonload(submit_file_path) 76 | 77 | # for filename in os.listdir('./submit'): 78 | # 79 | # # submit_file_path = './submit/'+filename 80 | # if filename == submit_file_path: 81 | # submit_file_path = './submit/'+filename 82 | # else: 83 | # continue 84 | # # 제출파일 load, json, jsonl 두 형태 모두 처리 85 | # try: 86 | # submit_data = jsonload(submit_file_path) 87 | # except: 88 | # try: 89 | # submit_data = jsonlload(submit_file_path) 90 | # except: 91 | # print(submit_file_path + ' 파일 형식 오류 - json, 또는 jsonl 형식이 아님') 92 | # log_dict['evalutation_fail'][submit_file_path] = '파일 형식 오류 - json, 또는 jsonl 형식이 아님' 93 | 94 | # result = evaluation(submit_data, test_data, evaluation_metrics=['classification_micro_F1', 'classification_macro_F1', 'classification_weighted_F1']) 95 | # result = evaluation(submit_data, test_data, evaluation_metrics=['MSE']) 96 | result = evaluation(submit_data, test_data, evaluation_metrics=['korean_contest_RAG_QA']) 97 | # result = evaluation(submit_data, test_data, evaluation_metrics=['ROUGE-1'], ratio=1, iteration=1) 98 | # result = evaluation(submit_data, test_data, evaluation_metrics=['sa_f1'], ratio=1, iteration=1) 99 | # result = evaluation(submit_data, test_data, evaluation_metrics=['ROUGE-1', 'BLEU', 'ROUGE-L']) 100 | # result = evaluation(submit_data, test_data, evaluation_metrics=['multi_label_classification_micro_F1']) 101 | 102 | # result = evaluation(submit_data, test_data, evaluation_metrics=['ROUGE-1', 'BLEU', 'ROUGE-L']) 103 | try: 104 | # result = evaluation(submit_data, test_data, evaluation_metrics=['bleurt', 'bertscore', 'ROUGE-1'], ratio=0.7, iteration=2) 105 | print(result) 106 | log_dict['evaluation_complete'][submit_file_path] = result 107 | except: 108 | print(submit_file_path + ' 파일 형식 오류 - 데이터 형태가 기준과 다름') 109 | log_dict['evalutation_fail'][submit_file_path] = '파일 형식 오류 - 데이터 형태가 제출 기준과 다름' 110 | 111 | # 로그파일 저장 112 | now = datetime.now() 113 | log_file_name = now.strftime('%Y-%m-%d_%H-%M-%S') + '.json' 114 | jsondump(log_dict, log_file_path + log_file_name) 115 | print(log_dict) 116 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import f1_score 2 | from sklearn.metrics import mean_squared_error 3 | from sklearn.metrics import accuracy_score 4 | from nltk.translate.bleu_score import sentence_bleu 5 | from rouge_metric import Rouge 6 | import random 7 | from konlpy.tag import Mecab 8 | import evaluate 9 | import tensorflow as tf 10 | 11 | # GPU 메모리 증가를 허용하도록 설정 12 | gpus = tf.config.experimental.list_physical_devices('GPU') 13 | if gpus: 14 | try: 15 | for gpu in gpus: 16 | tf.config.experimental.set_memory_growth(gpu, True) 17 | except RuntimeError as e: 18 | print(e) 19 | 20 | 21 | bert_scorer = evaluate.load('bertscore') 22 | bert_model_type = 'bert-base-multilingual-cased' 23 | 24 | from bleurt import score 25 | 26 | checkpoint = "BLEURT-20" 27 | 28 | scorer = score.BleurtScorer(checkpoint) 29 | 30 | tokenizer = Mecab() 31 | 32 | 33 | def data_sampling(true_data_list, pred_data_list, ratio): 34 | index_list = list(range(len(true_data_list))) 35 | random.shuffle(index_list) 36 | 37 | sampled_true_data_list = [0 for _ in range(len(true_data_list))] 38 | sampled_pred_data_list = [0 for _ in range(len(pred_data_list))] 39 | for i in range(len(true_data_list)): 40 | sampled_true_data_list[i] = true_data_list[index_list[i]] 41 | sampled_pred_data_list[i] = pred_data_list[index_list[i]] 42 | 43 | ratio_idx = int(len(true_data_list) * ratio) 44 | 45 | return sampled_true_data_list[:ratio_idx], sampled_pred_data_list[:ratio_idx] 46 | 47 | def calc_multi_label_classification_micro_F1(true, pred): 48 | 49 | if type(true[0]) is list: 50 | if type(true[0][0]) is int: 51 | pass 52 | elif type(true[0][0]) is float: 53 | pass 54 | elif type(true[0][0]) is bool: 55 | pass 56 | elif type(true[0][0]) is str: 57 | pass 58 | else: 59 | return -1 60 | 61 | elif type(true[0]) is dict: 62 | 63 | sample_key = next(iter(true[0])) 64 | 65 | if type(true[0][sample_key]) is int: 66 | pass 67 | elif type(true[0][sample_key]) is float: 68 | pass 69 | elif type(true[0][sample_key]) is str: 70 | def dict_to_list(input_dict): 71 | output_list = [] 72 | for instance in input_dict.values(): 73 | if instance == 'True' or instance == 'true': 74 | output_list.append(1) 75 | else: 76 | output_list.append(0) 77 | 78 | return output_list 79 | 80 | formated_pred = list(map(lambda x: dict_to_list(x), pred)) 81 | formated_true = list(map(lambda x: dict_to_list(x), true)) 82 | f1_micro = f1_score(y_true=formated_true, y_pred=formated_pred, average='micro') 83 | 84 | return f1_micro 85 | 86 | elif type(true[0][sample_key]) is bool: 87 | def dict_to_list(input_dict): 88 | output_list = [] 89 | for instance in input_dict.values(): 90 | if instance is True: 91 | output_list.append(1) 92 | else: 93 | output_list.append(0) 94 | 95 | formated_pred = list(map(lambda x: dict_to_list(x), pred)) 96 | formated_true = list(map(lambda x: dict_to_list(x), true)) 97 | f1_micro = f1_score(y_true=formated_true, y_pred=formated_pred, average='micro') 98 | return f1_micro 99 | 100 | else: 101 | return -1 102 | else: 103 | return -1 104 | 105 | 106 | def calc_classification_micro_F1(true, pred): 107 | return f1_score(true, pred, average='micro') 108 | 109 | 110 | def calc_classification_macro_F1(true, pred): 111 | return f1_score(true, pred, average='macro') 112 | 113 | 114 | def calc_classification_weighted_F1(true, pred): 115 | return f1_score(true, pred, average='weighted') 116 | 117 | 118 | def calc_MSE(true, pred): 119 | for i in range(len(true)): 120 | if type(true[i]) == str: 121 | true[i] = float(true[i]) 122 | if type(pred[i]) == str: 123 | pred[i] = float(pred[i]) 124 | return mean_squared_error(true, pred) 125 | 126 | 127 | def calc_ROUGE_1(true, pred): 128 | rouge_evaluator = Rouge( 129 | metrics=["rouge-n", "rouge-l"], 130 | max_n=2, 131 | limit_length=True, 132 | length_limit=1000, 133 | length_limit_type="words", 134 | use_tokenizer=True, 135 | apply_avg=True, 136 | apply_best=False, 137 | alpha=0.5, # Default F1_score 138 | weight_factor=1.0, 139 | ) 140 | 141 | scores = rouge_evaluator.get_scores(pred, true) 142 | return scores['rouge-1']['f'] 143 | 144 | 145 | def calc_ROUGE_L(true, pred): 146 | rouge_evaluator = Rouge( 147 | metrics=["rouge-n", "rouge-l"], 148 | max_n=2, 149 | limit_length=True, 150 | length_limit=1000, 151 | length_limit_type="words", 152 | use_tokenizer=True, 153 | apply_avg=True, 154 | apply_best=False, 155 | alpha=0.5, # Default F1_score 156 | weight_factor=1.0, 157 | ) 158 | 159 | scores = rouge_evaluator.get_scores(pred, true) 160 | return scores['rouge-l']['f'] 161 | 162 | 163 | def calc_BLEU(true, pred, apply_avg=True, apply_best=False, use_mecab=True): 164 | stacked_bleu = [] 165 | 166 | if type(true[0]) is str: 167 | true = list(map(lambda x: [x], true)) 168 | 169 | for i in range(len(true)): 170 | best_bleu = 0 171 | sum_bleu = 0 172 | for j in range(len(true[i])): 173 | 174 | if use_mecab: 175 | ref = tokenizer.morphs(true[i][j]) 176 | candi = tokenizer.morphs(pred[i]) 177 | else: 178 | ref = true[i][j].split() 179 | candi = pred[i].split() 180 | 181 | 182 | score = sentence_bleu([ref], candi, weights=(1, 0, 0, 0)) 183 | 184 | sum_bleu += score 185 | if score > best_bleu: 186 | best_bleu = score 187 | 188 | avg_bleu = sum_bleu / len(true[i]) 189 | if apply_best: 190 | stacked_bleu.append(best_bleu) 191 | if apply_avg: 192 | stacked_bleu.append(avg_bleu) 193 | 194 | return sum(stacked_bleu) / len(stacked_bleu) 195 | 196 | 197 | def evaluation_sa_f1(true_data, pred_data): 198 | true_data_list = true_data 199 | pred_data_list = pred_data 200 | 201 | ce_eval = { 202 | 'TP': 0, 203 | 'FP': 0, 204 | 'FN': 0, 205 | 'TN': 0 206 | } 207 | 208 | pipeline_eval = { 209 | 'TP': 0, 210 | 'FP': 0, 211 | 'FN': 0, 212 | 'TN': 0 213 | } 214 | 215 | for i in range(len(true_data_list)): 216 | 217 | # TP, FN checking 218 | is_ce_found = False 219 | is_pipeline_found = False 220 | for y_ano in true_data_list[i]: 221 | y_category = y_ano[0] 222 | y_polarity = y_ano[2] 223 | for p_ano in pred_data_list[i]: 224 | p_category = p_ano[0] 225 | p_polarity = p_ano[1] 226 | 227 | if y_category == p_category: 228 | is_ce_found = True 229 | if y_polarity == p_polarity: 230 | is_pipeline_found = True 231 | 232 | break 233 | 234 | if is_ce_found is True: 235 | ce_eval['TP'] += 1 236 | else: 237 | ce_eval['FN'] += 1 238 | 239 | if is_pipeline_found is True: 240 | pipeline_eval['TP'] += 1 241 | else: 242 | pipeline_eval['FN'] += 1 243 | 244 | is_ce_found = False 245 | is_pipeline_found = False 246 | 247 | # FP checking 248 | for p_ano in pred_data_list[i]: 249 | p_category = p_ano[0] 250 | p_polarity = p_ano[1] 251 | 252 | for y_ano in true_data_list[i]: 253 | y_category = y_ano[0] 254 | y_polarity = y_ano[2] 255 | 256 | if y_category == p_category: 257 | is_ce_found = True 258 | if y_polarity == p_polarity: 259 | is_pipeline_found = True 260 | 261 | break 262 | 263 | if is_ce_found is False: 264 | ce_eval['FP'] += 1 265 | 266 | if is_pipeline_found is False: 267 | pipeline_eval['FP'] += 1 268 | 269 | is_ce_found = False 270 | is_pipeline_found = False 271 | 272 | # ce_precision = ce_eval['TP']/(ce_eval['TP']+ce_eval['FP']) 273 | # ce_recall = ce_eval['TP']/(ce_eval['TP']+ce_eval['FN']) 274 | # 275 | # ce_result = { 276 | # 'Precision': ce_precision, 277 | # 'Recall': ce_recall, 278 | # 'F1': 2*ce_recall*ce_precision/(ce_recall+ce_precision) 279 | # } 280 | 281 | pipeline_precision = pipeline_eval['TP'] / (pipeline_eval['TP'] + pipeline_eval['FP']) 282 | pipeline_recall = pipeline_eval['TP'] / (pipeline_eval['TP'] + pipeline_eval['FN']) 283 | if pipeline_recall == 0 and pipeline_precision == 0: 284 | pipeline_f1 = 0 285 | else: 286 | pipeline_f1 = 2 * pipeline_recall * pipeline_precision / (pipeline_recall + pipeline_precision) 287 | 288 | pipeline_result = { 289 | 'Precision': pipeline_precision, 290 | 'Recall': pipeline_recall, 291 | 'F1': pipeline_f1 292 | } 293 | 294 | return { 295 | "sa_f1": pipeline_result['F1'] 296 | } 297 | # return { 298 | # 'category extraction result': ce_result, 299 | # 'entire pipeline result': pipeline_result 300 | # } 301 | 302 | 303 | def calc_bleurt(true_data, pred_data): 304 | if type(true_data[0]) is list: 305 | true_data = list(map(lambda x: x[0], true_data)) 306 | 307 | scores = scorer.score(references=true_data, candidates=pred_data, batch_size=64) 308 | 309 | return sum(scores) / len(scores) 310 | 311 | 312 | def calc_bertscore(true_data, pred_data): 313 | if type(true_data[0]) is list: 314 | true_data = list(map(lambda x: x[0], true_data)) 315 | 316 | scores = bert_scorer.compute(predictions=pred_data, references=true_data, model_type=bert_model_type) 317 | 318 | return sum(scores['f1']) / len(scores['f1']) 319 | 320 | def calc_Accuracy(true_data, pred_data): 321 | 322 | return accuracy_score(true_data, pred_data) 323 | 324 | def calc_multi_target_Accuracy(true_data, pred_data): 325 | """ 326 | Function to calculate multi-label accuracy between true labels and predicted labels. 327 | 328 | Args: 329 | - true_data: List of dictionaries with true label data from multiple documents. 330 | - pred_data: List of dictionaries with predicted label data from multiple documents. 331 | 332 | Returns: 333 | - accuracy_score: The proportion of correct labels predicted across all targets. 334 | - If there's a mismatch in the number of sentences, return an error log. 335 | """ 336 | 337 | # Flatten the output lists from true_data and pred_data 338 | true_output = [] 339 | pred_output = [] 340 | 341 | for doc_output in true_data: 342 | true_output.extend(doc_output) 343 | 344 | for doc_output in pred_data: 345 | pred_output.extend(doc_output) 346 | 347 | # Check if the number of sentences in true and pred data match 348 | if len(true_output) != len(pred_output): 349 | return f"Error: Mismatch in the number of true and predicted outputs. True data: {len(true_output)}, Pred data: {len(pred_output)}" 350 | 351 | # Sort both lists by id 352 | true_output = sorted(true_output, key=lambda x: x["id"]) 353 | pred_output = sorted(pred_output, key=lambda x: x["id"]) 354 | 355 | # Count correct predictions 356 | correct_count = 0 357 | total_count = len(true_output) 358 | 359 | for true_item, pred_item in zip(true_output, pred_output): 360 | # Check if IDs match 361 | if true_item["id"] != pred_item["id"]: 362 | return f"Error: Mismatch in IDs: {true_item['id']} != {pred_item['id']}" 363 | 364 | # Compare labels 365 | if true_item["label"] == pred_item["label"]: 366 | correct_count += 1 367 | 368 | # Calculate accuracy score 369 | accuracy_score = correct_count / total_count if total_count > 0 else 0 370 | return accuracy_score 371 | 372 | 373 | def calc_exact_match(true_data, pred_data): 374 | """ 375 | Calculate Exact Match score where true_data may contain multiple acceptable answers separated by # 376 | """ 377 | correct = 0 378 | total = len(true_data) 379 | 380 | for true, pred in zip(true_data, pred_data): 381 | # Split true answer into acceptable variations 382 | acceptable_answers = true.split('#') 383 | # Check if prediction matches any acceptable answer 384 | if any(pred.strip() == ans.strip() for ans in acceptable_answers): 385 | correct += 1 386 | 387 | return correct / total if total > 0 else 0 388 | 389 | def normalize_answer_text(text): 390 | """ 391 | Normalize answer text by removing quotes and extra whitespace 392 | """ 393 | # Remove both single and double quotes 394 | text = text.replace('"', '').replace("'", "") 395 | # Remove extra whitespace 396 | text = text.strip() 397 | return text 398 | 399 | def extract_answer_and_reason(text): 400 | """ 401 | Split the answer into selected answer part and reasoning part 402 | """ 403 | # Find the split point with '가 옳다' or similar patterns 404 | split_patterns = ['가 옳다', '이 옳다'] 405 | 406 | for pattern in split_patterns: 407 | if pattern in text: 408 | split_idx = text.find(pattern) + len(pattern) 409 | answer = text[:split_idx].strip() 410 | reason = text[split_idx:].strip() 411 | # Remove leading punctuation from reason 412 | reason = reason.lstrip('., ') 413 | return answer, reason 414 | 415 | # If no pattern is found, return the whole text as answer and empty reason 416 | return text.strip(), "" 417 | 418 | def evaluation_korean_contest_culture_QA(true_data, pred_data): 419 | # Separate questions by type 420 | multiple_choice_qs = {"true": [], "pred": []} 421 | short_answer_qs = {"true": [], "pred": []} 422 | descriptive_qs = {"true": [], "pred": []} 423 | 424 | # Categorize questions by type 425 | for true_item, pred_item in zip(true_data, pred_data): 426 | if true_item["id"] != pred_item["id"]: 427 | return { 428 | "error": f"ID mismatch: {true_item['id']} != {pred_item['id']}" 429 | } 430 | 431 | q_type = true_item["input"]["question_type"] 432 | true_ans = true_item["output"]["answer"] 433 | pred_ans = pred_item["output"]["answer"] 434 | 435 | if q_type == "선다형": 436 | multiple_choice_qs["true"].append(true_ans) 437 | multiple_choice_qs["pred"].append(pred_ans) 438 | elif q_type == "단답형": 439 | short_answer_qs["true"].append(true_ans) 440 | short_answer_qs["pred"].append(pred_ans) 441 | elif q_type == "서술형": 442 | descriptive_qs["true"].append(true_ans) 443 | descriptive_qs["pred"].append(pred_ans) 444 | 445 | # Calculate scores for each type 446 | scores = {} 447 | 448 | # Multiple choice questions (Accuracy) 449 | if multiple_choice_qs["true"]: 450 | scores["accuracy"] = calc_Accuracy(multiple_choice_qs["true"], multiple_choice_qs["pred"]) 451 | else: 452 | scores["accuracy"] = 0 453 | 454 | # Short answer questions (Exact Match) 455 | if short_answer_qs["true"]: 456 | scores["exact_match"] = calc_exact_match(short_answer_qs["true"], short_answer_qs["pred"]) 457 | else: 458 | scores["exact_match"] = 0 459 | 460 | # Descriptive questions (ROUGE, BERTScore, BLEURT) 461 | if descriptive_qs["true"]: 462 | scores["rouge_1"] = calc_ROUGE_1(descriptive_qs["true"], descriptive_qs["pred"]) 463 | scores["bertscore"] = calc_bertscore(descriptive_qs["true"], descriptive_qs["pred"]) 464 | scores["bleurt"] = calc_bleurt(descriptive_qs["true"], descriptive_qs["pred"]) 465 | scores["descriptive_avg"] = (scores["rouge_1"] + scores["bertscore"] + scores["bleurt"]) / 3 466 | else: 467 | scores["rouge_1"] = 0 468 | scores["bertscore"] = 0 469 | scores["bleurt"] = 0 470 | scores["descriptive_avg"] = 0 471 | 472 | # Calculate final score (average of the three types) 473 | type_scores = [] 474 | if multiple_choice_qs["true"]: 475 | type_scores.append(scores["accuracy"]) 476 | if short_answer_qs["true"]: 477 | type_scores.append(scores["exact_match"]) 478 | if descriptive_qs["true"]: 479 | type_scores.append(scores["descriptive_avg"]) 480 | 481 | scores["final_score"] = sum(type_scores) / len(type_scores) if type_scores else 0 482 | 483 | return scores 484 | 485 | def evaluation_korean_contest_RAG_QA(true_data, pred_data): 486 | scores = { 487 | "exact_match": 0, 488 | "rouge_1": 0, 489 | "bertscore": 0, 490 | "bleurt": 0, 491 | "descriptive_avg": 0, 492 | "final_score": 0 493 | } 494 | 495 | # Prepare lists for answer and reason evaluation 496 | true_answers = [] 497 | pred_answers = [] 498 | true_reasons = [] 499 | pred_reasons = [] 500 | 501 | # Process each QA pair 502 | for true_item, pred_item in zip(true_data, pred_data): 503 | if true_item["id"] != pred_item["id"]: 504 | return { 505 | "error": f"ID mismatch: {true_item['id']} != {pred_item['id']}" 506 | } 507 | 508 | # Extract answer and reason parts 509 | true_ans, true_reason = extract_answer_and_reason(true_item["output"]["answer"]) 510 | pred_ans, pred_reason = extract_answer_and_reason(pred_item["output"]["answer"]) 511 | 512 | # Normalize answers 513 | true_ans = normalize_answer_text(true_ans) 514 | pred_ans = normalize_answer_text(pred_ans) 515 | 516 | true_answers.append(true_ans) 517 | pred_answers.append(pred_ans) 518 | 519 | if true_reason and pred_reason: # Only include if both have reasoning 520 | true_reasons.append(true_reason) 521 | pred_reasons.append(pred_reason) 522 | 523 | # Calculate Exact Match score for answers 524 | scores["exact_match"] = calc_exact_match(true_answers, pred_answers) 525 | 526 | # Calculate generation metrics for reasoning if we have reasoning pairs 527 | if true_reasons and pred_reasons: 528 | scores["rouge_1"] = calc_ROUGE_1(true_reasons, pred_reasons) 529 | scores["bertscore"] = calc_bertscore(true_reasons, pred_reasons) 530 | scores["bleurt"] = calc_bleurt(true_reasons, pred_reasons) 531 | scores["descriptive_avg"] = (scores["rouge_1"] + scores["bertscore"] + scores["bleurt"]) / 3 532 | 533 | # Calculate final score (average of EM and descriptive_avg) 534 | scores["final_score"] = (scores["exact_match"] + scores["descriptive_avg"]) / 2 535 | 536 | return scores 537 | 538 | def evaluation(inferenced_data, ground_truth, evaluation_metrics=[], ratio=1, iteration=1): 539 | temp_ground_truth_dict = {} 540 | true_data_list = [] 541 | pred_data_list = [] 542 | 543 | if len(inferenced_data) != len(ground_truth): 544 | return { 545 | 'error': '제출 파일과 정답 파일의 데이터 개수가 서로 다름' 546 | } 547 | 548 | # sa_f1 인 경우 549 | if 'sa_f1' in evaluation_metrics: 550 | # 데이터 list로 변경 551 | for data in ground_truth: 552 | if data['id'] in temp_ground_truth_dict: 553 | return { 554 | "error": "정답 데이터에 중복된 id를 가지는 경우 존재" 555 | } 556 | temp_ground_truth_dict[data['id']] = data['annotation'] 557 | 558 | for data in inferenced_data: 559 | if data['id'] not in temp_ground_truth_dict: 560 | return { 561 | "error": "제출 파일과 정답 파일의 id가 일치하지 않음" 562 | } 563 | true_data_list.append(temp_ground_truth_dict[data['id']]) 564 | pred_data_list.append(data['annotation']) 565 | sampled_true_data_list, sampled_pred_data_list = data_sampling(true_data_list, pred_data_list, ratio) 566 | 567 | return evaluation_sa_f1(sampled_true_data_list, sampled_pred_data_list) 568 | 569 | elif 'korean_contest_culture_QA' in evaluation_metrics: 570 | return evaluation_korean_contest_culture_QA(ground_truth, inferenced_data) 571 | 572 | elif 'korean_contest_RAG_QA' in evaluation_metrics: 573 | return evaluation_korean_contest_RAG_QA(ground_truth, inferenced_data) 574 | 575 | # 평가 가능한 metric 목록 576 | defined_evaluation_metric_list = ['classification_micro_F1', 'classification_macro_F1', 577 | 'classification_weighted_F1', 'MSE', 'ROUGE-1', 'BLEU', 'bleurt', 'bertscore', 'multi_label_classification_micro_F1', 'ROUGE-L', 'Accuracy', 'multi_target_Accuracy'] 578 | metric_to_func = { 579 | "classification_micro_F1": calc_classification_micro_F1, 580 | "classification_macro_F1": calc_classification_macro_F1, 581 | "classification_weighted_F1": calc_classification_weighted_F1, 582 | "MSE": calc_MSE, 583 | "ROUGE-1": calc_ROUGE_1, 584 | "ROUGE-L": calc_ROUGE_L, 585 | "BLEU": calc_BLEU, 586 | "bleurt": calc_bleurt, 587 | "bertscore": calc_bertscore, 588 | 'multi_label_classification_micro_F1': calc_multi_label_classification_micro_F1, 589 | 'Accuracy': calc_Accuracy, 590 | 'multi_target_Accuracy': calc_multi_target_Accuracy 591 | } 592 | 593 | # 평가 대상 metric 정리 594 | for metric in evaluation_metrics: 595 | if metric not in defined_evaluation_metric_list: 596 | return { 597 | "error": f"evaluation metric 중 {metric}은 정의된 metric에 포함되어있지 않습니다." 598 | } 599 | 600 | evaluation_result = {metric: [] for metric in evaluation_metrics} 601 | 602 | # 데이터 list로 변경 603 | for data in ground_truth: 604 | if data['id'] in temp_ground_truth_dict: 605 | return { 606 | "error": "정답 데이터에 중복된 id를 가지는 경우 존재" 607 | } 608 | temp_ground_truth_dict[data['id']] = data['output'] 609 | 610 | for data in inferenced_data: 611 | if data['id'] not in temp_ground_truth_dict: 612 | return { 613 | "error": "제출 파일과 정답 파일의 id가 일치하지 않음" 614 | } 615 | true_data_list.append(temp_ground_truth_dict[data['id']]) 616 | pred_data_list.append(data['output']) 617 | 618 | # 평가 - iteration회 만금 ratio 비율로 sampling 한 데이터에 대해 평가 619 | for i in range(iteration): 620 | sampled_true_data_list, sampled_pred_data_list = data_sampling(true_data_list, pred_data_list, ratio) 621 | for metric in evaluation_metrics: 622 | result = metric_to_func[metric](sampled_true_data_list, sampled_pred_data_list) 623 | if type(result) is str: 624 | return { 625 | "error": result 626 | } 627 | 628 | evaluation_result[metric].append(result) 629 | 630 | #iteration 만큼 반복된 결과에 대한 평균 631 | for key, value in evaluation_result.items(): 632 | evaluation_result[key] = sum(value) / len(value) 633 | 634 | return evaluation_result 635 | -------------------------------------------------------------------------------- /rouge_metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import platform 4 | import itertools 5 | import collections 6 | import pkg_resources # pip install py-rouge 7 | from io import open 8 | from konlpy.tag import Mecab 9 | 10 | 11 | 12 | 13 | class Rouge: 14 | DEFAULT_METRICS = {"rouge-n"} 15 | DEFAULT_N = 1 16 | STATS = ["f", "p", "r"] 17 | AVAILABLE_METRICS = {"rouge-n", "rouge-l", "rouge-w"} 18 | AVAILABLE_LENGTH_LIMIT_TYPES = {"words", "bytes"} 19 | REMOVE_CHAR_PATTERN = re.compile("[^A-Za-z0-9가-힣]") 20 | 21 | 22 | def __init__( 23 | self, 24 | metrics=None, 25 | max_n=None, 26 | limit_length=True, 27 | length_limit=1000, 28 | length_limit_type="words", 29 | apply_avg=True, 30 | apply_best=False, 31 | use_tokenizer=True, 32 | alpha=0.5, 33 | weight_factor=1.0, 34 | ): 35 | self.metrics = metrics[:] if metrics is not None else Rouge.DEFAULT_METRICS 36 | for m in self.metrics: 37 | if m not in Rouge.AVAILABLE_METRICS: 38 | raise ValueError("Unknown metric '{}'".format(m)) 39 | 40 | 41 | self.max_n = max_n if "rouge-n" in self.metrics else None 42 | # Add all rouge-n metrics 43 | if self.max_n is not None: 44 | index_rouge_n = self.metrics.index("rouge-n") 45 | del self.metrics[index_rouge_n] 46 | self.metrics += ["rouge-{}".format(n) for n in range(1, self.max_n + 1)] 47 | self.metrics = set(self.metrics) 48 | 49 | 50 | self.limit_length = limit_length 51 | if self.limit_length: 52 | if length_limit_type not in Rouge.AVAILABLE_LENGTH_LIMIT_TYPES: 53 | raise ValueError("Unknown length_limit_type '{}'".format(length_limit_type)) 54 | 55 | 56 | self.length_limit = length_limit 57 | if self.length_limit == 0: 58 | self.limit_length = False 59 | self.length_limit_type = length_limit_type 60 | 61 | 62 | self.use_tokenizer = use_tokenizer 63 | if use_tokenizer: 64 | self.tokenizer = Mecab() 65 | 66 | 67 | self.apply_avg = apply_avg 68 | self.apply_best = apply_best 69 | self.alpha = alpha 70 | self.weight_factor = weight_factor 71 | if self.weight_factor <= 0: 72 | raise ValueError("ROUGE-W weight factor must greater than 0.") 73 | 74 | 75 | def tokenize_text(self, text): 76 | if self.use_tokenizer: 77 | return self.tokenizer.morphs(text) 78 | else: 79 | return text 80 | 81 | 82 | @staticmethod 83 | def split_into_sentences(text): 84 | return text.split("\n") 85 | 86 | 87 | @staticmethod 88 | def _get_ngrams(n, text): 89 | ngram_set = collections.defaultdict(int) 90 | max_index_ngram_start = len(text) - n 91 | for i in range(max_index_ngram_start + 1): 92 | ngram_set[tuple(text[i : i + n])] += 1 93 | return ngram_set 94 | 95 | 96 | @staticmethod 97 | def _split_into_words(sentences): 98 | return list(itertools.chain(*[_.split() for _ in sentences])) 99 | 100 | 101 | @staticmethod 102 | def _get_word_ngrams_and_length(n, sentences): 103 | assert len(sentences) > 0 104 | assert n > 0 105 | 106 | 107 | tokens = Rouge._split_into_words(sentences) 108 | return Rouge._get_ngrams(n, tokens), tokens, len(tokens) - (n - 1) 109 | 110 | 111 | @staticmethod 112 | def _get_unigrams(sentences): 113 | assert len(sentences) > 0 114 | 115 | 116 | tokens = Rouge._split_into_words(sentences) 117 | unigram_set = collections.defaultdict(int) 118 | for token in tokens: 119 | unigram_set[token] += 1 120 | return unigram_set, len(tokens) 121 | 122 | 123 | @staticmethod 124 | def _compute_p_r_f_score( 125 | evaluated_count, 126 | reference_count, 127 | overlapping_count, 128 | alpha=0.5, 129 | weight_factor=1.0, 130 | ): 131 | precision = 0.0 if evaluated_count == 0 else overlapping_count / float(evaluated_count) 132 | if weight_factor != 1.0: 133 | precision = precision ** (1.0 / weight_factor) 134 | recall = 0.0 if reference_count == 0 else overlapping_count / float(reference_count) 135 | if weight_factor != 1.0: 136 | recall = recall ** (1.0 / weight_factor) 137 | f1_score = Rouge._compute_f_score(precision, recall, alpha) 138 | return {"f": f1_score, "p": precision, "r": recall} 139 | 140 | 141 | @staticmethod 142 | def _compute_f_score(precision, recall, alpha=0.5): 143 | return ( 144 | 0.0 145 | if (recall == 0.0 or precision == 0.0) 146 | else precision * recall / ((1 - alpha) * precision + alpha * recall) 147 | ) 148 | 149 | 150 | @staticmethod 151 | def _compute_ngrams(evaluated_sentences, reference_sentences, n): 152 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 153 | raise ValueError("Collections must contain at least 1 sentence.") 154 | 155 | 156 | evaluated_ngrams, _, evaluated_count = Rouge._get_word_ngrams_and_length( 157 | n, evaluated_sentences 158 | ) 159 | reference_ngrams, _, reference_count = Rouge._get_word_ngrams_and_length( 160 | n, reference_sentences 161 | ) 162 | 163 | 164 | # Gets the overlapping ngrams between evaluated and reference 165 | overlapping_ngrams = set(evaluated_ngrams.keys()).intersection(set(reference_ngrams.keys())) 166 | overlapping_count = 0 167 | for ngram in overlapping_ngrams: 168 | overlapping_count += min(evaluated_ngrams[ngram], reference_ngrams[ngram]) 169 | 170 | 171 | return evaluated_count, reference_count, overlapping_count 172 | 173 | 174 | @staticmethod 175 | def _compute_ngrams_lcs(evaluated_sentences, reference_sentences, weight_factor=1.0): 176 | def _lcs(x, y): 177 | m = len(x) 178 | n = len(y) 179 | vals = collections.defaultdict(int) 180 | dirs = collections.defaultdict(int) 181 | 182 | 183 | for i in range(1, m + 1): 184 | for j in range(1, n + 1): 185 | if x[i - 1] == y[j - 1]: 186 | vals[i, j] = vals[i - 1, j - 1] + 1 187 | dirs[i, j] = "|" 188 | elif vals[i - 1, j] >= vals[i, j - 1]: 189 | vals[i, j] = vals[i - 1, j] 190 | dirs[i, j] = "^" 191 | else: 192 | vals[i, j] = vals[i, j - 1] 193 | dirs[i, j] = "<" 194 | 195 | 196 | return vals, dirs 197 | 198 | 199 | def _wlcs(x, y, weight_factor): 200 | m = len(x) 201 | n = len(y) 202 | vals = collections.defaultdict(float) 203 | dirs = collections.defaultdict(int) 204 | lengths = collections.defaultdict(int) 205 | 206 | 207 | for i in range(1, m + 1): 208 | for j in range(1, n + 1): 209 | if x[i - 1] == y[j - 1]: 210 | length_tmp = lengths[i - 1, j - 1] 211 | vals[i, j] = ( 212 | vals[i - 1, j - 1] 213 | + (length_tmp + 1) ** weight_factor 214 | - length_tmp ** weight_factor 215 | ) 216 | dirs[i, j] = "|" 217 | lengths[i, j] = length_tmp + 1 218 | elif vals[i - 1, j] >= vals[i, j - 1]: 219 | vals[i, j] = vals[i - 1, j] 220 | dirs[i, j] = "^" 221 | lengths[i, j] = 0 222 | else: 223 | vals[i, j] = vals[i, j - 1] 224 | dirs[i, j] = "<" 225 | lengths[i, j] = 0 226 | 227 | 228 | return vals, dirs 229 | 230 | 231 | def _mark_lcs(mask, dirs, m, n): 232 | while m != 0 and n != 0: 233 | if dirs[m, n] == "|": 234 | m -= 1 235 | n -= 1 236 | mask[m] = 1 237 | elif dirs[m, n] == "^": 238 | m -= 1 239 | elif dirs[m, n] == "<": 240 | n -= 1 241 | else: 242 | raise UnboundLocalError("Illegal move") 243 | 244 | 245 | return mask 246 | 247 | 248 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 249 | raise ValueError("Collections must contain at least 1 sentence.") 250 | 251 | 252 | evaluated_unigrams_dict, evaluated_count = Rouge._get_unigrams(evaluated_sentences) 253 | reference_unigrams_dict, reference_count = Rouge._get_unigrams(reference_sentences) 254 | 255 | 256 | # Has to use weight factor for WLCS 257 | use_WLCS = weight_factor != 1.0 258 | if use_WLCS: 259 | evaluated_count = evaluated_count ** weight_factor 260 | reference_count = 0 261 | 262 | 263 | overlapping_count = 0.0 264 | for reference_sentence in reference_sentences: 265 | reference_sentence_tokens = reference_sentence.split() 266 | if use_WLCS: 267 | reference_count += len(reference_sentence_tokens) ** weight_factor 268 | hit_mask = [0 for _ in range(len(reference_sentence_tokens))] 269 | 270 | 271 | for evaluated_sentence in evaluated_sentences: 272 | evaluated_sentence_tokens = evaluated_sentence.split() 273 | 274 | 275 | if use_WLCS: 276 | _, lcs_dirs = _wlcs( 277 | reference_sentence_tokens, 278 | evaluated_sentence_tokens, 279 | weight_factor, 280 | ) 281 | else: 282 | _, lcs_dirs = _lcs(reference_sentence_tokens, evaluated_sentence_tokens) 283 | _mark_lcs( 284 | hit_mask, 285 | lcs_dirs, 286 | len(reference_sentence_tokens), 287 | len(evaluated_sentence_tokens), 288 | ) 289 | 290 | 291 | overlapping_count_length = 0 292 | for ref_token_id, val in enumerate(hit_mask): 293 | if val == 1: 294 | token = reference_sentence_tokens[ref_token_id] 295 | if evaluated_unigrams_dict[token] > 0 and reference_unigrams_dict[token] > 0: 296 | evaluated_unigrams_dict[token] -= 1 297 | reference_unigrams_dict[ref_token_id] -= 1 298 | 299 | 300 | if use_WLCS: 301 | overlapping_count_length += 1 302 | if ( 303 | ref_token_id + 1 < len(hit_mask) and hit_mask[ref_token_id + 1] == 0 304 | ) or ref_token_id + 1 == len(hit_mask): 305 | overlapping_count += overlapping_count_length ** weight_factor 306 | overlapping_count_length = 0 307 | else: 308 | overlapping_count += 1 309 | 310 | 311 | if use_WLCS: 312 | reference_count = reference_count ** weight_factor 313 | 314 | 315 | return evaluated_count, reference_count, overlapping_count 316 | 317 | 318 | def get_scores(self, hypothesis, references): 319 | if isinstance(hypothesis, str): 320 | hypothesis, references = [hypothesis], [references] 321 | 322 | 323 | if type(hypothesis) != type(references): 324 | raise ValueError("'hyps' and 'refs' are not of the same type") 325 | 326 | 327 | if len(hypothesis) != len(references): 328 | raise ValueError("'hyps' and 'refs' do not have the same length") 329 | scores = {} 330 | has_rouge_n_metric = ( 331 | len([metric for metric in self.metrics if metric.split("-")[-1].isdigit()]) > 0 332 | ) 333 | if has_rouge_n_metric: 334 | scores.update(self._get_scores_rouge_n(hypothesis, references)) 335 | # scores = {**scores, **self._get_scores_rouge_n(hypothesis, references)} 336 | 337 | 338 | has_rouge_l_metric = ( 339 | len([metric for metric in self.metrics if metric.split("-")[-1].lower() == "l"]) > 0 340 | ) 341 | if has_rouge_l_metric: 342 | scores.update(self._get_scores_rouge_l_or_w(hypothesis, references, False)) 343 | # scores = {**scores, **self._get_scores_rouge_l_or_w(hypothesis, references, False)} 344 | 345 | 346 | has_rouge_w_metric = ( 347 | len([metric for metric in self.metrics if metric.split("-")[-1].lower() == "w"]) > 0 348 | ) 349 | if has_rouge_w_metric: 350 | scores.update(self._get_scores_rouge_l_or_w(hypothesis, references, True)) 351 | # scores = {**scores, **self._get_scores_rouge_l_or_w(hypothesis, references, True)} 352 | 353 | 354 | return scores 355 | 356 | 357 | def _get_scores_rouge_n(self, all_hypothesis, all_references): 358 | metrics = [metric for metric in self.metrics if metric.split("-")[-1].isdigit()] 359 | 360 | 361 | if self.apply_avg or self.apply_best: 362 | scores = {metric: {stat: 0.0 for stat in Rouge.STATS} for metric in metrics} 363 | else: 364 | scores = { 365 | metric: [{stat: [] for stat in Rouge.STATS} for _ in range(len(all_hypothesis))] 366 | for metric in metrics 367 | } 368 | 369 | 370 | for sample_id, (hypothesis, references) in enumerate(zip(all_hypothesis, all_references)): 371 | assert isinstance(hypothesis, str) 372 | has_multiple_references = False 373 | if isinstance(references, list): 374 | has_multiple_references = len(references) > 1 375 | if not has_multiple_references: 376 | references = references[0] 377 | 378 | 379 | # Prepare hypothesis and reference(s) 380 | hypothesis = self._preprocess_summary_as_a_whole(hypothesis) 381 | references = ( 382 | [self._preprocess_summary_as_a_whole(reference) for reference in references] 383 | if has_multiple_references 384 | else [self._preprocess_summary_as_a_whole(references)] 385 | ) 386 | 387 | 388 | # Compute scores 389 | for metric in metrics: 390 | suffix = metric.split("-")[-1] 391 | n = int(suffix) 392 | 393 | 394 | # Aggregate 395 | if self.apply_avg: 396 | # average model 397 | total_hypothesis_ngrams_count = 0 398 | total_reference_ngrams_count = 0 399 | total_ngrams_overlapping_count = 0 400 | 401 | 402 | for reference in references: 403 | ( 404 | hypothesis_count, 405 | reference_count, 406 | overlapping_ngrams, 407 | ) = Rouge._compute_ngrams(hypothesis, reference, n) 408 | total_hypothesis_ngrams_count += hypothesis_count 409 | total_reference_ngrams_count += reference_count 410 | total_ngrams_overlapping_count += overlapping_ngrams 411 | 412 | 413 | score = Rouge._compute_p_r_f_score( 414 | total_hypothesis_ngrams_count, 415 | total_reference_ngrams_count, 416 | total_ngrams_overlapping_count, 417 | self.alpha, 418 | ) 419 | 420 | 421 | for stat in Rouge.STATS: 422 | scores[metric][stat] += score[stat] 423 | else: 424 | # Best model 425 | if self.apply_best: 426 | best_current_score = None 427 | for reference in references: 428 | ( 429 | hypothesis_count, 430 | reference_count, 431 | overlapping_ngrams, 432 | ) = Rouge._compute_ngrams(hypothesis, reference, n) 433 | score = Rouge._compute_p_r_f_score( 434 | hypothesis_count, 435 | reference_count, 436 | overlapping_ngrams, 437 | self.alpha, 438 | ) 439 | if best_current_score is None or score["r"] > best_current_score["r"]: 440 | best_current_score = score 441 | 442 | 443 | for stat in Rouge.STATS: 444 | scores[metric][stat] += best_current_score[stat] 445 | # Keep all 446 | else: 447 | for reference in references: 448 | ( 449 | hypothesis_count, 450 | reference_count, 451 | overlapping_ngrams, 452 | ) = Rouge._compute_ngrams(hypothesis, reference, n) 453 | score = Rouge._compute_p_r_f_score( 454 | hypothesis_count, 455 | reference_count, 456 | overlapping_ngrams, 457 | self.alpha, 458 | ) 459 | for stat in Rouge.STATS: 460 | scores[metric][sample_id][stat].append(score[stat]) 461 | 462 | 463 | # Compute final score with the average or the the max 464 | if (self.apply_avg or self.apply_best) and len(all_hypothesis) > 1: 465 | for metric in metrics: 466 | for stat in Rouge.STATS: 467 | scores[metric][stat] /= len(all_hypothesis) 468 | 469 | 470 | return scores 471 | 472 | 473 | def _get_scores_rouge_l_or_w(self, all_hypothesis, all_references, use_w=False): 474 | metric = "rouge-w" if use_w else "rouge-l" 475 | if self.apply_avg or self.apply_best: 476 | scores = {metric: {stat: 0.0 for stat in Rouge.STATS}} 477 | else: 478 | scores = { 479 | metric: [{stat: [] for stat in Rouge.STATS} for _ in range(len(all_hypothesis))] 480 | } 481 | 482 | 483 | for sample_id, (hypothesis_sentences, references_sentences) in enumerate( 484 | zip(all_hypothesis, all_references) 485 | ): 486 | assert isinstance(hypothesis_sentences, str) 487 | has_multiple_references = False 488 | if isinstance(references_sentences, list): 489 | has_multiple_references = len(references_sentences) > 1 490 | if not has_multiple_references: 491 | references_sentences = references_sentences[0] 492 | 493 | 494 | # Prepare hypothesis and reference(s) 495 | hypothesis_sentences = self._preprocess_summary_per_sentence(hypothesis_sentences) 496 | references_sentences = ( 497 | [ 498 | self._preprocess_summary_per_sentence(reference) 499 | for reference in references_sentences 500 | ] 501 | if has_multiple_references 502 | else [self._preprocess_summary_per_sentence(references_sentences)] 503 | ) 504 | 505 | 506 | # Compute scores 507 | # Aggregate 508 | if self.apply_avg: 509 | # average model 510 | total_hypothesis_ngrams_count = 0 511 | total_reference_ngrams_count = 0 512 | total_ngrams_overlapping_count = 0 513 | 514 | 515 | for reference_sentences in references_sentences: 516 | ( 517 | hypothesis_count, 518 | reference_count, 519 | overlapping_ngrams, 520 | ) = Rouge._compute_ngrams_lcs( 521 | hypothesis_sentences, 522 | reference_sentences, 523 | self.weight_factor if use_w else 1.0, 524 | ) 525 | total_hypothesis_ngrams_count += hypothesis_count 526 | total_reference_ngrams_count += reference_count 527 | total_ngrams_overlapping_count += overlapping_ngrams 528 | 529 | 530 | score = Rouge._compute_p_r_f_score( 531 | total_hypothesis_ngrams_count, 532 | total_reference_ngrams_count, 533 | total_ngrams_overlapping_count, 534 | self.alpha, 535 | self.weight_factor if use_w else 1.0, 536 | ) 537 | for stat in Rouge.STATS: 538 | scores[metric][stat] += score[stat] 539 | else: 540 | # Best model 541 | if self.apply_best: 542 | best_current_score = None 543 | best_current_score_wlcs = None 544 | for reference_sentences in references_sentences: 545 | ( 546 | hypothesis_count, 547 | reference_count, 548 | overlapping_ngrams, 549 | ) = Rouge._compute_ngrams_lcs( 550 | hypothesis_sentences, 551 | reference_sentences, 552 | self.weight_factor if use_w else 1.0, 553 | ) 554 | score = Rouge._compute_p_r_f_score( 555 | hypothesis_count, 556 | reference_count, 557 | overlapping_ngrams, 558 | self.alpha, 559 | self.weight_factor if use_w else 1.0, 560 | ) 561 | 562 | 563 | if use_w: 564 | reference_count_for_score = reference_count ** ( 565 | 1.0 / self.weight_factor 566 | ) 567 | overlapping_ngrams_for_score = overlapping_ngrams 568 | score_wlcs = ( 569 | overlapping_ngrams_for_score / reference_count_for_score 570 | ) ** (1.0 / self.weight_factor) 571 | 572 | 573 | if ( 574 | best_current_score_wlcs is None 575 | or score_wlcs > best_current_score_wlcs 576 | ): 577 | best_current_score = score 578 | best_current_score_wlcs = score_wlcs 579 | else: 580 | if best_current_score is None or score["r"] > best_current_score["r"]: 581 | best_current_score = score 582 | 583 | 584 | for stat in Rouge.STATS: 585 | scores[metric][stat] += best_current_score[stat] 586 | # Keep all 587 | else: 588 | for reference_sentences in references_sentences: 589 | ( 590 | hypothesis_count, 591 | reference_count, 592 | overlapping_ngrams, 593 | ) = Rouge._compute_ngrams_lcs( 594 | hypothesis_sentences, 595 | reference_sentences, 596 | self.weight_factor if use_w else 1.0, 597 | ) 598 | score = Rouge._compute_p_r_f_score( 599 | hypothesis_count, 600 | reference_count, 601 | overlapping_ngrams, 602 | self.alpha, 603 | self.weight_factor, 604 | ) 605 | 606 | 607 | for stat in Rouge.STATS: 608 | scores[metric][sample_id][stat].append(score[stat]) 609 | 610 | 611 | # Compute final score with the average or the the max 612 | if (self.apply_avg or self.apply_best) and len(all_hypothesis) > 1: 613 | for stat in Rouge.STATS: 614 | scores[metric][stat] /= len(all_hypothesis) 615 | 616 | 617 | return scores 618 | 619 | 620 | def _preprocess_summary_as_a_whole(self, summary): 621 | sentences = Rouge.split_into_sentences(summary) 622 | 623 | 624 | # Truncate 625 | if self.limit_length: 626 | # By words 627 | if self.length_limit_type == "words": 628 | summary = " ".join(sentences) 629 | all_tokens = summary.split() # Counting as in the perls script 630 | summary = " ".join(all_tokens[: self.length_limit]) 631 | 632 | 633 | # By bytes 634 | elif self.length_limit_type == "bytes": 635 | summary = "" 636 | current_len = 0 637 | for sentence in sentences: 638 | sentence = sentence.strip() 639 | sentence_len = len(sentence) 640 | 641 | 642 | if current_len + sentence_len < self.length_limit: 643 | if current_len != 0: 644 | summary += " " 645 | summary += sentence 646 | current_len += sentence_len 647 | else: 648 | if current_len > 0: 649 | summary += " " 650 | summary += sentence[: self.length_limit - current_len] 651 | break 652 | else: 653 | summary = " ".join(sentences) 654 | 655 | 656 | summary = Rouge.REMOVE_CHAR_PATTERN.sub(" ", summary.lower()).strip() 657 | 658 | 659 | tokens = self.tokenize_text(Rouge.REMOVE_CHAR_PATTERN.sub(" ", summary)) 660 | preprocessed_summary = [" ".join(tokens)] 661 | 662 | 663 | return preprocessed_summary 664 | 665 | 666 | def _preprocess_summary_per_sentence(self, summary): 667 | sentences = Rouge.split_into_sentences(summary) 668 | 669 | 670 | # Truncate 671 | if self.limit_length: 672 | final_sentences = [] 673 | current_len = 0 674 | # By words 675 | if self.length_limit_type == "words": 676 | for sentence in sentences: 677 | tokens = sentence.strip().split() 678 | tokens_len = len(tokens) 679 | if current_len + tokens_len < self.length_limit: 680 | sentence = " ".join(tokens) 681 | final_sentences.append(sentence) 682 | current_len += tokens_len 683 | else: 684 | sentence = " ".join(tokens[: self.length_limit - current_len]) 685 | final_sentences.append(sentence) 686 | break 687 | # By bytes 688 | elif self.length_limit_type == "bytes": 689 | for sentence in sentences: 690 | sentence = sentence.strip() 691 | sentence_len = len(sentence) 692 | if current_len + sentence_len < self.length_limit: 693 | final_sentences.append(sentence) 694 | current_len += sentence_len 695 | else: 696 | sentence = sentence[: self.length_limit - current_len] 697 | final_sentences.append(sentence) 698 | break 699 | sentences = final_sentences 700 | 701 | 702 | final_sentences = [] 703 | for sentence in sentences: 704 | sentence = Rouge.REMOVE_CHAR_PATTERN.sub(" ", sentence.lower()).strip() 705 | 706 | 707 | tokens = self.tokenize_text(Rouge.REMOVE_CHAR_PATTERN.sub(" ", sentence)) 708 | 709 | 710 | sentence = " ".join(tokens) 711 | 712 | 713 | final_sentences.append(sentence) 714 | 715 | 716 | return final_sentences --------------------------------------------------------------------------------