├── prompts ├── std_generation_cnndm.txt ├── std_generation_xsum.txt ├── cot_generation_cnndm.txt ├── cot_generation_xsum.txt └── cot_element_extraction.txt ├── assets ├── Dataset.png └── Method.png ├── __pycache__ ├── arguments.cpython-38.pyc └── api_request.cpython-38.pyc ├── evaluation ├── __pycache__ │ ├── metric.cpython-36.pyc │ └── metric.cpython-38.pyc ├── metric.py └── eva.py ├── api_request.py ├── arguments.py ├── generation.py ├── README.md └── LICENSE /prompts/std_generation_cnndm.txt: -------------------------------------------------------------------------------- 1 | Summarize the above article: -------------------------------------------------------------------------------- /prompts/std_generation_xsum.txt: -------------------------------------------------------------------------------- 1 | Summarize the above article in one sentence: -------------------------------------------------------------------------------- /assets/Dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/SumCoT/HEAD/assets/Dataset.png -------------------------------------------------------------------------------- /assets/Method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/SumCoT/HEAD/assets/Method.png -------------------------------------------------------------------------------- /prompts/cot_generation_cnndm.txt: -------------------------------------------------------------------------------- 1 | Let's integrate the above information and summarize the article: -------------------------------------------------------------------------------- /prompts/cot_generation_xsum.txt: -------------------------------------------------------------------------------- 1 | Let's integrate the above information and summarize the article in one sentence: -------------------------------------------------------------------------------- /__pycache__/arguments.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/SumCoT/HEAD/__pycache__/arguments.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/api_request.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/SumCoT/HEAD/__pycache__/api_request.cpython-38.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/SumCoT/HEAD/evaluation/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /evaluation/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alsace08/SumCoT/HEAD/evaluation/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /prompts/cot_element_extraction.txt: -------------------------------------------------------------------------------- 1 | What are the important entities in this document? 2 | What are the important dates in this document? 3 | What events are happening in this document? 4 | What is the result of these events? 5 | Please answer the above questions: -------------------------------------------------------------------------------- /api_request.py: -------------------------------------------------------------------------------- 1 | # -!- coding: utf-8 -!- 2 | import openai 3 | 4 | 5 | class Decoder: 6 | def __init__(self, api_key): 7 | self.api_key = api_key 8 | 9 | def decode(self, input, model, max_length): 10 | response = self.decoder_for_gpt3(model, input, max_length) 11 | return response 12 | 13 | def decoder_for_gpt3(self, model, input, max_length): 14 | openai.api_key = self.api_key 15 | 16 | if model == "gpt3": 17 | engine = "text-ada-001" 18 | elif model == "gpt3-medium": 19 | engine = "text-babbage-001" 20 | elif model == "gpt3-large": 21 | engine = "text-curie-001" 22 | elif model == "gpt3-xl": 23 | engine = "text-davinci-002" 24 | else: 25 | raise ValueError("model is not properly defined ...") 26 | 27 | response = openai.Completion.create( 28 | engine=engine, 29 | prompt=input, 30 | max_tokens=max_length, 31 | temperature=0, 32 | stop=None 33 | ) 34 | 35 | return response["choices"][0]["text"] -------------------------------------------------------------------------------- /evaluation/metric.py: -------------------------------------------------------------------------------- 1 | # -!- coding: utf-8 -!- 2 | import json 3 | import openai 4 | import argparse 5 | from rouge import Rouge 6 | from bert_score import score 7 | 8 | 9 | def rouge_score(ref, pred): 10 | rouge = Rouge() 11 | rs = rouge.get_scores(pred, ref) 12 | rouge1 = rs[0]["rouge-1"]["f"] * 100 13 | rouge2 = rs[0]["rouge-2"]["f"] * 100 14 | rougel = rs[0]["rouge-l"]["f"] * 100 15 | return rouge1, rouge2, rougel 16 | 17 | 18 | def bs_score(ref, pred): 19 | _, _, F1 = score([pred], [ref], lang="en", verbose=True) 20 | bs = F1.mean() 21 | return bs 22 | 23 | 24 | class BatchEvaluation: 25 | def __init__(self, total_r1=0, total_r2=0, total_rl=0, total_bs=0, 26 | call_time_rs=0, call_time_bs=0): 27 | self.ref = "" 28 | self.pred = "" 29 | 30 | self.total_r1 = total_r1 31 | self.total_r2 = total_r2 32 | self.total_rl = total_rl 33 | self.total_bs = total_bs 34 | self.call_time_rs = call_time_rs 35 | self.call_time_bs = call_time_bs 36 | 37 | def set_text(self, ref, pred): 38 | self.ref = ref 39 | self.pred = pred 40 | return self 41 | 42 | def get_rouge_score(self): 43 | r1, r2, rl = rouge_score(self.ref, self.pred) 44 | self.total_r1 += r1 45 | self.total_r2 += r2 46 | self.total_rl += rl 47 | self.call_time_rs += 1 48 | 49 | def get_bs_score(self): 50 | bs = bs_score(self.ref, self.pred) 51 | self.total_bs += bs 52 | self.call_time_bs += 1 53 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | # -!- coding: utf-8 -!- 2 | import argparse 3 | 4 | 5 | def get_prompt(): 6 | std_generation_cnndm_prompt = open("./prompts/std_generation_cnndm.txt").read() 7 | std_generation_xsum_prompt = open("./prompts/std_generation_xsum.txt").read() 8 | cot_generation_cnndm_prompt = open("./prompts/cot_generation_cnndm.txt").read() 9 | cot_generation_xsum_prompt = open("./prompts/cot_generation_xsum.txt").read() 10 | cot_extraction_prompt = "" 11 | for line in open("./prompts/cot_element_extraction.txt"): 12 | cot_extraction_prompt += line 13 | 14 | prompt = {"std_generation_cnndm_prompt": std_generation_cnndm_prompt, 15 | "std_generation_xsum_prompt": std_generation_xsum_prompt, 16 | "cot_generation_cnndm_prompt": cot_generation_cnndm_prompt, 17 | "cot_generation_xsum_prompt": cot_generation_xsum_prompt, 18 | "cot_extraction_prompt": cot_extraction_prompt} 19 | 20 | return prompt 21 | 22 | 23 | def parse_arguments(): 24 | parser = argparse.ArgumentParser(description="SumCoT") 25 | parser.add_argument("--cot_true", type=bool, default="False", 26 | help="standard or cot-based generation") 27 | parser.add_argument("--model", type=str, default="gpt3-xl", 28 | choices=["gpt3", "gpt3-medium", "gpt3-large", "gpt3-xl"], 29 | help="model used for decoding") 30 | parser.add_argument("--dataset", type=str, default="cnndm", 31 | choices=["cnndm", "xsum"], help="dataset source") 32 | parser.add_argument("--start_id", type=int, default="0") 33 | parser.add_argument("--end_id", type=int, default="0") 34 | args = parser.parse_args() 35 | 36 | prompt = get_prompt() 37 | args.cot = prompt["cot_extraction_prompt"] 38 | 39 | if args.dataset == "cnndm": 40 | args.std_prompt = prompt["std_generation_cnndm_prompt"] 41 | args.cot_prompt = prompt["cot_generation_cnndm_prompt"] 42 | elif args.dataset == "xsum": 43 | args.std_prompt = prompt["std_generation_xsum_prompt"] 44 | args.cot_prompt = prompt["cot_generation_xsum_prompt"] 45 | else: 46 | raise "Invalid Dataset!" 47 | 48 | return args 49 | -------------------------------------------------------------------------------- /generation.py: -------------------------------------------------------------------------------- 1 | # -!- coding: utf-8 -!- 2 | import json 3 | import os 4 | from api_request import Decoder 5 | from arguments import parse_arguments 6 | 7 | 8 | def get_llm_summary(args, decoder): 9 | in_file = os.path.join("./data", args.dataset+"_element_aware.json") 10 | with open(in_file, "r", encoding="utf-8") as f: 11 | if "cnndm" in in_file: 12 | data = json.load(f)["cnndm"] 13 | data_output = {"cnndm": []} 14 | elif "xsum" in in_file: 15 | data = json.load(f)["xsum"] 16 | data_output = {"xsum": []} 17 | else: 18 | raise "Invalid Dataset!" 19 | 20 | for i in range(args.start_id, args.end_id + 1): 21 | src = data[i]["src"] 22 | ori_sum = data[i]["original_summary"] 23 | new_sum = data[i]["element-aware_summary"] 24 | 25 | x = "Article: " + src + "\n" + args.std_prompt 26 | pred_std = decoder.decode(x, model=args.model, max_length=2048) 27 | 28 | x = "Article: " + src + "\n" + args.cot 29 | ele = decoder.decode(x, model=args.model, max_length=2048) 30 | x = x + ele + "\n" + args.cot_prompt 31 | pred_cot = decoder.decode(x, model=args.model, max_length=2048) 32 | 33 | if "cnndm" in in_file: 34 | data_output["cnndm"].append({"id": i, 35 | "src": src, 36 | "original_summary": ori_sum, 37 | "element-aware_summary": new_sum, 38 | "gpt3_summary": pred_std, 39 | "gpt3_cot_summary": pred_cot}) 40 | elif "xsum" in in_file: 41 | data_output["xsum"].append({"id": i, 42 | "src": src, 43 | "original_summary": ori_sum, 44 | "element-aware_summary": new_sum, 45 | "gpt3_summary": pred_std, 46 | "gpt3_cot_summary": pred_cot}) 47 | 48 | data_output = json.dumps(data_output, indent=2) 49 | if "cnndm" in in_file: 50 | with open("cnndm_output.json", "w", newline='\n') as g: 51 | g.write(data_output) 52 | if "xsum" in in_file: 53 | with open("xsum_output.json", "w", newline='\n') as g: 54 | g.write(data_output) 55 | 56 | 57 | if __name__ == '__main__': 58 | args = parse_arguments() 59 | decoder = Decoder(api_key="xxx") 60 | 61 | get_llm_summary(args, decoder) 62 | 63 | -------------------------------------------------------------------------------- /evaluation/eva.py: -------------------------------------------------------------------------------- 1 | # -!- coding: utf-8 -!- 2 | import json 3 | import os 4 | from metric import BatchEvaluation 5 | import argparse 6 | 7 | 8 | def batch_evalution(dataset, start_id, end_id, bs_true): 9 | in_file = os.path.join("../data", dataset+"_element_aware.json") 10 | with open(in_file, "r", encoding="utf-8") as f: 11 | if "cnndm" in in_file: 12 | data = json.load(f)["cnndm"] 13 | elif "xsum" in in_file: 14 | data = json.load(f)["xsum"] 15 | 16 | eva_ori_std = BatchEvaluation() # (original ref. summary) vs. (GPT-3 std. summary) 17 | eva_ori_cot = BatchEvaluation() # (original ref. summary) vs. (GPT-3 cot summary) 18 | eva_new_std = BatchEvaluation() # (element-aware ref. summary) vs. (GPT-3 std. summary) 19 | eva_new_cot = BatchEvaluation() # (element-aware ref. summary) vs. (GPT-3 cot summary) 20 | 21 | for i in range(start_id, end_id + 1): 22 | ori_ref = data[i]["original_summary"] 23 | new_ref = data[i]["element-aware_summary"] 24 | std_pred = data[i]["gpt3_summary"] 25 | cot_pred = data[i]["gpt3_cot_summary"] 26 | 27 | if ori_ref == "" or new_ref == "" or std_pred == "" or cot_pred == "": 28 | continue 29 | 30 | eva_ori_std.set_text(ori_ref, std_pred) 31 | eva_ori_std.get_rouge_score() 32 | if bs_true: eva_ori_std.get_bs_score() 33 | 34 | eva_ori_cot.set_text(ori_ref, cot_pred) 35 | eva_ori_cot.get_rouge_score() 36 | if bs_true: eva_ori_cot.get_bs_score() 37 | 38 | eva_new_std.set_text(new_ref, std_pred) 39 | eva_new_std.get_rouge_score() 40 | if bs_true: eva_new_std.get_bs_score() 41 | 42 | eva_new_cot.set_text(new_ref, cot_pred) 43 | eva_new_cot.get_rouge_score() 44 | if bs_true: eva_new_cot.get_bs_score() 45 | 46 | print(f"original ref. summary vs. GPT-3 std. summary:\n" 47 | f"batch size:{eva_ori_std.call_time_rs}\n" 48 | f"r1: {eva_ori_std.total_r1/eva_ori_std.call_time_rs}\n" 49 | f"r2: {eva_ori_std.total_r2/eva_ori_std.call_time_rs}\n" 50 | f"rl: {eva_ori_std.total_rl/eva_ori_std.call_time_rs}\n") 51 | 52 | #print(f"original ref. summary vs. GPT-3 cot summary:\n" 53 | #f"batch size:{eva_ori_cot.call_time_rs}\n" 54 | #f"r1: {eva_ori_cot.total_r1 / eva_ori_cot.call_time_rs}\n" 55 | #f"r2: {eva_ori_cot.total_r2 / eva_ori_cot.call_time_rs}\n" 56 | #f"rl: {eva_ori_cot.total_rl / eva_ori_cot.call_time_rs}\n") 57 | 58 | print(f"element-aware ref. summary vs. GPT-3 std. summary:\n" 59 | f"batch size:{eva_new_std.call_time_rs}\n" 60 | f"r1: {eva_new_std.total_r1 / eva_new_std.call_time_rs}\n" 61 | f"r2: {eva_new_std.total_r2 / eva_new_std.call_time_rs}\n" 62 | f"rl: {eva_new_std.total_rl / eva_new_std.call_time_rs}\n") 63 | 64 | print(f"element-aware ref. summary vs. GPT-3 cot summary:\n" 65 | f"batch size:{eva_new_std.call_time_rs}\n" 66 | f"r1: {eva_new_cot.total_r1 / eva_new_cot.call_time_rs}\n" 67 | f"r2: {eva_new_cot.total_r2 / eva_new_cot.call_time_rs}\n" 68 | f"rl: {eva_new_cot.total_rl / eva_new_cot.call_time_rs}\n") 69 | 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser(description="Evaluation") 73 | parser.add_argument("--dataset", type=str, default="cnndm", 74 | choices=["cnndm", "xsum"], help="dataset source") 75 | parser.add_argument("--start_id", type=int, default="0") 76 | parser.add_argument("--end_id", type=int, default="199") 77 | parser.add_argument("--bs_true", type=bool, default=False) 78 | args = parser.parse_args() 79 | #args.end_id = args.start_id 80 | batch_evalution(dataset=args.dataset, start_id=args.start_id, end_id=args.end_id, bs_true=args.bs_true) 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Element-aware Summarization with Large Language Models: Expert-aligned Evaluation and Chain-of-Thought Method (ACL'23 Long Paper) 2 |  3 |  4 |  5 | 6 | 7 | 8 | 9 | **Let's elicit LLMs summarize step by step following the professional communication theory!** 10 | 11 | In this work, you can use directly or get inspired by: 12 | 13 | - A fine-grained generic summary data annotation protocol (combining micro and macro demands) 14 | 15 | - An expert-aligned generic summary test set (rewrite [*CNN/DailyMail*](https://paperswithcode.com/dataset/cnn-daily-mail-1) and [*BBC XSum*](https://paperswithcode.com/dataset/xsum)) 16 | 17 | - An expandable CoT-based open-end generation path (not only *SumCoT*) 18 | 19 | 20 | 21 | --- 22 | 23 | 24 | ## Element-aware Dataset 25 | 26 | ### Annotation Statement 27 | 28 | Our annotation protocol is mainly based on [*Lasswell Communication Model*](https://en.wikipedia.org/wiki/Lasswell%27s_model_of_communication) --- a famous communication theory proposed by Lasswell(1948). Additionally, we removed as much noise as possible from the original data set and performed data analysis (See paper for more details). 29 | 30 | Case comparisons for our Element-aware summary and original dataset-specific summary: 31 | 32 |
34 |
83 |