├── README.md ├── biography ├── article.json ├── eval_conversation.py └── gen_conversation.py ├── gsm ├── eval_gsm.py └── gen_gsm.py ├── math └── gen_math.py ├── mmlu ├── eval_mmlu.py └── gen_mmlu.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Improving Factuality and Reasoning in Language Models through Multiagent Debate 2 | 3 | ### [Project Page](https://composable-models.github.io/llm_debate/) | [Paper](https://arxiv.org/abs/2305.14325) 4 | 5 | [Yilun Du](https://yilundu.github.io/), 6 | [Shuang Li](https://shuangli59.github.io/), 7 | [Antonio Torralba](https://groups.csail.mit.edu/vision/torralbalab), 8 | [Joshua B. Tenenbaum](https://scholar.google.com/citations?user=rRJ9wTJMUB8C&hl=en), 9 | [Igor Mordatch](https://scholar.google.com/citations?user=Vzr1RukAAAAJ&hl=en) 10 | 11 | This is a preliminary implementation of the paper "Improving Factuality and Reasoning in Language Models through Multiagent Debate". More tasks and settings will be released soon. 12 | You may see some additional debate logs [here](https://www.dropbox.com/sh/6kq5ixfnf4zqk09/AABezsYsBhgg1IQAZ12yQ43_a?dl=0). 13 | 14 | Also, check out gauss5930's awesome implementation of multiagent debate on opensource LLMs [here](https://github.com/gauss5930/LLM-Agora)! 15 | 16 | ## Running experiments 17 | 18 | The code for running arithmetic, GSM, biographies, and MMLU tasks may be found in the following subfolders 19 | 20 | * ./math/ contains code for running math 21 | * ./gsm/ contains code for running gsm 22 | * ./biography/ contains code for running biographies 23 | * ./mmlu/ contains code for running mmlu results. 24 | 25 | **Math:** 26 | 27 | To generate and evaluated answer for Math problems through multiagent debate, cd into the math directory and run: 28 | `python gen_math.py` 29 | 30 | **Grade School Math:** 31 | 32 | To generate answers for Grade School Math problems through multiagent debate, cd into the gsm directory and run: 33 | `python gen_gsm.py` 34 | 35 | To evaluate the generated results of Grade School Math problems: 36 | `python eval_gsm.py` 37 | 38 | You can download the GSM dataset [here](https://github.com/openai/grade-school-math) 39 | 40 | 41 | **Biography:** 42 | 43 | To generate answers for Biography problems through multiagent debate, cd into the biography directory and run: 44 | `python gen_conversation.py` 45 | 46 | To evaluate the generated results for Biography problems: 47 | `python eval_conversation.py` 48 | 49 | **MMLU:** 50 | 51 | To generate answers for MMLU through multiagent debate, cd into the MMLU directory and run: 52 | `python gen_mmlu.py` 53 | 54 | To evaluate the generated results of MMLU: 55 | `python eval_mmlu.py` 56 | 57 | You can download the MMLU dataset [here](https://github.com/hendrycks/test) 58 | 59 | If you would like to cite the paper, here is a bibtex file: 60 | ``` 61 | @article{du2023improving, 62 | title={Improving Factuality and Reasoning in Language Models through Multiagent Debate}, 63 | author={Du, Yilun and Li, Shuang and Torralba, Antonio and Tenenbaum, Joshua B and Mordatch, Igor}, 64 | journal={arXiv preprint arXiv:2305.14325}, 65 | year={2023} 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /biography/eval_conversation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import openai 3 | import numpy as np 4 | import time 5 | 6 | def parse_bullets(sentence): 7 | bullets_preprocess = sentence.split("\n") 8 | bullets = [] 9 | 10 | for bullet in bullets_preprocess: 11 | try: 12 | idx = bullet.find(next(filter(str.isalpha, bullet))) 13 | except: 14 | continue 15 | 16 | bullet = bullet[idx:] 17 | 18 | if len(bullet) != 0: 19 | bullets.append(bullet) 20 | 21 | return bullets 22 | 23 | 24 | def parse_yes_no(string): 25 | """ 26 | Parses a string containing "yes" or "no" and returns a boolean value. 27 | 28 | Args: 29 | string (str): The string to parse. 30 | 31 | Returns: 32 | bool: True if the string contains "yes", False if the string contains "no". 33 | 34 | Raises: 35 | ValueError: If the input string does not contain "yes" or "no". 36 | """ 37 | 38 | if "uncertain" in string.lower(): 39 | return None 40 | elif "yes" in string.lower(): 41 | return True 42 | elif "no" in string.lower(): 43 | return False 44 | else: 45 | return None 46 | 47 | def filter_people(person): 48 | people = person.split("(")[0] 49 | return people 50 | 51 | if __name__ == "__main__": 52 | response = json.load(open("biography_1_2.json", "r")) 53 | 54 | with open("article.json", "r") as f: 55 | gt_data = json.load(f) 56 | 57 | gt_data_filter = {} 58 | 59 | for k, v in gt_data.items(): 60 | k = filter_people(k) 61 | gt_data_filter[k] = v 62 | 63 | gt_data = gt_data_filter 64 | 65 | people = list(response.keys()) 66 | 67 | accuracies = [] 68 | 69 | for person in people: 70 | 71 | if person not in gt_data: 72 | continue 73 | 74 | gt_description = gt_data[person] 75 | gt_bullets = parse_bullets(gt_description) 76 | bio_descriptions = response[person]# [2][-1]['content'] 77 | 78 | for description in bio_descriptions: 79 | 80 | bio_description = description[-1]['content'] 81 | 82 | bio_bullets = parse_bullets(bio_description) 83 | if len(bio_bullets) == 1: 84 | if len(bio_bullets[0]) < 400: 85 | continue 86 | 87 | bio_bullets = " ".join(bio_bullets) 88 | # continue 89 | 90 | for bullet in gt_bullets: 91 | message = [{"role": "user", "content": "Consider the following biography of {}: \n {} \n\n Is the above biography above consistent with the fact below? \n\n {} \n Give a single word answer, yes, no, or uncertain. Carefully check the precise dates and locations between the fact and the above biography.".format(person, bio_bullets, bullet)}] 92 | 93 | try: 94 | completion = openai.ChatCompletion.create( 95 | model="gpt-3.5-turbo-0301", 96 | messages=message, 97 | n=1) 98 | except Exception as e: 99 | print("sleeping") 100 | time.sleep(20) 101 | continue 102 | 103 | print(message) 104 | 105 | content = completion["choices"][0]["message"]["content"] 106 | print(content) 107 | accurate = parse_yes_no(content) 108 | 109 | if accurate is not None: 110 | accuracies.append(float(accurate)) 111 | 112 | print("accuracies:", np.mean(accuracies), np.std(accuracies) / (len(accuracies) ** 0.5)) 113 | 114 | -------------------------------------------------------------------------------- /biography/gen_conversation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import openai 3 | import random 4 | from tqdm import tqdm 5 | 6 | def parse_bullets(sentence): 7 | bullets_preprocess = sentence.split("\n") 8 | bullets = [] 9 | 10 | for bullet in bullets_preprocess: 11 | try: 12 | idx = bullet.find(next(filter(str.isalpha, bullet))) 13 | except: 14 | continue 15 | 16 | bullet = bullet[idx:] 17 | 18 | if len(bullet) != 0: 19 | bullets.append(bullet) 20 | 21 | return bullets 22 | 23 | 24 | def filter_people(person): 25 | people = person.split("(")[0] 26 | return people 27 | 28 | 29 | def construct_message(agents, idx, person, final=False): 30 | prefix_string = "Here are some bullet point biographies of {} given by other agents: ".format(person) 31 | 32 | if len(agents) == 0: 33 | return {"role": "user", "content": "Closely examine your biography and provide an updated bullet point biography."} 34 | 35 | 36 | for i, agent in enumerate(agents): 37 | agent_response = agent[idx]["content"] 38 | response = "\n\n Agent response: ```{}```".format(agent_response) 39 | 40 | prefix_string = prefix_string + response 41 | 42 | if final: 43 | prefix_string = prefix_string + "\n\n Closely examine your biography and the biography of other agents and provide an updated bullet point biography.".format(person, person) 44 | else: 45 | prefix_string = prefix_string + "\n\n Using these other biographies of {} as additional advice, what is your updated bullet point biography of the computer scientist {}?".format(person, person) 46 | 47 | return {"role": "user", "content": prefix_string} 48 | 49 | 50 | def construct_assistant_message(completion): 51 | content = completion["choices"][0]["message"]["content"] 52 | return {"role": "assistant", "content": content} 53 | 54 | 55 | if __name__ == "__main__": 56 | with open("article.json", "r") as f: 57 | data = json.load(f) 58 | 59 | people = sorted(data.keys()) 60 | people = [filter_people(person) for person in people] 61 | random.seed(1) 62 | random.shuffle(people) 63 | 64 | agents = 3 65 | rounds = 2 66 | 67 | generated_description = {} 68 | 69 | 70 | for person in tqdm(people[:40]): 71 | agent_contexts = [[{"role": "user", "content": "Give a bullet point biography of {} highlighting their contributions and achievements as a computer scientist, with each fact separated with a new line character. ".format(person)}] for agent in range(agents)] 72 | 73 | for round in range(rounds): 74 | for i, agent_context in enumerate(agent_contexts): 75 | 76 | if round != 0: 77 | agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:] 78 | 79 | if round == (rounds - 1): 80 | message = construct_message(agent_contexts_other, 2*round - 1, person=person, final=True) 81 | else: 82 | message = construct_message(agent_contexts_other, 2*round - 1, person=person, final=False) 83 | agent_context.append(message) 84 | 85 | try: 86 | completion = openai.ChatCompletion.create( 87 | model="gpt-3.5-turbo-0301", 88 | messages=agent_context, 89 | n=1) 90 | except: 91 | completion = openai.ChatCompletion.create( 92 | model="gpt-3.5-turbo-0301", 93 | messages=agent_context, 94 | n=1) 95 | 96 | print(completion) 97 | assistant_message = construct_assistant_message(completion) 98 | agent_context.append(assistant_message) 99 | 100 | bullets = parse_bullets(completion["choices"][0]['message']['content']) 101 | 102 | # The LM just doesn't know this person so no need to create debates 103 | if len(bullets) == 1: 104 | break 105 | 106 | generated_description[person] = agent_contexts 107 | 108 | json.dump(generated_description, open("biography_{}_{}.json".format(agents, rounds), "w")) 109 | 110 | -------------------------------------------------------------------------------- /gsm/eval_gsm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import openai 3 | import numpy as np 4 | import time 5 | import re 6 | 7 | def parse_bullets(sentence): 8 | bullets_preprocess = sentence.split("\n") 9 | bullets = [] 10 | 11 | for bullet in bullets_preprocess: 12 | try: 13 | idx = bullet.find(next(filter(str.isalpha, bullet))) 14 | except: 15 | continue 16 | 17 | bullet = bullet[idx:] 18 | 19 | if len(bullet) != 0: 20 | bullets.append(bullet) 21 | 22 | return bullets 23 | 24 | 25 | def parse_yes_no(string): 26 | """ 27 | Parses a string containing "yes" or "no" and returns a boolean value. 28 | 29 | Args: 30 | string (str): The string to parse. 31 | 32 | Returns: 33 | bool: True if the string contains "yes", False if the string contains "no". 34 | 35 | Raises: 36 | ValueError: If the input string does not contain "yes" or "no". 37 | """ 38 | if "yes" in string.lower(): 39 | return True 40 | elif "no" in string.lower(): 41 | return False 42 | else: 43 | return None 44 | 45 | 46 | def solve_math_problems(input_str): 47 | pattern = r"\d+\.?\d*" 48 | 49 | matches = re.findall(pattern, input_str) 50 | if matches: 51 | return matches[-1] 52 | 53 | return None 54 | 55 | def parse_answer(input_str): 56 | pattern = r"\{([0-9.,$]*)\}" 57 | matches = re.findall(pattern, input_str) 58 | 59 | solution = None 60 | 61 | for match_str in matches[::-1]: 62 | solution = re.sub(r"[^0-9.]", "", match_str) 63 | if solution: 64 | break 65 | 66 | return solution 67 | 68 | 69 | def compute_accuracy(gt, pred_solution): 70 | answers = solve_math_problems(gt) 71 | 72 | if answers is None: 73 | return None 74 | 75 | if type(pred_solution) == list: 76 | pred_answers = [] 77 | 78 | for pred_solution in pred_solutions: 79 | pred_answer = parse_answer(pred_solution) 80 | 81 | if pred_answer is None: 82 | pred_answer = solve_math_problems(pred_solution) 83 | 84 | pred_answers.append(pred_answer) 85 | 86 | # print("pred_answers: ", pred_answers) 87 | pred_answer = most_frequent(pred_answers) 88 | # print("pred answer: ", pred_answer) 89 | # pred_answer = pred_answers[0] 90 | else: 91 | pred_answer = parse_answer(pred_solution) 92 | if pred_answer is None: 93 | pred_answer = solve_math_problems(pred_solution) 94 | 95 | if pred_answer is None: 96 | return 1 97 | 98 | # try: 99 | if float(answers) == float(pred_answer): 100 | return 1 101 | else: 102 | return 0 103 | # except: 104 | # import pdb 105 | # pdb.set_trace() 106 | # print(pred_solution) 107 | 108 | 109 | def most_frequent(List): 110 | counter = 0 111 | num = List[0] 112 | 113 | for i in List: 114 | current_frequency = List.count(i) 115 | if current_frequency > counter: 116 | counter = current_frequency 117 | num = i 118 | 119 | return num 120 | 121 | if __name__ == "__main__": 122 | response_dict = json.load(open("gsm_debate_3_3.json", "r")) 123 | 124 | questions = list(response_dict.keys()) 125 | 126 | accuracies = [] 127 | 128 | for question in questions: 129 | responses, gt = response_dict[question] 130 | 131 | pred_solutions = [] 132 | for response in responses: 133 | pred_solution = response[-1]['content'] 134 | 135 | pred_solutions.append(pred_solution) 136 | 137 | accurate = compute_accuracy(gt, pred_solutions) 138 | 139 | if accurate is not None: 140 | accuracies.append(float(accurate)) 141 | else: 142 | import pdb 143 | pdb.set_trace() 144 | print(gt) 145 | 146 | print("accuracies:", np.mean(accuracies), np.std(accuracies) / (len(accuracies) ** 0.5)) 147 | 148 | -------------------------------------------------------------------------------- /gsm/gen_gsm.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import json 3 | import numpy as np 4 | import random 5 | 6 | def construct_message(agents, question, idx): 7 | if len(agents) == 0: 8 | return {"role": "user", "content": "Can you double check that your answer is correct. Please reiterate your answer, with your final answer a single numerical number, in the form \\boxed{{answer}}."} 9 | 10 | prefix_string = "These are the solutions to the problem from other agents: " 11 | 12 | for agent in agents: 13 | agent_response = agent[idx]["content"] 14 | response = "\n\n One agent solution: ```{}```".format(agent_response) 15 | 16 | prefix_string = prefix_string + response 17 | 18 | prefix_string = prefix_string + """\n\n Using the solutions from other agents as additional information, can you provide your answer to the math problem? \n The original math problem is {}. Your final answer should be a single numerical number, in the form \\boxed{{answer}}, at the end of your response.""".format(question) 19 | return {"role": "user", "content": prefix_string} 20 | 21 | 22 | def construct_assistant_message(completion): 23 | content = completion["choices"][0]["message"]["content"] 24 | return {"role": "assistant", "content": content} 25 | 26 | 27 | def read_jsonl(path: str): 28 | with open(path) as fh: 29 | return [json.loads(line) for line in fh.readlines() if line] 30 | 31 | if __name__ == "__main__": 32 | agents = 3 33 | rounds = 2 34 | random.seed(0) 35 | 36 | generated_description = {} 37 | 38 | questions = read_jsonl("/data/vision/billf/scratch/yilundu/llm_iterative_debate/grade-school-math/grade_school_math/data/test.jsonl") 39 | random.shuffle(questions) 40 | 41 | for data in questions[:100]: 42 | question = data['question'] 43 | answer = data['answer'] 44 | 45 | agent_contexts = [[{"role": "user", "content": """Can you solve the following math problem? {} Explain your reasoning. Your final answer should be a single numerical number, in the form \\boxed{{answer}}, at the end of your response. """.format(question)}] for agent in range(agents)] 46 | 47 | for round in range(rounds): 48 | for i, agent_context in enumerate(agent_contexts): 49 | 50 | if round != 0: 51 | agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:] 52 | message = construct_message(agent_contexts_other, question, 2*round - 1) 53 | agent_context.append(message) 54 | 55 | completion = openai.ChatCompletion.create( 56 | model="gpt-3.5-turbo-0301", 57 | messages=agent_context, 58 | n=1) 59 | 60 | assistant_message = construct_assistant_message(completion) 61 | agent_context.append(assistant_message) 62 | 63 | generated_description[question] = (agent_contexts, answer) 64 | 65 | json.dump(generated_description, open("gsm_{}_{}.json".format(agents, rounds), "w")) 66 | 67 | import pdb 68 | pdb.set_trace() 69 | print(answer) 70 | print(agent_context) 71 | -------------------------------------------------------------------------------- /math/gen_math.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import json 3 | import numpy as np 4 | import time 5 | import pickle 6 | from tqdm import tqdm 7 | 8 | def parse_bullets(sentence): 9 | bullets_preprocess = sentence.split("\n") 10 | bullets = [] 11 | 12 | for bullet in bullets_preprocess: 13 | try: 14 | idx = bullet.find(next(filter(str.isalpha, bullet))) 15 | except: 16 | continue 17 | 18 | bullet = bullet[idx:] 19 | 20 | if len(bullet) != 0: 21 | bullets.append(bullet) 22 | 23 | return bullets 24 | 25 | 26 | def generate_answer(answer_context): 27 | try: 28 | completion = openai.ChatCompletion.create( 29 | model="gpt-3.5-turbo-0301", 30 | messages=answer_context, 31 | n=1) 32 | except: 33 | print("retrying due to an error......") 34 | time.sleep(20) 35 | return generate_answer(answer_context) 36 | 37 | return completion 38 | 39 | 40 | def construct_message(agents, question, idx): 41 | 42 | # Use introspection in the case in which there are no other agents. 43 | if len(agents) == 0: 44 | return {"role": "user", "content": "Can you verify that your answer is correct. Please reiterate your answer, making sure to state your answer at the end of the response."} 45 | 46 | prefix_string = "These are the recent/updated opinions from other agents: " 47 | 48 | for agent in agents: 49 | agent_response = agent[idx]["content"] 50 | response = "\n\n One agent response: ```{}```".format(agent_response) 51 | 52 | prefix_string = prefix_string + response 53 | 54 | prefix_string = prefix_string + "\n\n Use these opinions carefully as additional advice, can you provide an updated answer? Make sure to state your answer at the end of the response.".format(question) 55 | return {"role": "user", "content": prefix_string} 56 | 57 | 58 | def construct_assistant_message(completion): 59 | content = completion["choices"][0]["message"]["content"] 60 | return {"role": "assistant", "content": content} 61 | 62 | def parse_answer(sentence): 63 | parts = sentence.split(" ") 64 | 65 | for part in parts[::-1]: 66 | try: 67 | answer = float(part) 68 | return answer 69 | except: 70 | continue 71 | 72 | 73 | def most_frequent(List): 74 | counter = 0 75 | num = List[0] 76 | 77 | for i in List: 78 | current_frequency = List.count(i) 79 | if current_frequency > counter: 80 | counter = current_frequency 81 | num = i 82 | 83 | return num 84 | 85 | 86 | if __name__ == "__main__": 87 | answer = parse_answer("My answer is the same as the other agents and AI language model: the result of 12+28*19+6 is 550.") 88 | 89 | agents = 2 90 | rounds = 3 91 | np.random.seed(0) 92 | 93 | evaluation_round = 100 94 | scores = [] 95 | 96 | generated_description = {} 97 | 98 | for round in tqdm(range(evaluation_round)): 99 | a, b, c, d, e, f = np.random.randint(0, 30, size=6) 100 | 101 | answer = a + b * c + d - e * f 102 | agent_contexts = [[{"role": "user", "content": """What is the result of {}+{}*{}+{}-{}*{}? Make sure to state your answer at the end of the response.""".format(a, b, c, d, e, f)}] for agent in range(agents)] 103 | 104 | content = agent_contexts[0][0]['content'] 105 | question_prompt = "We seek to find the result of {}+{}*{}+{}-{}*{}?".format(a, b, c, d, e, f) 106 | 107 | for round in range(rounds): 108 | for i, agent_context in enumerate(agent_contexts): 109 | 110 | if round != 0: 111 | agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:] 112 | message = construct_message(agent_contexts_other, question_prompt, 2*round - 1) 113 | agent_context.append(message) 114 | 115 | print("message: ", message) 116 | 117 | completion = generate_answer(agent_context) 118 | 119 | assistant_message = construct_assistant_message(completion) 120 | agent_context.append(assistant_message) 121 | print(completion) 122 | 123 | text_answers = [] 124 | 125 | for agent_context in agent_contexts: 126 | text_answer = string = agent_context[-1]['content'] 127 | text_answer = text_answer.replace(",", ".") 128 | text_answer = parse_answer(text_answer) 129 | 130 | if text_answer is None: 131 | continue 132 | 133 | text_answers.append(text_answer) 134 | 135 | generated_description[(a, b, c, d, e, f)] = (agent_contexts, answer) 136 | 137 | try: 138 | text_answer = most_frequent(text_answers) 139 | if text_answer == answer: 140 | scores.append(1) 141 | else: 142 | scores.append(0) 143 | except: 144 | continue 145 | 146 | print("performance:", np.mean(scores), np.std(scores) / (len(scores) ** 0.5)) 147 | 148 | pickle.dump(generated_description, open("math_agents{}_rounds{}.p".format(agents, rounds), "wb")) 149 | import pdb 150 | pdb.set_trace() 151 | print(answer) 152 | print(agent_context) 153 | -------------------------------------------------------------------------------- /mmlu/eval_mmlu.py: -------------------------------------------------------------------------------- 1 | import json 2 | import openai 3 | import numpy as np 4 | import time 5 | import re 6 | 7 | def parse_bullets(sentence): 8 | bullets_preprocess = sentence.split("\n") 9 | bullets = [] 10 | 11 | for bullet in bullets_preprocess: 12 | try: 13 | idx = bullet.find(next(filter(str.isalpha, bullet))) 14 | except: 15 | continue 16 | 17 | bullet = bullet[idx:] 18 | 19 | if len(bullet) != 0: 20 | bullets.append(bullet) 21 | 22 | return bullets 23 | 24 | 25 | def parse_yes_no(string): 26 | """ 27 | Parses a string containing "yes" or "no" and returns a boolean value. 28 | 29 | Args: 30 | string (str): The string to parse. 31 | 32 | Returns: 33 | bool: True if the string contains "yes", False if the string contains "no". 34 | 35 | Raises: 36 | ValueError: If the input string does not contain "yes" or "no". 37 | """ 38 | if "yes" in string.lower(): 39 | return True 40 | elif "no" in string.lower(): 41 | return False 42 | else: 43 | return None 44 | 45 | 46 | def solve_math_problems(input_str): 47 | pattern = r"\d+\.?\d*" 48 | 49 | matches = re.findall(pattern, input_str) 50 | if matches: 51 | return matches[-1] 52 | 53 | return None 54 | 55 | def parse_answer(input_str): 56 | pattern = r'\((\w)\)' 57 | matches = re.findall(pattern, input_str) 58 | 59 | solution = None 60 | # print("predicted solution") 61 | # print(input_str) 62 | # print("matches") 63 | # print(matches) 64 | 65 | for match_str in matches[::-1]: 66 | solution = match_str.upper() 67 | if solution: 68 | break 69 | 70 | return solution 71 | 72 | 73 | def compute_accuracy(gt, pred_solutions): 74 | if type(pred_solutions) == list: 75 | pred_answers = [] 76 | 77 | for pred_solution in pred_solutions: 78 | pred_answer = parse_answer(pred_solution) 79 | 80 | if pred_answer is None: 81 | pred_answer = solve_math_problems(pred_solution) 82 | 83 | if pred_answer is not None: 84 | pred_answers.append(pred_answer) 85 | 86 | if pred_answer is None: 87 | return 0 88 | pred_answer = most_frequent(pred_answers) 89 | # pred_answer = pred_answers[0] 90 | else: 91 | pred_answer = parse_answer(pred_solutions) 92 | if pred_answer is None: 93 | pred_answer = solve_math_problems(pred_solutions) 94 | 95 | if gt == pred_answer: 96 | return 1 97 | else: 98 | return 0 99 | 100 | 101 | def most_frequent(List): 102 | counter = 0 103 | num = List[0] 104 | 105 | for i in List: 106 | current_frequency = List.count(i) 107 | if current_frequency > counter: 108 | counter = current_frequency 109 | num = i 110 | 111 | return num 112 | 113 | if __name__ == "__main__": 114 | response_dict = json.load(open("mmlu_personalities_3_2.json", "r")) 115 | questions = list(response_dict.keys()) 116 | 117 | accuracies = [] 118 | 119 | for question in questions: 120 | responses, gt = response_dict[question] 121 | 122 | pred_solutions = [] 123 | for response in responses: 124 | pred_solution = response[-1]['content'] 125 | 126 | pred_solutions.append(pred_solution) 127 | # break 128 | 129 | # pred_solutions = pred_solutions[:1] 130 | 131 | accurate = compute_accuracy(gt, pred_solutions) 132 | 133 | 134 | if accurate is not None: 135 | accuracies.append(float(accurate)) 136 | else: 137 | import pdb 138 | pdb.set_trace() 139 | print(gt) 140 | 141 | print("accuracies:", np.mean(accuracies), np.std(accuracies) / (len(accuracies) ** 0.5)) 142 | 143 | -------------------------------------------------------------------------------- /mmlu/gen_mmlu.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import pandas as pd 3 | import json 4 | import time 5 | import random 6 | import openai 7 | 8 | def construct_message(agents, question, idx): 9 | if len(agents) == 0: 10 | return {"role": "user", "content": "Can you double check that your answer is correct. Put your final answer in the form (X) at the end of your response."} 11 | 12 | prefix_string = "These are the solutions to the problem from other agents: " 13 | 14 | for agent in agents: 15 | agent_response = agent[idx]["content"] 16 | response = "\n\n One agent solution: ```{}```".format(agent_response) 17 | 18 | prefix_string = prefix_string + response 19 | 20 | prefix_string = prefix_string + """\n\n Using the reasoning from other agents as additional advice, can you give an updated answer? Examine your solution and that other agents step by step. Put your answer in the form (X) at the end of your response.""".format(question) 21 | return {"role": "user", "content": prefix_string} 22 | 23 | 24 | def construct_assistant_message(completion): 25 | content = completion["choices"][0]["message"]["content"] 26 | return {"role": "assistant", "content": content} 27 | 28 | 29 | def generate_answer(answer_context): 30 | try: 31 | completion = openai.ChatCompletion.create( 32 | model="gpt-3.5-turbo-0301", 33 | messages=answer_context, 34 | n=1) 35 | except: 36 | print("retrying due to an error......") 37 | time.sleep(20) 38 | return generate_answer(answer_context) 39 | 40 | return completion 41 | 42 | 43 | def parse_question_answer(df, ix): 44 | question = df.iloc[ix, 0] 45 | a = df.iloc[ix, 1] 46 | b = df.iloc[ix, 2] 47 | c = df.iloc[ix, 3] 48 | d = df.iloc[ix, 4] 49 | 50 | question = "Can you answer the following question as accurately as possible? {}: A) {}, B) {}, C) {}, D) {} Explain your answer, putting the answer in the form (X) at the end of your response.".format(question, a, b, c, d) 51 | 52 | answer = df.iloc[ix, 5] 53 | 54 | return question, answer 55 | 56 | if __name__ == "__main__": 57 | agents = 3 58 | rounds = 2 59 | 60 | tasks = glob("/data/vision/billf/scratch/yilundu/llm_iterative_debate/mmlu/data/test/*.csv") 61 | 62 | dfs = [pd.read_csv(task) for task in tasks] 63 | 64 | random.seed(0) 65 | response_dict = {} 66 | 67 | for i in range(100): 68 | df = random.choice(dfs) 69 | ix = len(df) 70 | idx = random.randint(0, ix-1) 71 | 72 | question, answer = parse_question_answer(df, idx) 73 | 74 | agent_contexts = [[{"role": "user", "content": question}] for agent in range(agents)] 75 | 76 | for round in range(rounds): 77 | for i, agent_context in enumerate(agent_contexts): 78 | 79 | if round != 0: 80 | agent_contexts_other = agent_contexts[:i] + agent_contexts[i+1:] 81 | message = construct_message(agent_contexts_other, question, 2 * round - 1) 82 | agent_context.append(message) 83 | 84 | completion = generate_answer(agent_context) 85 | 86 | assistant_message = construct_assistant_message(completion) 87 | agent_context.append(assistant_message) 88 | print(completion) 89 | 90 | response_dict[question] = (agent_contexts, answer) 91 | 92 | json.dump(response_dict, open("mmlu_{}_{}.json".format(agents, rounds), "w")) 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.4 2 | openai==0.27.6 3 | pandas==1.5.3 4 | tqdm==4.64.1 5 | --------------------------------------------------------------------------------