├── .gitignore ├── Datasets ├── gsm8k │ └── gsm8k_dataset.py ├── gsm8k_dataset.py ├── math_dataset.py ├── mbpp_dataset.py ├── mmlu_dataset.py └── sample_MATH.py ├── Experiments ├── run_gsm8k.py ├── run_humaneval.py ├── run_math.py ├── run_mbpp.py └── run_mmlu.py ├── LICENSE ├── MAR ├── Agent │ ├── __init__.py │ ├── agent.py │ ├── agent_registry.py │ └── reasoning_profile.py ├── Graph │ ├── __init__.py │ ├── graph.py │ └── node.py ├── LLM │ ├── __init__.py │ ├── gpt_chat.py │ ├── llm.py │ ├── llm_embedding.py │ ├── llm_profile.py │ ├── llm_registry.py │ └── price.py ├── MasRouter │ └── mas_router.py ├── Prompts │ ├── __init__.py │ ├── message_aggregation.py │ ├── output_format.py │ ├── post_process.py │ ├── reasoning.py │ └── tasks_profile.py ├── Roles │ ├── Code │ │ ├── AlgorithmDesigner.json │ │ ├── BugFixer.json │ │ ├── PlanSolver.json │ │ ├── ProgrammingExpert.json │ │ ├── ProjectManager.json │ │ ├── ReflectProgrammer.json │ │ └── TestAnalyst.json │ ├── Commonsense │ │ ├── Critic.json │ │ ├── Economist.json │ │ ├── Historian.json │ │ ├── KnowledgeExpert.json │ │ ├── Reflector.json │ │ ├── Scientist.json │ │ └── WikiSearcher.json │ ├── FinalNode │ │ ├── gsm8k.json │ │ ├── humaneval.json │ │ ├── math.json │ │ ├── mbpp.json │ │ └── mmlu.json │ ├── Math │ │ ├── AlgorithmEngineer.json │ │ ├── CertifiedAccountant.json │ │ ├── Economist.json │ │ ├── Engineer.json │ │ ├── Inspector.json │ │ ├── MathAnalyst.json │ │ ├── MathSolver.json │ │ ├── MathTeacher.json │ │ ├── Mathematician.json │ │ ├── ProgrammingExpert.json │ │ ├── Scientist.json │ │ └── SoftwareDeveloper.json │ ├── __init__.py │ ├── role_example.py │ └── role_registry.py ├── Tools │ ├── coding │ │ ├── executor_factory.py │ │ ├── executor_types.py │ │ ├── executor_utils.py │ │ └── python_executor.py │ ├── reader │ │ └── readers.py │ ├── search │ │ ├── arXiv.py │ │ ├── search.py │ │ └── wiki.py │ ├── vgen │ │ └── dalle3.py │ └── web │ │ ├── screenshot.py │ │ └── youtube.py └── Utils │ ├── const.py │ ├── globals.py │ ├── log.py │ └── utils.py ├── README.md ├── assets ├── intro.png └── pipeline.png └── template.env /.gitignore: -------------------------------------------------------------------------------- 1 | docx/ 2 | logs/ 3 | Datasets/MATH/* 4 | Datasets/MMLU/* 5 | 6 | .env 7 | *.jsonl 8 | *.pth 9 | *.pyc 10 | *.csv -------------------------------------------------------------------------------- /Datasets/gsm8k/gsm8k_dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def gsm_data_process(dataset): 4 | # extract the question, step and answer 5 | list_data_dict = [] 6 | for data in dataset: 7 | item = {"task":data["question"]} 8 | raw_answer = data["answer"] 9 | raw_answer_list = raw_answer.split("\n####") 10 | item["step"] = raw_answer_list[0].strip() 11 | item["answer"] = raw_answer_list[-1].replace(",", "").strip() 12 | list_data_dict.append(item) 13 | 14 | return list_data_dict 15 | 16 | def is_number(s): 17 | try: 18 | float(s) 19 | return True 20 | except ValueError: 21 | pass 22 | try: 23 | import unicodedata 24 | unicodedata.numeric(s) 25 | return True 26 | except (TypeError, ValueError): 27 | pass 28 | return False 29 | 30 | def gsm_get_predict(pred_str): 31 | pred_str = re.sub(r'(?<=\d),(?=\d)', '', pred_str) 32 | if('The answer is ' in pred_str): 33 | pred = pred_str.split('The answer is ')[-1].strip() 34 | elif('the answer is ' in pred_str): 35 | pred = pred_str.split('the answer is ')[-1].strip() 36 | elif 'boxed' in pred_str: 37 | ans = pred_str.split('boxed')[-1] 38 | if (ans[0] == '{'): 39 | stack = 1 40 | a = '' 41 | for c in ans[1:]: 42 | if (c == '{'): 43 | stack += 1 44 | a += c 45 | elif (c == '}'): 46 | stack -= 1 47 | if (stack == 0): break 48 | a += c 49 | else: 50 | a += c 51 | else: 52 | a = ans.split('$')[0].strip() 53 | a = _strip_string(a) 54 | pred=a 55 | else: 56 | pattern = '-?\d*\.?\d+' 57 | pred = re.findall(pattern, pred_str) 58 | if(len(pred) >= 1): 59 | # print(pred_str) 60 | pred = pred[-1] 61 | else: pred = '' 62 | 63 | if pred != "": 64 | if pred[-1] == ".": 65 | pred = pred[:-1] 66 | if pred[-1] == "/": 67 | pred = pred[:-1] 68 | 69 | pred=_strip_string(pred) 70 | 71 | if 'boxed' in pred: 72 | ans = pred.split('boxed')[-1] 73 | if (ans[0] == '{'): 74 | stack = 1 75 | a = '' 76 | for c in ans[1:]: 77 | if (c == '{'): 78 | stack += 1 79 | a += c 80 | elif (c == '}'): 81 | stack -= 1 82 | if (stack == 0): break 83 | a += c 84 | else: 85 | a += c 86 | else: 87 | a = ans.split('$')[0].strip() 88 | a = _strip_string(a) 89 | pred=a 90 | 91 | if is_number(pred): 92 | return pred 93 | else: 94 | matches = re.findall(r'\d+', pred) 95 | return matches[-1] if matches else '0' 96 | 97 | 98 | def _fix_sqrt(string): 99 | if "\\sqrt" not in string: 100 | return string 101 | splits = string.split("\\sqrt") 102 | new_string = splits[0] 103 | for split in splits[1:]: 104 | if split[0] != "{": 105 | a = split[0] 106 | new_substr = "\\sqrt{" + a + "}" + split[1:] 107 | else: 108 | new_substr = "\\sqrt" + split 109 | new_string += new_substr 110 | return new_string 111 | 112 | def delete_extra_zero(n): 113 | try: 114 | n=float(n) 115 | except: 116 | print("None {}".format(n)) 117 | return n 118 | if isinstance(n, int): 119 | return str(n) 120 | if isinstance(n, float): 121 | n = str(n).rstrip('0') 122 | n = int(n.rstrip('.')) if n.endswith('.') else float(n) 123 | n=str(n) 124 | return n 125 | 126 | def _fix_fracs(string): 127 | substrs = string.split("\\frac") 128 | new_str = substrs[0] 129 | if len(substrs) > 1: 130 | substrs = substrs[1:] 131 | for substr in substrs: 132 | new_str += "\\frac" 133 | if substr[0] == "{": 134 | new_str += substr 135 | else: 136 | try: 137 | assert len(substr) >= 2 138 | except: 139 | return string 140 | a = substr[0] 141 | b = substr[1] 142 | if b != "{": 143 | if len(substr) > 2: 144 | post_substr = substr[2:] 145 | new_str += "{" + a + "}{" + b + "}" + post_substr 146 | else: 147 | new_str += "{" + a + "}{" + b + "}" 148 | else: 149 | if len(substr) > 2: 150 | post_substr = substr[2:] 151 | new_str += "{" + a + "}" + b + post_substr 152 | else: 153 | new_str += "{" + a + "}" + b 154 | string = new_str 155 | return string 156 | 157 | def _fix_a_slash_b(string): 158 | if len(string.split("/")) != 2: 159 | return string 160 | a = string.split("/")[0] 161 | b = string.split("/")[1] 162 | try: 163 | a = int(a) 164 | b = int(b) 165 | assert string == "{}/{}".format(a, b) 166 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 167 | return new_string 168 | except: 169 | return string 170 | 171 | def _remove_right_units(string): 172 | # "\\text{ " only ever occurs (at least in the val set) when describing units 173 | if "\\text{ " in string: 174 | splits = string.split("\\text{ ") 175 | assert len(splits) == 2 176 | return splits[0] 177 | else: 178 | return string 179 | 180 | def _strip_string(string): 181 | # linebreaks 182 | string = string.replace("\n", "") 183 | # print(string) 184 | 185 | # remove inverse spaces 186 | string = string.replace("\\!", "") 187 | # print(string) 188 | 189 | # replace \\ with \ 190 | string = string.replace("\\\\", "\\") 191 | # print(string) 192 | 193 | # replace tfrac and dfrac with frac 194 | string = string.replace("tfrac", "frac") 195 | string = string.replace("dfrac", "frac") 196 | # print(string) 197 | 198 | # remove \left and \right 199 | string = string.replace("\\left", "") 200 | string = string.replace("\\right", "") 201 | # print(string) 202 | 203 | # Remove circ (degrees) 204 | string = string.replace("^{\\circ}", "") 205 | string = string.replace("^\\circ", "") 206 | 207 | # remove dollar signs 208 | string = string.replace("\\$", "") 209 | 210 | # remove units (on the right) 211 | string = _remove_right_units(string) 212 | 213 | # remove percentage 214 | string = string.replace("\\%", "") 215 | string = string.replace("\%", "") 216 | 217 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 218 | string = string.replace(" .", " 0.") 219 | string = string.replace("{.", "{0.") 220 | # if empty, return empty string 221 | if len(string) == 0: 222 | return string 223 | if string[0] == ".": 224 | string = "0" + string 225 | 226 | # to consider: get rid of e.g. "k = " or "q = " at beginning 227 | if len(string.split("=")) == 2: 228 | if len(string.split("=")[0]) <= 2: 229 | string = string.split("=")[1] 230 | 231 | # fix sqrt3 --> sqrt{3} 232 | string = _fix_sqrt(string) 233 | 234 | # remove spaces 235 | string = string.replace(" ", "") 236 | 237 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 238 | string = _fix_fracs(string) 239 | 240 | # manually change 0.5 --> \frac{1}{2} 241 | if string == "0.5": 242 | string = "\\frac{1}{2}" 243 | 244 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 245 | string = _fix_a_slash_b(string) 246 | 247 | return string -------------------------------------------------------------------------------- /Datasets/gsm8k_dataset.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def gsm_data_process(dataset): 4 | # extract the question, step and answer 5 | list_data_dict = [] 6 | for data in dataset: 7 | item = {"task":data["question"]} 8 | raw_answer = data["answer"] 9 | raw_answer_list = raw_answer.split("\n####") 10 | item["step"] = raw_answer_list[0].strip() 11 | item["answer"] = raw_answer_list[-1].replace(",", "").strip() 12 | list_data_dict.append(item) 13 | 14 | return list_data_dict 15 | 16 | def gsm_get_predict(pred_str): 17 | if('answer is ' in pred_str): 18 | pred = pred_str.split('answer is ')[-1].strip() 19 | elif 'boxed' in pred_str: 20 | ans = pred_str.split('boxed')[-1] 21 | if (ans[0] == '{'): 22 | stack = 1 23 | a = '' 24 | for c in ans[1:]: 25 | if (c == '{'): 26 | stack += 1 27 | a += c 28 | elif (c == '}'): 29 | stack -= 1 30 | if (stack == 0): break 31 | a += c 32 | else: 33 | a += c 34 | else: 35 | a = ans.split('$')[0].strip() 36 | a = _strip_string(a) 37 | pred=a 38 | else: 39 | pattern = '-?\d*\.?\d+' 40 | pred = re.findall(pattern, pred_str) 41 | if(len(pred) >= 1): 42 | # print(pred_str) 43 | pred = pred[-1] 44 | else: pred = '' 45 | 46 | if pred != "": 47 | if pred[-1] == ".": 48 | pred = pred[:-1] 49 | if pred[-1] == "/": 50 | pred = pred[:-1] 51 | 52 | pred=_strip_string(pred) 53 | 54 | if 'boxed' in pred: 55 | ans = pred.split('boxed')[-1] 56 | if (ans[0] == '{'): 57 | stack = 1 58 | a = '' 59 | for c in ans[1:]: 60 | if (c == '{'): 61 | stack += 1 62 | a += c 63 | elif (c == '}'): 64 | stack -= 1 65 | if (stack == 0): break 66 | a += c 67 | else: 68 | a += c 69 | else: 70 | a = ans.split('$')[0].strip() 71 | a = _strip_string(a) 72 | pred=a 73 | 74 | if pred.isdigit(): 75 | return pred 76 | else: 77 | matches = re.findall(r'\d+', pred) 78 | return matches[-1] if matches else '0' 79 | 80 | 81 | def _fix_sqrt(string): 82 | if "\\sqrt" not in string: 83 | return string 84 | splits = string.split("\\sqrt") 85 | new_string = splits[0] 86 | for split in splits[1:]: 87 | if split[0] != "{": 88 | a = split[0] 89 | new_substr = "\\sqrt{" + a + "}" + split[1:] 90 | else: 91 | new_substr = "\\sqrt" + split 92 | new_string += new_substr 93 | return new_string 94 | 95 | def delete_extra_zero(n): 96 | try: 97 | n=float(n) 98 | except: 99 | print("None {}".format(n)) 100 | return n 101 | if isinstance(n, int): 102 | return str(n) 103 | if isinstance(n, float): 104 | n = str(n).rstrip('0') 105 | n = int(n.rstrip('.')) if n.endswith('.') else float(n) 106 | n=str(n) 107 | return n 108 | 109 | def _fix_fracs(string): 110 | substrs = string.split("\\frac") 111 | new_str = substrs[0] 112 | if len(substrs) > 1: 113 | substrs = substrs[1:] 114 | for substr in substrs: 115 | new_str += "\\frac" 116 | if substr[0] == "{": 117 | new_str += substr 118 | else: 119 | try: 120 | assert len(substr) >= 2 121 | except: 122 | return string 123 | a = substr[0] 124 | b = substr[1] 125 | if b != "{": 126 | if len(substr) > 2: 127 | post_substr = substr[2:] 128 | new_str += "{" + a + "}{" + b + "}" + post_substr 129 | else: 130 | new_str += "{" + a + "}{" + b + "}" 131 | else: 132 | if len(substr) > 2: 133 | post_substr = substr[2:] 134 | new_str += "{" + a + "}" + b + post_substr 135 | else: 136 | new_str += "{" + a + "}" + b 137 | string = new_str 138 | return string 139 | 140 | def _fix_a_slash_b(string): 141 | if len(string.split("/")) != 2: 142 | return string 143 | a = string.split("/")[0] 144 | b = string.split("/")[1] 145 | try: 146 | a = int(a) 147 | b = int(b) 148 | assert string == "{}/{}".format(a, b) 149 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 150 | return new_string 151 | except: 152 | return string 153 | 154 | def _remove_right_units(string): 155 | # "\\text{ " only ever occurs (at least in the val set) when describing units 156 | if "\\text{ " in string: 157 | splits = string.split("\\text{ ") 158 | assert len(splits) == 2 159 | return splits[0] 160 | else: 161 | return string 162 | 163 | def _strip_string(string): 164 | # linebreaks 165 | string = string.replace("\n", "") 166 | # print(string) 167 | 168 | # remove inverse spaces 169 | string = string.replace("\\!", "") 170 | # print(string) 171 | 172 | # replace \\ with \ 173 | string = string.replace("\\\\", "\\") 174 | # print(string) 175 | 176 | # replace tfrac and dfrac with frac 177 | string = string.replace("tfrac", "frac") 178 | string = string.replace("dfrac", "frac") 179 | # print(string) 180 | 181 | # remove \left and \right 182 | string = string.replace("\\left", "") 183 | string = string.replace("\\right", "") 184 | # print(string) 185 | 186 | # Remove circ (degrees) 187 | string = string.replace("^{\\circ}", "") 188 | string = string.replace("^\\circ", "") 189 | 190 | # remove dollar signs 191 | string = string.replace("\\$", "") 192 | 193 | # remove units (on the right) 194 | string = _remove_right_units(string) 195 | 196 | # remove percentage 197 | string = string.replace("\\%", "") 198 | string = string.replace("\%", "") 199 | 200 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 201 | string = string.replace(" .", " 0.") 202 | string = string.replace("{.", "{0.") 203 | # if empty, return empty string 204 | if len(string) == 0: 205 | return string 206 | if string[0] == ".": 207 | string = "0" + string 208 | 209 | # to consider: get rid of e.g. "k = " or "q = " at beginning 210 | if len(string.split("=")) == 2: 211 | if len(string.split("=")[0]) <= 2: 212 | string = string.split("=")[1] 213 | 214 | # fix sqrt3 --> sqrt{3} 215 | string = _fix_sqrt(string) 216 | 217 | # remove spaces 218 | string = string.replace(" ", "") 219 | 220 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 221 | string = _fix_fracs(string) 222 | 223 | # manually change 0.5 --> \frac{1}{2} 224 | if string == "0.5": 225 | string = "\\frac{1}{2}" 226 | 227 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 228 | string = _fix_a_slash_b(string) 229 | 230 | return string -------------------------------------------------------------------------------- /Datasets/math_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import glob 4 | import json 5 | from typing import Union, List, Literal, Any, Dict 6 | import numpy as np 7 | 8 | def load_math_dataset(data_path: str, split: Union [Literal['train'], Literal['test'], Literal['sampled_train'], Literal['sampled_test']]='train') -> List[Dict[str, str]]: 9 | print("Loading Math dataset...") 10 | category_paths = glob.glob(os.path.join(data_path, split, "*")) 11 | category_paths = sorted(category_paths) 12 | print("Number of categories: ", len(category_paths)) 13 | total_data = [] 14 | for category_path in category_paths: 15 | if os.path.isdir(category_path): 16 | json_files = glob.glob(os.path.join(category_path, "*.json")) 17 | for json_file in json_files: 18 | with open(json_file, "r", encoding="utf-8") as f: 19 | data = json.load(f) 20 | total_data.append(data) 21 | print("Total number of questions: ", len(total_data)) 22 | rng = np.random.default_rng(888) 23 | shuffled_data = list(rng.permutation(total_data)) 24 | return shuffled_data 25 | 26 | 27 | ''' 28 | copied from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py 29 | ''' 30 | def _fix_fracs(string): 31 | substrs = string.split("\\frac") 32 | new_str = substrs[0] 33 | if len(substrs) > 1: 34 | substrs = substrs[1:] 35 | for substr in substrs: 36 | new_str += "\\frac" 37 | if substr[0] == "{": 38 | new_str += substr 39 | else: 40 | try: 41 | assert len(substr) >= 2 42 | except: 43 | return string 44 | a = substr[0] 45 | b = substr[1] 46 | if b != "{": 47 | if len(substr) > 2: 48 | post_substr = substr[2:] 49 | new_str += "{" + a + "}{" + b + "}" + post_substr 50 | else: 51 | new_str += "{" + a + "}{" + b + "}" 52 | else: 53 | if len(substr) > 2: 54 | post_substr = substr[2:] 55 | new_str += "{" + a + "}" + b + post_substr 56 | else: 57 | new_str += "{" + a + "}" + b 58 | string = new_str 59 | return string 60 | 61 | def _fix_a_slash_b(string): 62 | if len(string.split("/")) != 2: 63 | return string 64 | a = string.split("/")[0] 65 | b = string.split("/")[1] 66 | try: 67 | a = int(a) 68 | b = int(b) 69 | assert string == "{}/{}".format(a, b) 70 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 71 | return new_string 72 | except: 73 | return string 74 | 75 | def _remove_right_units(string): 76 | # "\\text{ " only ever occurs (at least in the val set) when describing units 77 | if "\\text{ " in string: 78 | splits = string.split("\\text{ ") 79 | assert len(splits) == 2 80 | return splits[0] 81 | else: 82 | return string 83 | 84 | def _fix_sqrt(string): 85 | if "\\sqrt" not in string: 86 | return string 87 | splits = string.split("\\sqrt") 88 | new_string = splits[0] 89 | for split in splits[1:]: 90 | if split[0] != "{": 91 | a = split[0] 92 | new_substr = "\\sqrt{" + a + "}" + split[1:] 93 | else: 94 | new_substr = "\\sqrt" + split 95 | new_string += new_substr 96 | return new_string 97 | 98 | def _strip_string(string): 99 | # linebreaks 100 | string = string.replace("\n", "") 101 | #print(string) 102 | 103 | # remove inverse spaces 104 | string = string.replace("\\!", "") 105 | #print(string) 106 | 107 | # replace \\ with \ 108 | string = string.replace("\\\\", "\\") 109 | #print(string) 110 | 111 | # replace tfrac and dfrac with frac 112 | string = string.replace("tfrac", "frac") 113 | string = string.replace("dfrac", "frac") 114 | #print(string) 115 | 116 | # remove \left and \right 117 | string = string.replace("\\left", "") 118 | string = string.replace("\\right", "") 119 | #print(string) 120 | 121 | # Remove circ (degrees) 122 | string = string.replace("^{\\circ}", "") 123 | string = string.replace("^\\circ", "") 124 | 125 | # remove dollar signs 126 | string = string.replace("\\$", "") 127 | 128 | # remove units (on the right) 129 | string = _remove_right_units(string) 130 | 131 | # remove percentage 132 | string = string.replace("\\%", "") 133 | string = string.replace("\%", "") 134 | 135 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 136 | string = string.replace(" .", " 0.") 137 | string = string.replace("{.", "{0.") 138 | # if empty, return empty string 139 | if len(string) == 0: 140 | return string 141 | if string[0] == ".": 142 | string = "0" + string 143 | 144 | # to consider: get rid of e.g. "k = " or "q = " at beginning 145 | if len(string.split("=")) == 2: 146 | if len(string.split("=")[0]) <= 2: 147 | string = string.split("=")[1] 148 | 149 | # fix sqrt3 --> sqrt{3} 150 | string = _fix_sqrt(string) 151 | 152 | # remove spaces 153 | string = string.replace(" ", "") 154 | 155 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 156 | string = _fix_fracs(string) 157 | 158 | # manually change 0.5 --> \frac{1}{2} 159 | if string == "0.5": 160 | string = "\\frac{1}{2}" 161 | 162 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 163 | string = _fix_a_slash_b(string) 164 | 165 | return string 166 | 167 | def is_equiv(str1, str2, verbose=False): 168 | if str1 is None and str2 is None: 169 | print("WARNING: Both None") 170 | return True 171 | if str1 is None or str2 is None: 172 | return False 173 | 174 | try: 175 | ss1 = _strip_string(str1) 176 | ss2 = _strip_string(str2) 177 | if verbose: 178 | print(ss1, ss2) 179 | return ss1 == ss2 180 | except: 181 | return str1 == str2 182 | 183 | 184 | def last_boxed_only_string(string): 185 | idx = string.rfind("\\boxed") 186 | if idx < 0: 187 | idx = string.rfind("\\fbox") 188 | if idx < 0: 189 | return None 190 | 191 | i = idx 192 | right_brace_idx = None 193 | num_left_braces_open = 0 194 | while i < len(string): 195 | if string[i] == "{": 196 | num_left_braces_open += 1 197 | if string[i] == "}": 198 | num_left_braces_open -= 1 199 | if num_left_braces_open == 0: 200 | right_brace_idx = i 201 | break 202 | i += 1 203 | 204 | if right_brace_idx == None: 205 | retval = None 206 | else: 207 | retval = string[idx:right_brace_idx + 1] 208 | 209 | return retval 210 | 211 | 212 | def remove_boxed(s): 213 | left = "\\boxed{" 214 | try: 215 | assert s[:len(left)] == left 216 | assert s[-1] == "}" 217 | return s[len(left):-1] 218 | except: 219 | return None 220 | 221 | def MATH_get_predict(pred_str): 222 | if '\\boxed' in pred_str: 223 | pred = remove_boxed(last_boxed_only_string(pred_str)) 224 | return pred.strip() if pred is not None else "0" 225 | elif('answer is ' in pred_str): 226 | pred = pred_str.split('answer is ')[-1].strip().rstrip(".") 227 | return pred.strip() 228 | elif len(pred_str) > 0: 229 | return pred_str[-1] 230 | else: 231 | return "A" 232 | 233 | 234 | def MATH_is_correct(pred,reference): 235 | true_answer_str = remove_boxed(last_boxed_only_string(reference)) 236 | if pred is not None and is_equiv(true_answer_str, pred): 237 | return True 238 | return False 239 | 240 | 241 | -------------------------------------------------------------------------------- /Datasets/mbpp_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Literal 2 | import pandas as pd 3 | 4 | class MbppDataset: 5 | def __init__(self, split: Union[Literal['train'], Literal['val'], Literal['test'], Literal['prompt']],): 6 | self._splits = {'train': 'full/train-00000-of-00001.parquet', 'test': 'full/test-00000-of-00001.parquet', 'val': 'full/validation-00000-of-00001.parquet', 'prompt': 'full/prompt-00000-of-00001.parquet'} 7 | self.df = pd.read_parquet("hf://datasets/google-research-datasets/mbpp/" + self._splits[split]) 8 | # self.df = self.df.sample(frac=0.2).reset_index(drop=True) 9 | self.df = process_data(self.df) 10 | 11 | def __len__(self): 12 | return len(self.df) 13 | 14 | def __getitem__(self, index): 15 | return self.df.iloc[index] 16 | 17 | class MbppDataLoader: 18 | def __init__(self, dataset, batch_size=1, shuffle=False): 19 | self.dataset = dataset 20 | self.batch_size = batch_size 21 | self.shuffle = shuffle 22 | self.indices = list(range(len(dataset))) 23 | if self.shuffle: 24 | self._shuffle_indices() 25 | self.index = 0 26 | 27 | def _shuffle_indices(self): 28 | import random 29 | random.shuffle(self.indices) 30 | 31 | def __iter__(self): 32 | batch = [] 33 | for i in range(len(self.indices)): 34 | batch.append(self.dataset[self.indices[i]]) 35 | if len(batch) == self.batch_size or i == len(self.indices) - 1: 36 | yield batch 37 | batch = [] 38 | 39 | def __next__(self): 40 | if self.index >= len(self.dataset): 41 | raise StopIteration 42 | 43 | batch_indices = self.indices[self.index:self.index + self.batch_size] 44 | batch = [self.dataset[i] for i in batch_indices] 45 | self.index += self.batch_size 46 | return batch 47 | 48 | def process_data(df: pd.DataFrame): 49 | tasks = [] 50 | for i, data_entry in df.iterrows(): 51 | prompt = data_entry["text"] 52 | test_case = data_entry["test_list"] 53 | tests = "" 54 | for test in test_case: 55 | tests+="\n"+test 56 | text = f""" 57 | **Task**: 58 | ```python 59 | {prompt} 60 | ``` 61 | Your code should pass these tests: 62 | ```python 63 | {tests} 64 | ``` 65 | """ 66 | tasks.append(text) 67 | df["task"] = tasks 68 | return df 69 | 70 | MbppDataset(split='test') -------------------------------------------------------------------------------- /Datasets/mmlu_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | from typing import Union, List, Literal, Any, Dict 4 | import numpy as np 5 | from abc import ABC 6 | 7 | class MMLUDataset(ABC): 8 | def __init__(self, 9 | split: Union[Literal['dev'], Literal['val'], Literal['test']], 10 | ) -> None: 11 | 12 | self._split = split 13 | 14 | data_path = f"datasets/MMLU/data/{self._split}/" 15 | self._total_df: pd.DataFrame = self._load_data(data_path) 16 | 17 | @staticmethod 18 | def get_domain() -> str: 19 | return 'mmlu' 20 | 21 | @staticmethod 22 | def _load_data( 23 | data_path: str, 24 | ) -> pd.DataFrame: 25 | 26 | rng = np.random.default_rng(888) 27 | 28 | csv_paths = glob.glob(data_path + "*.csv") 29 | csv_paths = sorted(csv_paths) 30 | print("Number of topics: ", len(csv_paths)) 31 | 32 | names = ['question', 'A', 'B', 'C', 'D', 'correct_answer'] 33 | 34 | total_df = pd.DataFrame(columns=names) 35 | for path in csv_paths: 36 | single_df = pd.read_csv(path, header=None, 37 | names=names,encoding='utf-8') 38 | total_df = pd.concat([total_df, single_df]) 39 | 40 | total_df = total_df.reset_index(drop=True) 41 | 42 | # Pseudorandom shuffle 43 | total_df = total_df.reindex(rng.permutation(total_df.index)) 44 | 45 | print("Total number of questions: ", len(total_df)) 46 | 47 | return total_df 48 | 49 | @property 50 | def split(self) -> str: 51 | return self._split 52 | 53 | def __len__(self) -> int: 54 | return len(self._total_df) 55 | 56 | def __getitem__(self, index: int) -> Union[pd.DataFrame, pd.Series]: 57 | record = self._total_df.iloc[index] 58 | assert isinstance(record, pd.DataFrame) or isinstance(record, pd.Series) 59 | return record 60 | 61 | @staticmethod 62 | def record_to_input(record: Union[pd.DataFrame, pd.Series]) -> Dict[str, Any]: 63 | demo_question = ( 64 | f"{record['question']}\n" 65 | f"Option A: {record['A']}\n" 66 | f"Option B: {record['B']}\n" 67 | f"Option C: {record['C']}\n" 68 | f"Option D: {record['D']}\n" 69 | ) 70 | input_dict = {"task": demo_question} 71 | return input_dict 72 | 73 | def postprocess_answer(self, answer: Union[str, List[str]]) -> str: 74 | if isinstance(answer, list): 75 | if len(answer) > 0: 76 | answer = answer[0] 77 | else: 78 | answer = "" 79 | if not isinstance(answer, str): 80 | raise Exception("Expected string") 81 | if len(answer) > 0: 82 | ans_pos = answer.find("answer is") 83 | if ans_pos != -1: 84 | answer = answer[ans_pos+len("answer is"):].strip(":").strip().strip("Option").strip() 85 | answer = answer[0] # Try to format the answer by taking the first letter 86 | return answer 87 | 88 | @staticmethod 89 | def record_to_target_answer(record: Union[pd.DataFrame, pd.Series]) -> str: 90 | correct_answer = record['correct_answer'] 91 | assert isinstance(correct_answer, str), ( 92 | f"String expected but got {correct_answer} " 93 | f"of type {type(correct_answer)} (2)" \ 94 | f" record={record}") 95 | return correct_answer 96 | -------------------------------------------------------------------------------- /Datasets/sample_MATH.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import shutil 5 | import math 6 | 7 | # 设置路径 8 | test_dir = '/Users/lby/Desktop/master/code/MAR/Datasets/MATH/test' 9 | output_dir = '/Users/lby/Desktop/master/code/MAR/Datasets/MATH/sampled_test' 10 | 11 | random.seed(42) # 固定随机种子以确保可重复性 12 | 13 | # 创建输出目录 14 | os.makedirs(output_dir, exist_ok=True) 15 | 16 | # 遍历每个类别目录 17 | for category in os.listdir(test_dir): 18 | category_path = os.path.join(test_dir, category) 19 | if not os.path.isdir(category_path): 20 | continue # 跳过非目录文件 21 | 22 | level_groups = {} # 按level分组存储问题文件信息 23 | 24 | # 遍历该类别下的所有JSON文件 25 | for problem_file in os.listdir(category_path): 26 | if not problem_file.endswith('.json'): 27 | continue # 仅处理JSON文件 28 | 29 | file_path = os.path.join(category_path, problem_file) 30 | 31 | # 读取问题数据 32 | try: 33 | with open(file_path, 'r', encoding='utf-8') as f: 34 | problem_data = json.load(f) 35 | except Exception as e: 36 | print(f"Error reading {file_path}: {e}") 37 | continue 38 | 39 | level = problem_data.get('level', 'Unknown') 40 | 41 | if level not in level_groups: 42 | level_groups[level] = [] 43 | level_groups[level].append((problem_file, file_path)) 44 | 45 | # 对每个level进行抽样并收集结果 46 | sampled_files = [] 47 | for level, files in level_groups.items(): 48 | total = len(files) 49 | sample_size = max(1, math.ceil(total * 0.1)) # 向上取整,至少1个 50 | 51 | # 随机抽样 52 | try: 53 | sampled = random.sample(files, sample_size) 54 | except ValueError as e: 55 | print(f"Error sampling {level} in {category}: {e}") 56 | continue 57 | 58 | sampled_files.extend(sampled) 59 | 60 | # 创建输出目录并复制文件 61 | output_category_dir = os.path.join(output_dir, category) 62 | os.makedirs(output_category_dir, exist_ok=True) 63 | 64 | for problem_file, src_path in sampled_files: 65 | dest_path = os.path.join(output_category_dir, problem_file) 66 | shutil.copy2(src_path, dest_path) # 保留元数据 67 | 68 | print("抽样完成!结果保存在:", output_dir) -------------------------------------------------------------------------------- /Experiments/run_gsm8k.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import yaml 5 | import json 6 | import time 7 | import io 8 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 10 | 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | from MAR.MasRouter.mas_router import MasRouter 16 | from MAR.LLM.llm_profile import llm_profile 17 | from MAR.Agent.reasoning_profile import reasoning_profile 18 | from MAR.Prompts.tasks_profile import tasks_profile 19 | from MAR.Tools.reader.readers import JSONLReader 20 | from MAR.Utils.utils import fix_random_seed,split_list 21 | from MAR.Utils.globals import Cost, PromptTokens, CompletionTokens 22 | from Datasets.gsm8k_dataset import gsm_data_process, gsm_get_predict 23 | from MAR.Utils.log import configure_logging 24 | from loguru import logger 25 | 26 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 27 | 28 | def load_result(result_file): 29 | if not result_file.exists(): 30 | with open(result_file, 'w',encoding='utf-8') as file: 31 | json.dump([], file) 32 | 33 | with open(result_file, 'r',encoding='utf-8') as file: 34 | data = json.load(file) 35 | return data 36 | 37 | def dataloader(data_list, batch_size, i_batch): 38 | return data_list[i_batch*batch_size:i_batch*batch_size + batch_size] 39 | 40 | def load_config(config_path): 41 | with open(config_path, 'r',encoding='utf-8') as file: 42 | return yaml.safe_load(file) 43 | 44 | def parse_args(): 45 | parser = argparse.ArgumentParser(description="AgentPrune Experiments on gsm8k") 46 | parser.add_argument("--dataset_json", type=str, default="datasets/gsm8k/gsm8k.jsonl") 47 | parser.add_argument("--result_file", type=str, default=None) 48 | parser.add_argument('--lr', type=float, default=0.01,help="learning rate") 49 | parser.add_argument('--batch_size', type=int, default=16,help="batch size") 50 | parser.add_argument('--epochs', type=int, default=5, help="Default 5.") 51 | parser.add_argument('--num_rounds',type=int,default=1,help="Number of optimization/inference rounds for one query") 52 | parser.add_argument('--domain', type=str, default="gsm8k",help="Domain (the same as dataset name), default 'gsm8k'") 53 | parser.add_argument('--decision_method', type=str, default='FinalRefer', 54 | help='The decison method of the agentprune') 55 | parser.add_argument('--prompt_file', type=str, default='MAR/Roles/FinalNode/gsm8k.json') 56 | parser.add_argument('--start_epoch', type=int, default=0) 57 | parser.add_argument('--cost_rate', type=float, default=200.0) 58 | parser.add_argument('--max_agent', type=int, default=6) 59 | args = parser.parse_args() 60 | return args 61 | 62 | 63 | if __name__ == '__main__': 64 | args = parse_args() 65 | dataset = JSONLReader().parse_file("Datasets/gsm8k/sample_gsm8k.jsonl") 66 | dataset = gsm_data_process(dataset) 67 | train_dataset, test_dataset = split_list(dataset, 0.2) 68 | 69 | current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 70 | log_file = f"gsm8k_{current_time}.txt" 71 | fix_random_seed(1234) 72 | configure_logging(log_name=log_file) 73 | 74 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 75 | router = MasRouter(max_agent=args.max_agent, device=device).to(device) 76 | optimizer = torch.optim.Adam(router.parameters(), lr=args.lr) 77 | tasks = tasks_profile 78 | llms = llm_profile 79 | reasonings = reasoning_profile 80 | 81 | logger.info("Start training...") 82 | train_batches = int(len(train_dataset)/args.batch_size) 83 | for epoch in range(args.epochs): 84 | logger.info(f"Epoch {epoch}",80*'-') 85 | total_solved, total_executed = (0, 0) 86 | if epoch < args.start_epoch: 87 | router.load_state_dict(torch.load(f"gsm8k_router_epoch{epoch}_newnew.pth", map_location=torch.device('cuda'))) 88 | continue 89 | for i_batch in range(train_batches): 90 | logger.info(f"Batch {i_batch}",80*'-') 91 | start_ts = time.time() 92 | current_batch = dataloader(train_dataset, args.batch_size,i_batch) 93 | queries = [item['task'] for item in current_batch] 94 | answers = [item['answer'] for item in current_batch] 95 | task_labels = [0 for _ in current_batch] 96 | tasks_y = torch.tensor(task_labels).to(device) 97 | optimizer.zero_grad() 98 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels) 99 | 100 | task_loss = F.cross_entropy(tasks_probs, tasks_y) 101 | agent_num_loss = 0 102 | utilities = [] 103 | answers_loss = [] 104 | is_solved_list = [] 105 | for result, true_answer, log_prob, cost in zip(results, answers, log_probs, costs): 106 | predict_answer = gsm_get_predict(result) 107 | is_solved = float(predict_answer)==float(true_answer) 108 | total_solved = total_solved + is_solved 109 | total_executed = total_executed + 1 110 | utility = is_solved - cost * args.cost_rate 111 | utilities.append(utility) 112 | is_solved_list.append(is_solved) 113 | answer_loss = -log_prob * utility 114 | answers_loss.append(answer_loss) 115 | 116 | answer_loss = torch.stack(answers_loss).sum() / len(answers_loss) 117 | vae_loss = vae_loss.mean() 118 | is_solved_tensor = torch.tensor(is_solved_list, dtype=torch.float32, device=device).unsqueeze(1) # shape: [N, 1] 119 | # adjust_loss = ((1 - is_solved_tensor) * (router.num_determiner.max_agent - agents_num) + 0.25 * is_solved_tensor * agents_num).mean() 120 | 121 | loss = task_loss + answer_loss + vae_loss*0.001 #+ adjust_loss 122 | loss.backward() 123 | optimizer.step() 124 | 125 | accuracy = total_solved / total_executed 126 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 127 | logger.info(f"Accuracy: {accuracy}") 128 | logger.info(f"utilities:{utilities}") 129 | 130 | logger.info(f"Epoch {epoch} Finishes",80*'-') 131 | torch.save(router.state_dict(), f"gsm8k_router_epoch{epoch}_newnew.pth") 132 | 133 | logger.info("Finish training...") 134 | logger.info("Start testing...") 135 | 136 | test_batches = int(len(test_dataset)/args.batch_size) 137 | total_solved, total_executed = (0, 0) 138 | for i_batch in range(test_batches): 139 | logger.info(f"Batch {i_batch}",80*'-') 140 | start_ts = time.time() 141 | current_batch = dataloader(test_dataset,args.batch_size,i_batch) 142 | queries = [item['task'] for item in current_batch] 143 | answers = [item['answer'] for item in current_batch] 144 | task_labels = [0 for _ in current_batch] 145 | tasks_y = torch.tensor(task_labels).to(device) 146 | optimizer.zero_grad() 147 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels) 148 | utilities = [] 149 | for result, true_answer, log_prob, cost in zip(results, answers, log_probs, costs): 150 | predict_answer = gsm_get_predict(result) 151 | is_solved = float(predict_answer)==float(true_answer) 152 | total_solved = total_solved + is_solved 153 | total_executed = total_executed + 1 154 | utility = is_solved - cost * args.cost_rate 155 | utilities.append(utility) 156 | 157 | accuracy = total_solved / total_executed 158 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 159 | logger.info(f"Accuracy: {accuracy}") 160 | logger.info(f"utilities:{utilities}") 161 | logger.info("Finish testing...") -------------------------------------------------------------------------------- /Experiments/run_humaneval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import io 4 | 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 6 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') 7 | 8 | import time 9 | import argparse 10 | import yaml 11 | import json 12 | import time 13 | import re 14 | import torch 15 | from loguru import logger 16 | import torch.nn.functional as F 17 | 18 | from MAR.MasRouter.mas_router import MasRouter 19 | from MAR.LLM.llm_profile import llm_profile 20 | from MAR.Agent.reasoning_profile import reasoning_profile 21 | from MAR.Prompts.tasks_profile import tasks_profile 22 | from MAR.Tools.reader.readers import JSONLReader 23 | from MAR.Tools.coding.python_executor import PyExecutor 24 | from MAR.Utils.utils import fix_random_seed, split_list 25 | from MAR.Utils.globals import Cost, PromptTokens, CompletionTokens 26 | from MAR.Utils.log import configure_logging 27 | 28 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 29 | 30 | def load_result(result_file): 31 | if not result_file.exists(): 32 | with open(result_file, 'w',encoding='utf-8') as file: 33 | json.dump([], file) 34 | 35 | with open(result_file, 'r',encoding='utf-8') as file: 36 | data = json.load(file) 37 | return data 38 | 39 | def dataloader(data_list, batch_size, i_batch): 40 | return data_list[i_batch*batch_size:i_batch*batch_size + batch_size] 41 | 42 | def load_config(config_path): 43 | with open(config_path, 'r',encoding='utf-8') as file: 44 | return yaml.safe_load(file) 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser(description="AgentPrune Experiments on humaneval") 48 | parser.add_argument("--dataset_json", type=str, default="Datasets/humaneval/humaneval-py.jsonl") 49 | parser.add_argument("--result_file", type=str, default=None) 50 | parser.add_argument('--lr', type=float, default=0.01,help="learning rate") 51 | parser.add_argument('--batch_size', type=int, default=16,help="batch size") 52 | parser.add_argument('--epochs', type=int, default=10, help="Default 5.") 53 | parser.add_argument('--num_rounds',type=int,default=1,help="Number of optimization/inference rounds for one query") 54 | parser.add_argument('--domain', type=str, default="humaneval",help="Domain (the same as dataset name), default 'humaneval'") 55 | parser.add_argument('--decision_method', type=str, default='FinalRefer', 56 | help='The decison method of the agentprune') 57 | parser.add_argument('--prompt_file', type=str, default='MAR/Roles/FinalNode/humaneval.json') 58 | parser.add_argument('--start_epoch', type=int, default=0) 59 | parser.add_argument('--cost_rate', type=float, default=200.0) 60 | parser.add_argument('--max_agent', type=int, default=6) 61 | args = parser.parse_args() 62 | return args 63 | 64 | 65 | if __name__ == '__main__': 66 | args = parse_args() 67 | dataset = JSONLReader().parse_file("Datasets/humaneval/humaneval-py.jsonl") 68 | train_dataset, test_dataset = split_list(dataset, 0.2) 69 | current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 70 | log_file = f"humaneval_{current_time}.txt" 71 | fix_random_seed(1234) 72 | configure_logging(log_name=log_file) 73 | 74 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 75 | router = MasRouter(max_agent=args.max_agent,device=device).to(device) 76 | optimizer = torch.optim.Adam(router.parameters(), lr=args.lr) 77 | tasks = tasks_profile 78 | llms = llm_profile 79 | reasonings = reasoning_profile 80 | 81 | logger.info("Start training...") 82 | for epoch in range(args.epochs): 83 | if epoch < args.start_epoch: 84 | router.load_state_dict(torch.load(f"humaneval_router_epoch{epoch}.pth", map_location=torch.device('cuda'))) 85 | continue 86 | logger.info(f"Epoch {epoch}",80*'-') 87 | train_batches = int(len(train_dataset)/args.batch_size) 88 | total_solved, total_executed = (0, 0) 89 | for i_batch in range(train_batches): 90 | logger.info(f"Batch {i_batch}",80*'-') 91 | start_ts = time.time() 92 | current_batch = dataloader(train_dataset,args.batch_size,i_batch) 93 | queries = [item['prompt'] for item in current_batch] 94 | tests = [item['test'] for item in current_batch] 95 | task_labels = [2 for _ in current_batch] 96 | tasks_y = torch.tensor(task_labels).to(device) 97 | optimizer.zero_grad() 98 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels, prompt_file=args.prompt_file) 99 | 100 | task_loss = F.cross_entropy(tasks_probs, tasks_y) 101 | utilities = [] 102 | answers_loss = [] 103 | is_solved_list = [] 104 | pattern = r'```python.*```' 105 | for query, result, test, log_prob, cost in zip(queries, results, tests, log_probs, costs): 106 | match = re.search(pattern, result, re.DOTALL|re.MULTILINE) 107 | if match: 108 | answer = match.group(0).lstrip("```python\n").rstrip("\n```") 109 | is_solved, _, _ = PyExecutor().execute(answer, [test], timeout=100) 110 | else: 111 | answer = "" 112 | is_solved = 0 113 | total_solved = total_solved + is_solved 114 | total_executed = total_executed + 1 115 | utility = is_solved - cost * args.cost_rate 116 | utilities.append(utility) 117 | is_solved_list.append(is_solved) 118 | answer_loss = -log_prob * utility 119 | answers_loss.append(answer_loss) 120 | 121 | answer_loss = torch.stack(answers_loss).sum() / len(answers_loss) 122 | vae_loss = vae_loss.mean() 123 | is_solved_tensor = torch.tensor(is_solved_list, dtype=torch.float32, device=device).unsqueeze(1) # shape: [N, 1] 124 | adjust_loss = ((1 - is_solved_tensor) * (router.num_determiner.max_agent - agents_num) + 0.25 * is_solved_tensor * agents_num).mean() 125 | 126 | loss = task_loss + answer_loss + vae_loss*0.001 # + adjust_loss 127 | loss.backward() 128 | optimizer.step() 129 | 130 | accuracy = total_solved / total_executed 131 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 132 | logger.info(f"Accuracy: {accuracy}") 133 | logger.info(f"utilities:{utilities}") 134 | logger.info(f"Epoch {epoch} Finishes",80*'-') 135 | torch.save(router.state_dict(), f"humaneval_router_epoch{epoch}.pth") 136 | logger.info("Finish training...") 137 | logger.info("Start testing...") 138 | test_batches = int(len(test_dataset)/args.batch_size) 139 | total_solved, total_executed = (0, 0) 140 | 141 | for i_batch in range(test_batches): 142 | logger.info(f"Batch {i_batch}",80*'-') 143 | start_ts = time.time() 144 | current_batch = dataloader(test_dataset,args.batch_size,i_batch) 145 | queries = [item['prompt'] for item in current_batch] 146 | tests = [item['test'] for item in current_batch] 147 | task_labels = [2 for _ in current_batch] 148 | tasks_y = torch.tensor(task_labels).to(device) 149 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels, prompt_file=args.prompt_file) 150 | 151 | utilities = [] 152 | pattern = r'```python.*```' 153 | for query, result, test, log_prob, cost in zip(queries, results, tests, log_probs, costs): 154 | match = re.search(pattern, result, re.DOTALL|re.MULTILINE) 155 | if match: 156 | answer = match.group(0).lstrip("```python\n").rstrip("\n```") 157 | is_solved, _, _ = PyExecutor().execute(answer, [test], timeout=100) 158 | else: 159 | is_solved = 0 160 | total_solved = total_solved + is_solved 161 | total_executed = total_executed + 1 162 | utility = is_solved - cost * args.cost_rate 163 | utilities.append(utility) 164 | 165 | accuracy = total_solved / total_executed 166 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 167 | logger.info(f"Accuracy: {accuracy}") 168 | logger.info(f"utilities:{utilities}") 169 | logger.info("Finish testing...") -------------------------------------------------------------------------------- /Experiments/run_math.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | import yaml 5 | import json 6 | import time 7 | import torch 8 | import io 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 10 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | from MAR.MasRouter.mas_router import MasRouter 16 | from MAR.LLM.llm_profile import llm_profile 17 | from MAR.Agent.reasoning_profile import reasoning_profile 18 | from MAR.Prompts.tasks_profile import tasks_profile 19 | from MAR.Utils.utils import fix_random_seed 20 | from MAR.Utils.globals import Cost, PromptTokens, CompletionTokens 21 | from Datasets.math_dataset import load_math_dataset,MATH_is_correct,MATH_get_predict 22 | from MAR.Utils.log import configure_logging 23 | from loguru import logger 24 | 25 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 26 | 27 | def load_result(result_file): 28 | if not result_file.exists(): 29 | with open(result_file, 'w',encoding='utf-8') as file: 30 | json.dump([], file) 31 | 32 | with open(result_file, 'r',encoding='utf-8') as file: 33 | data = json.load(file) 34 | return data 35 | 36 | def dataloader(data_list, batch_size, i_batch): 37 | return data_list[i_batch*batch_size:i_batch*batch_size + batch_size] 38 | 39 | def load_config(config_path): 40 | with open(config_path, 'r',encoding='utf-8') as file: 41 | return yaml.safe_load(file) 42 | 43 | def parse_args(): 44 | parser = argparse.ArgumentParser(description="AgentPrune Experiments on MATH") 45 | parser.add_argument("--result_file", type=str, default=None) 46 | parser.add_argument('--lr', type=float, default=0.01,help="learning rate") 47 | parser.add_argument('--batch_size', type=int, default=16,help="batch size") 48 | parser.add_argument('--epochs', type=int, default=5, help="Prune every few iterations. Default 5.") 49 | parser.add_argument('--num_rounds',type=int,default=1,help="Number of optimization/inference rounds for one query") 50 | parser.add_argument('--domain', type=str, default="gsm8k",help="Domain (the same as dataset name), default 'gsm8k'") 51 | parser.add_argument('--decision_method', type=str, default='FinalRefer', 52 | help='The decison method of the agentprune') 53 | parser.add_argument('--prompt_file', type=str, default='MAR/Roles/FinalNode/math.json') 54 | parser.add_argument('--start_epoch', type=int, default=0) 55 | parser.add_argument('--cost_rate', type=float, default=100.0) 56 | parser.add_argument('--max_agent', type=int, default=6) 57 | args = parser.parse_args() 58 | return args 59 | 60 | 61 | if __name__ == '__main__': 62 | args = parse_args() 63 | train_dataset = load_math_dataset("Datasets/MATH",split="sampled_train") 64 | test_dataset = load_math_dataset("Datasets/MATH",split="sampled_test") 65 | current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 66 | log_file = f"MATH_{current_time}.txt" 67 | fix_random_seed(1234) 68 | configure_logging(log_name=log_file) 69 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 70 | router = MasRouter(max_agent=args.max_agent,device=device).to(device) 71 | optimizer = torch.optim.Adam(router.parameters(), lr=args.lr) 72 | tasks = tasks_profile 73 | llms = llm_profile 74 | reasonings = reasoning_profile 75 | 76 | logger.info("Start training...") 77 | num_batches = int(len(train_dataset)/args.batch_size) 78 | 79 | for epoch in range(args.epochs): 80 | logger.info(f"Epoch {epoch}",80*'-') 81 | total_solved, total_executed = (0, 0) 82 | if epoch < args.start_epoch: 83 | router.load_state_dict(torch.load(f"math_router_epoch{epoch}.pth", map_location=torch.device('cuda'))) 84 | continue 85 | for i_batch in range(num_batches): 86 | logger.info(f"Batch {i_batch}",80*'-') 87 | start_ts = time.time() 88 | current_batch = dataloader(train_dataset,args.batch_size,i_batch) 89 | queries = [item['problem'] for item in current_batch] 90 | answers = [item['solution'] for item in current_batch] 91 | task_labels = [0 for _ in current_batch] 92 | tasks_y = torch.tensor(task_labels).to(device) 93 | optimizer.zero_grad() 94 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels,prompt_file=args.prompt_file) 95 | 96 | task_loss = F.cross_entropy(tasks_probs, tasks_y) 97 | agent_num_loss = 0 98 | utilities = [] 99 | answers_loss = [] 100 | is_solved_list = [] 101 | for result, true_answer, log_prob, cost in zip(results, answers, log_probs, costs): 102 | predict_answer = MATH_get_predict(result) 103 | is_solved = MATH_is_correct(predict_answer,true_answer) 104 | total_solved = total_solved + is_solved 105 | total_executed = total_executed + 1 106 | utility = is_solved - cost * args.cost_rate 107 | utilities.append(utility) 108 | is_solved_list.append(is_solved) 109 | answer_loss:torch.Tensor = -log_prob * utility 110 | answers_loss.append(answer_loss) 111 | 112 | answer_loss = torch.stack(answers_loss).sum() / len(answers_loss) 113 | vae_loss = vae_loss.mean() 114 | is_solved_tensor = torch.tensor(is_solved_list, dtype=torch.float32, device=device).unsqueeze(1) # shape: [N, 1] 115 | adjust_loss = ((1 - is_solved_tensor) * (router.num_determiner.max_agent - agents_num) + 0.25 * is_solved_tensor * agents_num).mean() 116 | 117 | loss = task_loss + answer_loss + vae_loss*0.001 # + adjust_loss 118 | loss.backward() 119 | optimizer.step() 120 | 121 | accuracy = total_solved / total_executed 122 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 123 | logger.info(f"Accuracy: {accuracy}") 124 | logger.info(f"utilities:{utilities}") 125 | torch.save(router.state_dict(), f"math_router_epoch{epoch}_new.pth") 126 | logger.info("Finish training...") 127 | logger.info("Start testing...") 128 | total_solved, total_executed = (0, 0) 129 | num_batches = int(len(test_dataset)/args.batch_size) 130 | 131 | for i_batch in range(num_batches): 132 | logger.info(f"Batch {i_batch}",80*'-') 133 | start_ts = time.time() 134 | current_batch = dataloader(test_dataset,args.batch_size,i_batch) 135 | queries = [item['problem'] for item in current_batch] 136 | answers = [item['solution'] for item in current_batch] 137 | task_labels = [0 for _ in current_batch] 138 | tasks_y = torch.tensor(task_labels).to(device) 139 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels,prompt_file=args.prompt_file) 140 | 141 | utilities = [] 142 | for result, true_answer, log_prob, cost in zip(results, answers, log_probs, costs): 143 | predict_answer = MATH_get_predict(result) 144 | is_solved = MATH_is_correct(predict_answer,true_answer) 145 | total_solved = total_solved + is_solved 146 | total_executed = total_executed + 1 147 | utility = is_solved - cost * args.cost_rate 148 | utilities.append(utility) 149 | logger.debug(f"Predict: {predict_answer}") 150 | logger.debug(f"Truth: {true_answer}") 151 | 152 | accuracy = total_solved / total_executed 153 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 154 | logger.info(f"Accuracy: {accuracy}") 155 | logger.info(f"utilities:{utilities}") 156 | logger.info("Finish testing...") 157 | -------------------------------------------------------------------------------- /Experiments/run_mbpp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import io 4 | 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 6 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') 7 | 8 | import time 9 | import argparse 10 | import yaml 11 | import json 12 | import re 13 | import torch 14 | from loguru import logger 15 | import torch.nn.functional as F 16 | 17 | from MAR.MasRouter.mas_router import MasRouter 18 | from MAR.LLM.llm_profile import llm_profile 19 | from MAR.Agent.reasoning_profile import reasoning_profile 20 | from MAR.Prompts.tasks_profile import tasks_profile 21 | from MAR.Tools.coding.python_executor import PyExecutor 22 | from MAR.Utils.utils import fix_random_seed 23 | from MAR.Utils.globals import Cost, PromptTokens, CompletionTokens 24 | from MAR.Utils.log import configure_logging 25 | 26 | from Datasets.mbpp_dataset import MbppDataset, MbppDataLoader 27 | 28 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 29 | 30 | def load_result(result_file): 31 | if not result_file.exists(): 32 | with open(result_file, 'w',encoding='utf-8') as file: 33 | json.dump([], file) 34 | 35 | with open(result_file, 'r',encoding='utf-8') as file: 36 | data = json.load(file) 37 | return data 38 | 39 | def dataloader(data_list, batch_size, i_batch): 40 | return data_list[i_batch*batch_size:i_batch*batch_size + batch_size] 41 | 42 | def load_config(config_path): 43 | with open(config_path, 'r',encoding='utf-8') as file: 44 | return yaml.safe_load(file) 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser(description="AgentPrune Experiments on mbpp") 48 | parser.add_argument("--dataset_json", type=str, default="Datasets/mbpp/mbpp.jsonl") 49 | parser.add_argument("--result_file", type=str, default=None) 50 | parser.add_argument('--lr', type=float, default=0.01,help="learning rate") 51 | parser.add_argument('--batch_size', type=int, default=16,help="batch size") 52 | parser.add_argument('--epochs', type=int, default=10, help="Default 10.") 53 | parser.add_argument('--num_rounds',type=int,default=1,help="Number of optimization/inference rounds for one query") 54 | parser.add_argument('--domain', type=str, default="mbpp",help="Domain (the same as dataset name), default 'mbpp'") 55 | parser.add_argument('--decision_method', type=str, default='FinalRefer', 56 | help='The decison method of the agentprune') 57 | parser.add_argument('--prompt_file', type=str, default='MAR/Roles/FinalNode/mbpp.json') 58 | parser.add_argument('--start_epoch', type=int, default=0) 59 | parser.add_argument('--cost_rate', type=float, default=400.0) 60 | parser.add_argument('--max_agent', type=int, default=6) 61 | args = parser.parse_args() 62 | return args 63 | 64 | 65 | if __name__ == '__main__': 66 | args = parse_args() 67 | train_dataset = MbppDataset('train') 68 | test_dataset = MbppDataset('test') 69 | 70 | current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 71 | log_file = f"mbpp_{current_time}.txt" 72 | fix_random_seed(1234) 73 | configure_logging(log_name=log_file) 74 | 75 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 76 | router = MasRouter(max_agent=args.max_agent,device=device).to(device) 77 | optimizer = torch.optim.Adam(router.parameters(), lr=args.lr) 78 | tasks = tasks_profile 79 | llms = llm_profile 80 | reasonings = reasoning_profile 81 | logger.info("Start training...") 82 | 83 | for epoch in range(args.epochs): 84 | logger.info(f"Epoch {epoch}",80*'-') 85 | total_solved, total_executed = (0, 0) 86 | train_loader = MbppDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 87 | if epoch < args.start_epoch: 88 | router.load_state_dict(torch.load(f"mbpp_router_epoch{epoch}_new.pth", map_location=torch.device('cuda'))) 89 | continue 90 | for i_batch, current_batch in enumerate(train_loader): 91 | logger.info(f"Batch {i_batch}",80*'-') 92 | start_ts = time.time() 93 | queries = [item['task'] for item in current_batch] 94 | tests = [item['test_list'] for item in current_batch] 95 | task_labels = [2 for _ in current_batch] 96 | tasks_y = torch.tensor(task_labels).to(device) 97 | optimizer.zero_grad() 98 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels, prompt_file=args.prompt_file) 99 | 100 | task_loss = F.cross_entropy(tasks_probs, tasks_y) 101 | utilities = [] 102 | answers_loss = [] 103 | is_solved_list = [] 104 | pattern = r'```python.*```' 105 | for query, result, test, log_prob, cost in zip(queries, results, tests, log_probs, costs): 106 | match = re.search(pattern, result, re.DOTALL|re.MULTILINE) 107 | if match: 108 | answer = match.group(0).lstrip("```python\n").rstrip("\n```") 109 | is_solved, _, _ = PyExecutor().execute(answer, test, timeout=100) 110 | else: 111 | is_solved = 0 112 | total_solved = total_solved + is_solved 113 | total_executed = total_executed + 1 114 | utility = is_solved - cost * args.cost_rate 115 | utilities.append(utility) 116 | is_solved_list.append(is_solved) 117 | answer_loss = -log_prob * utility 118 | answers_loss.append(answer_loss) 119 | answer_loss = torch.stack(answers_loss).sum() / len(answers_loss) 120 | vae_loss = vae_loss.mean() 121 | is_solved_tensor = torch.tensor(is_solved_list, dtype=torch.float32, device=device).unsqueeze(1) # shape: [N, 1] 122 | adjust_loss = ((1 - is_solved_tensor) * (router.num_determiner.max_agent - agents_num) + 0.25 * is_solved_tensor * agents_num).mean() 123 | 124 | loss = task_loss + answer_loss + vae_loss*0.001 # + adjust_loss 125 | loss.backward() 126 | optimizer.step() 127 | accuracy = total_solved / total_executed 128 | 129 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 130 | logger.info(f"Accuracy: {accuracy}") 131 | logger.info(f"utilities:{utilities}") 132 | torch.save(router.state_dict(), f"mbpp_router_epoch{epoch}_new.pth") 133 | logger.info("End training...") 134 | logger.info("Start testing...") 135 | total_solved, total_executed = (0, 0) 136 | test_loader = MbppDataLoader(test_dataset, batch_size=args.batch_size, shuffle=True) 137 | 138 | for i_batch, current_batch in enumerate(test_loader): 139 | start_ts = time.time() 140 | logger.info(f"Batch {i_batch}",80*'-') 141 | queries = [item['task'] for item in current_batch] 142 | tests = [item['test_list'] for item in current_batch] 143 | task_labels = [2 for _ in current_batch] 144 | tasks_y = torch.tensor(task_labels).to(device) 145 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels, prompt_file=args.prompt_file) 146 | utilities = [] 147 | pattern = r'```python.*```' 148 | for query, result, test, log_prob, cost in zip(queries, results, tests, log_probs, costs): 149 | match = re.search(pattern, result, re.DOTALL|re.MULTILINE) 150 | if match: 151 | answer = match.group(0).lstrip("```python\n").rstrip("\n```") 152 | is_solved, _, _ = PyExecutor().execute(answer, test, timeout=100) 153 | else: 154 | is_solved = 0 155 | total_solved = total_solved + is_solved 156 | total_executed = total_executed + 1 157 | utility = is_solved - cost * args.cost_rate 158 | utilities.append(utility) 159 | 160 | accuracy = total_solved / total_executed 161 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 162 | logger.info(f"Accuracy: {accuracy}") 163 | logger.info(f"utilities:{utilities}") 164 | logger.info("End testing...") 165 | -------------------------------------------------------------------------------- /Experiments/run_mmlu.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import io 4 | 5 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 6 | sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') 7 | 8 | import time 9 | import argparse 10 | import yaml 11 | import json 12 | import torch 13 | import numpy as np 14 | from loguru import logger 15 | import torch.nn.functional as F 16 | 17 | from MAR.MasRouter.mas_router import MasRouter 18 | from MAR.LLM.llm_profile import llm_profile 19 | from MAR.Agent.reasoning_profile import reasoning_profile 20 | from MAR.Prompts.tasks_profile import tasks_profile 21 | from MAR.Tools.reader.readers import JSONLReader 22 | from MAR.Tools.coding.python_executor import PyExecutor 23 | from MAR.Utils.utils import fix_random_seed 24 | from MAR.Utils.globals import Cost, PromptTokens, CompletionTokens 25 | from MAR.Utils.log import configure_logging 26 | from Datasets.mmlu_dataset import MMLUDataset 27 | from Datasets.MMLU.download import download 28 | from Datasets.math_dataset import MATH_get_predict 29 | 30 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 31 | 32 | def load_result(result_file): 33 | if not result_file.exists(): 34 | with open(result_file, 'w',encoding='utf-8') as file: 35 | json.dump([], file) 36 | 37 | with open(result_file, 'r',encoding='utf-8') as file: 38 | data = json.load(file) 39 | return data 40 | 41 | def dataloader(data_list, batch_size, i_batch): 42 | return data_list[i_batch*batch_size:i_batch*batch_size + batch_size] 43 | 44 | def load_config(config_path): 45 | with open(config_path, 'r',encoding='utf-8') as file: 46 | return yaml.safe_load(file) 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser(description="MAR Experiments on MMLU") 50 | parser.add_argument("--result_file", type=str, default=None) 51 | parser.add_argument('--lr', type=float, default=0.01,help="learning rate") 52 | parser.add_argument('--batch_size', type=int, default=16,help="batch size") 53 | parser.add_argument('--epochs', type=int, default=10, help="Prune every few iterations. Default 5.") 54 | parser.add_argument('--num_rounds',type=int,default=1,help="Number of optimization/inference rounds for one query") 55 | parser.add_argument('--domain', type=str, default="mmlu",help="Domain (the same as dataset name), default 'mmlu'") 56 | parser.add_argument('--decision_method', type=str, default='FinalRefer', 57 | help='The decison method of the agentprune') 58 | parser.add_argument('--prompt_file', type=str, default='MAR/Roles/FinalNode/mmlu.json') 59 | parser.add_argument('--start_epoch', type=int, default=0) 60 | parser.add_argument('--cost_rate', type=float, default=500.0) 61 | parser.add_argument('--max_agent', type=int, default=6) 62 | args = parser.parse_args() 63 | return args 64 | 65 | def infinite_data_loader(dataset): 66 | perm = np.random.permutation(len(dataset)) 67 | while True: 68 | for idx in perm: 69 | record = dataset[idx.item()] 70 | yield record 71 | 72 | 73 | if __name__ == '__main__': 74 | args = parse_args() 75 | current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) 76 | log_file = f"mmlu_{current_time}.txt" 77 | fix_random_seed(1234) 78 | configure_logging(log_name=log_file) 79 | total_solved, total_executed = (0, 0) 80 | 81 | # download() 82 | dataset_train = MMLUDataset('dev') 83 | dataset_test = MMLUDataset('test') 84 | 85 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 86 | router = MasRouter(max_agent=args.max_agent, device=device).to(device) 87 | optimizer = torch.optim.Adam(router.parameters(), lr=args.lr) 88 | tasks = tasks_profile 89 | llms = llm_profile 90 | reasonings = reasoning_profile 91 | logger.info("Start training...") 92 | 93 | train_batch = min(40,len(dataset_train)//args.batch_size) 94 | for i_epoch in range(args.epochs): 95 | if i_epoch < args.start_epoch: 96 | router.load_state_dict(torch.load(f"mmlu_router_epoch{i_epoch}.pth", map_location=device)) 97 | continue 98 | for i_batch in range(train_batch): 99 | print(f"Batch {i_batch}",80*'-') 100 | start_ts = time.time() 101 | current_batch = dataloader(dataset_train, args.batch_size, i_batch) 102 | current_batch = [{"task":dataset_train.record_to_input(record)["task"], "answer":dataset_train.record_to_target_answer(record)} for row, record in current_batch.iterrows()] 103 | 104 | queries = [item['task'] for item in current_batch] 105 | answers = [item['answer'] for item in current_batch] 106 | task_labels = [1 for _ in current_batch] 107 | tasks_y = torch.tensor(task_labels).to(device) 108 | optimizer.zero_grad() 109 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels, prompt_file=args.prompt_file) 110 | task_loss = F.cross_entropy(tasks_probs, tasks_y) 111 | utilities = [] 112 | answers_loss = [] 113 | is_solved_list = [] 114 | for query, result, answer, log_prob, cost in zip(queries, results, answers, log_probs, costs): 115 | predict_answer = MATH_get_predict(result)[0] 116 | is_solved = str(predict_answer).strip()==str(answer).strip() 117 | total_solved = total_solved + is_solved 118 | total_executed = total_executed + 1 119 | utility = is_solved - cost * args.cost_rate 120 | utilities.append(utility) 121 | is_solved_list.append(is_solved) 122 | answer_loss = -log_prob * utility 123 | answers_loss.append(answer_loss) 124 | logger.debug(f"Raw Result: {result}") 125 | logger.debug(f"Predict: {predict_answer}") 126 | logger.debug(f"Truth: {answer}") 127 | logger.debug(f"Cost: {cost}") 128 | logger.debug(f"is_solved: {is_solved}") 129 | answer_loss = torch.stack(answers_loss).sum() / len(answers_loss) 130 | vae_loss = vae_loss.mean() 131 | is_solved_tensor = torch.tensor(is_solved_list, dtype=torch.float32, device=device).unsqueeze(1) # shape: [N, 1] 132 | # adjust_loss = ((1 - is_solved_tensor) * (router.num_determiner.max_agent - agents_num) + 0.25 * is_solved_tensor * agents_num).mean() 133 | loss = task_loss + answer_loss + vae_loss*0.001 # + adjust_loss 134 | loss.backward() 135 | optimizer.step() 136 | 137 | accuracy = total_solved / total_executed 138 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 139 | logger.info(f"Accuracy: {accuracy}") 140 | 141 | logger.info(f"Epoch {i_epoch} Finishes",80*'-') 142 | torch.save(router.state_dict(), f"mmlu_router_epoch{i_epoch}.pth") 143 | 144 | logger.info("Finish training...") 145 | logger.info("Start testing...") 146 | total_solved, total_executed = (0, 0) 147 | test_batch = min(80, len(dataset_test)//args.batch_size) 148 | for i_batch in range(test_batch): 149 | if i_batch < train_batch: 150 | continue 151 | print(f"Batch {i_batch}",80*'-') 152 | start_ts = time.time() 153 | current_batch = dataloader(dataset_test, args.batch_size, i_batch) 154 | current_batch = [{"task":dataset_test.record_to_input(record)["task"],"answer":dataset_test.record_to_target_answer(record)} for row, record in current_batch.iterrows()] 155 | 156 | queries = [item['task'] for item in current_batch] 157 | answers = [item['answer'] for item in current_batch] 158 | task_labels = [1 for _ in current_batch] 159 | tasks_y = torch.tensor(task_labels).to(device) 160 | results, costs, log_probs, tasks_probs, vae_loss, agents_num = router.forward(queries, tasks, llms, reasonings, task_labels, prompt_file=args.prompt_file) 161 | utilities = [] 162 | answers_loss = [] 163 | 164 | for query, result, answer, log_prob, cost in zip(queries, results, answers, log_probs, costs): 165 | predict_answer = MATH_get_predict(result)[0] 166 | is_solved = str(predict_answer)==str(answer) 167 | total_solved = total_solved + is_solved 168 | total_executed = total_executed + 1 169 | utility = is_solved - cost * args.cost_rate 170 | utilities.append(utility) 171 | 172 | accuracy = total_solved / total_executed 173 | logger.info(f"Batch time {time.time() - start_ts:.3f}") 174 | logger.info(f"Accuracy: {accuracy}") 175 | logger.info(f"utilities:{utilities}") 176 | 177 | logger.info("Finish testing...") 178 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MAR/Agent/__init__.py: -------------------------------------------------------------------------------- 1 | from MAR.Agent.agent import Agent 2 | from MAR.Agent.agent_registry import AgentRegistry 3 | 4 | __all__ = ['Agent', 5 | 'FinalRefer', 6 | 'AgentRegistry', 7 | ] 8 | -------------------------------------------------------------------------------- /MAR/Agent/agent.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import json 3 | from loguru import logger 4 | 5 | from MAR.Agent.agent_registry import AgentRegistry 6 | from MAR.LLM.llm_registry import LLMRegistry 7 | from MAR.Roles.role_registry import RoleRegistry 8 | from MAR.Graph.node import Node 9 | from MAR.Prompts.message_aggregation import message_aggregation,inner_test 10 | from MAR.Prompts.post_process import post_process 11 | from MAR.Prompts.output_format import output_format_prompt 12 | from MAR.Prompts.reasoning import reasoning_prompt 13 | 14 | 15 | @AgentRegistry.register('Agent') 16 | class Agent(Node): 17 | def __init__(self, id: str | None =None, domain: str = "", role:str = None , llm_name: str = "",reason_name: str = "",): 18 | super().__init__(id, reason_name, domain, llm_name) 19 | self.llm = LLMRegistry.get(llm_name) 20 | self.role = RoleRegistry(domain, role) 21 | self.reason = reason_name 22 | 23 | self.message_aggregation = self.role.get_message_aggregation() 24 | self.description = self.role.get_description() 25 | self.output_format = self.role.get_output_format() 26 | self.post_process = self.role.get_post_process() 27 | self.post_description = self.role.get_post_description() 28 | self.post_output_format = self.role.get_post_output_format() 29 | # Reflect 30 | if reason_name == "Reflection" and self.post_output_format == "None": 31 | self.post_output_format = self.output_format 32 | self.post_description = "\nReflect on possible errors in the answer above and answer again using the same format. If you think there are no errors in your previous answers that will affect the results, there is no need to correct them.\n" 33 | 34 | def _process_inputs(self, raw_inputs:Dict[str,str], spatial_info:Dict[str, Dict], temporal_info:Dict[str, Dict], **kwargs): 35 | query = raw_inputs['query'] 36 | spatial_prompt = message_aggregation(raw_inputs, spatial_info, self.message_aggregation) 37 | temporal_prompt = message_aggregation(raw_inputs, temporal_info, self.message_aggregation) 38 | format_prompt = output_format_prompt[self.output_format] 39 | reason_prompt = reasoning_prompt[self.reason] 40 | 41 | system_prompt = f"{self.description}\n{reason_prompt}" 42 | system_prompt += f"\nFormat requirements that must be followed:\n{format_prompt}" if format_prompt else "" 43 | user_prompt = f"{query}\n" 44 | user_prompt += f"At the same time, other agents' outputs are as follows:\n\n{spatial_prompt}" if spatial_prompt else "" 45 | user_prompt += f"\n\nIn the last round of dialogue, other agents' outputs were:\n\n{temporal_prompt}" if temporal_prompt else "" 46 | return [{'role':'system','content':system_prompt},{'role':'user','content':user_prompt}] 47 | 48 | def _execute(self, input:Dict[str,str], spatial_info:Dict[str,Dict], temporal_info:Dict[str,Dict],**kwargs): 49 | """ 50 | Run the agent. 51 | Args: 52 | inputs: dict[str, str]: Raw inputs. 53 | spatial_info: dict[str, dict]: Spatial information. 54 | temporal_info: dict[str, dict]: Temporal information. 55 | Returns: 56 | Any: str: Aggregated message. 57 | """ 58 | query = input['query'] 59 | passed, response= inner_test(input, spatial_info, temporal_info) 60 | if passed: 61 | return response 62 | prompt = self._process_inputs(input, spatial_info, temporal_info, **kwargs) 63 | response = self.llm.gen(prompt) 64 | response = post_process(input, response, self.post_process) 65 | logger.debug(f"Agent {self.id} Role: {self.role.role} LLM: {self.llm.model_name}") 66 | logger.debug(f"system prompt:\n {prompt[0]['content']}") 67 | logger.debug(f"user prompt:\n {prompt[1]['content']}") 68 | logger.debug(f"response:\n {response}") 69 | 70 | # #! 71 | # received_id = [] 72 | # for id, info in spatial_info.items(): 73 | # role = info["role"].role 74 | # received_id.append(id + '(' + role + ')') 75 | # for id, info in temporal_info.items(): 76 | # role = info["role"].role 77 | # received_id.append(id + '(' + role + ')') 78 | 79 | # entry = { 80 | # "id": self.id, 81 | # "role": self.role.role, 82 | # "llm_name": self.llm.model_name, 83 | # "system_prompt": prompt[0]['content'], 84 | # "user_prompt": prompt[1]['content'], 85 | # "received_id": received_id, 86 | # "response": response, 87 | # } 88 | # try: 89 | # with open(f'./result/tmp_log.json', 'r', encoding='utf-8') as f: 90 | # data = json.load(f) 91 | # except (FileNotFoundError, json.JSONDecodeError): 92 | # data = [] 93 | 94 | # data.append(entry) 95 | 96 | # with open(f'./result/tmp_log.json', 'w', encoding='utf-8') as f: 97 | # json.dump(data, f, ensure_ascii=False, indent=2) 98 | # #! 99 | 100 | post_format_prompt = output_format_prompt[self.post_output_format] 101 | if post_format_prompt is not None: 102 | system_prompt = f"{self.post_description}\n" 103 | system_prompt += f"Format requirements that must be followed:\n{post_format_prompt}" 104 | user_prompt = f"{query}\nThe initial thinking information is:\n{response} \n Please refer to the new format requirements when replying." 105 | prompt = [{'role':'system','content':system_prompt},{'role':'user','content':user_prompt}] 106 | response = self.llm.gen(prompt) 107 | logger.debug(f"post system prompt:\n {system_prompt}") 108 | logger.debug(f"post user prompt:\n {user_prompt}") 109 | logger.debug(f"post response:\n {response}") 110 | 111 | # #! 112 | # received_id = [] 113 | # role = self.role.role 114 | # received_id.append(self.id + '(' + role + ')') 115 | 116 | # entry = { 117 | # "id": self.id, 118 | # "role": self.role.role, 119 | # "llm_name": self.llm.model_name, 120 | # "system_prompt": prompt[0]['content'], 121 | # "user_prompt": prompt[1]['content'], 122 | # "received_id": received_id, 123 | # "response": response, 124 | # } 125 | # try: 126 | # with open(f'./result/tmp_log.json', 'r', encoding='utf-8') as f: 127 | # data = json.load(f) 128 | # except (FileNotFoundError, json.JSONDecodeError): 129 | # data = [] 130 | 131 | # data.append(entry) 132 | 133 | # with open(f'./result/tmp_log.json', 'w', encoding='utf-8') as f: 134 | # json.dump(data, f, ensure_ascii=False, indent=2) 135 | # #! 136 | return response 137 | 138 | def _async_execute(self, input, spatial_info, temporal_info, **kwargs): 139 | return None 140 | 141 | @AgentRegistry.register('FinalRefer') 142 | class FinalRefer(Node): 143 | def __init__(self, id: str | None =None, agent_name = "", domain = "", llm_name = "", prompt_file = ""): 144 | super().__init__(id, agent_name, domain, llm_name) 145 | self.llm = LLMRegistry.get(llm_name) 146 | self.prompt_file = json.load(open(f"{prompt_file}", 'r', encoding='utf-8')) 147 | 148 | def _process_inputs(self, raw_inputs, spatial_info, temporal_info, **kwargs): 149 | system_prompt = f"{self.prompt_file['system']}" 150 | spatial_str = "" 151 | for id, info in spatial_info.items(): 152 | spatial_str += id + ": " + info['output'] + "\n\n" 153 | user_prompt = f"The task is:\n\n {raw_inputs['query']}.\n At the same time, the output of other agents is as follows:\n\n{spatial_str} {self.prompt_file['user']}" 154 | return [{'role':'system','content':system_prompt},{'role':'user','content':user_prompt}] 155 | 156 | def _execute(self, input, spatial_info, temporal_info, **kwargs): 157 | prompt = self._process_inputs(input, spatial_info, temporal_info, **kwargs) 158 | response = self.llm.gen(prompt) 159 | logger.debug(f"Final Refer Node LLM: {self.llm.model_name}") 160 | logger.debug(f"Final System Prompt:\n {prompt[0]['content']}") 161 | logger.debug(f"Final User Prompt:\n {prompt[1]['content']}") 162 | logger.debug(f"Final Response:\n {response}") 163 | # #! 164 | # received_id = [] 165 | # for id, info in spatial_info.items(): 166 | # role = info["role"].role 167 | # received_id.append(id + '(' + role + ')') 168 | # for id, info in temporal_info.items(): 169 | # role = info["role"].role 170 | # received_id.append(id + '(' + role + ')') 171 | 172 | # entry = { 173 | # "id": self.id, 174 | # "role": "FinalDecision", 175 | # "llm_name": self.llm.model_name, 176 | # "system_prompt": prompt[0]['content'], 177 | # "user_prompt": prompt[1]['content'], 178 | # "received_id": received_id, 179 | # "response": response, 180 | # } 181 | # try: 182 | # with open(f'./result/tmp_log.json', 'r', encoding='utf-8') as f: 183 | # data = json.load(f) 184 | # except (FileNotFoundError, json.JSONDecodeError): 185 | # data = [] 186 | 187 | # data.append(entry) 188 | 189 | # with open(f'./result/tmp_log.json', 'w', encoding='utf-8') as f: 190 | # json.dump(data, f, ensure_ascii=False, indent=2) 191 | # #! 192 | return response 193 | 194 | def _async_execute(self, input, spatial_info, temporal_info, **kwargs): 195 | return None -------------------------------------------------------------------------------- /MAR/Agent/agent_registry.py: -------------------------------------------------------------------------------- 1 | from class_registry import ClassRegistry 2 | from typing import Type 3 | 4 | from MAR.Graph import Node 5 | 6 | class AgentRegistry: 7 | registry = ClassRegistry() 8 | @classmethod 9 | def register(cls, *args, **kwargs): 10 | return cls.registry.register(*args, **kwargs) 11 | 12 | @classmethod 13 | def keys(cls): 14 | return cls.registry.keys() 15 | 16 | @classmethod 17 | def get(cls, name:str, *args, **kwargs)->Node: 18 | return cls.registry.get(name, *args, **kwargs) 19 | 20 | @classmethod 21 | def get_class(cls, name:str) -> Type: 22 | return cls.registry.get_class(name) -------------------------------------------------------------------------------- /MAR/Agent/reasoning_profile.py: -------------------------------------------------------------------------------- 1 | reasoning_profile = [{'Name': 'IO', 'Description': 'In single-agent IO reasoning, a single agent directly gives an output based on the input.'}, 2 | {'Name': 'CoT', 'Description': 'In single-agent CoT reasoning, a single agent reasons step-by-step to achieve a goal.'}, 3 | {'Name': 'Chain', 'Description': 'In multi-agent chain reasoning, multiple agents sequentially reason and pass information in a chain-like manner.'}, 4 | {'Name': 'FullConnected', 'Description': 'In multi-agent full-graph reasoning, multiple agents reason collectively over the entire graph structure.'}, 5 | {'Name': 'Debate', 'Description': 'In multi-agent debate reasoning, multiple agents engage in a structured argumentative dialogue to explore different perspectives, challenge assumptions, and reach a consensus.'}, 6 | {'Name': 'Reflection', 'Description': 'In reflection reasoning, multiple agents reflect on their own reasoning processes and outcomes to improve their performance.'},] -------------------------------------------------------------------------------- /MAR/Graph/__init__.py: -------------------------------------------------------------------------------- 1 | from MAR.Graph.node import Node 2 | from MAR.Graph.graph import Graph 3 | 4 | __all__ = ["Node", 5 | "Graph",] -------------------------------------------------------------------------------- /MAR/Graph/node.py: -------------------------------------------------------------------------------- 1 | import shortuuid 2 | from typing import List, Any, Optional,Dict 3 | from abc import ABC, abstractmethod 4 | import warnings 5 | import asyncio 6 | 7 | 8 | class Node(ABC): 9 | """ 10 | Represents a processing unit within a graph-based framework. 11 | 12 | This class encapsulates the functionality for a node in a graph, managing 13 | connections to other nodes, handling inputs and outputs, and executing 14 | assigned operations. It supports both individual and aggregated processing modes. 15 | 16 | Attributes: 17 | id (uuid.UUID): Unique identifier for the node. 18 | agent_type(str): Associated agent name for node-specific operations. 19 | spatial_predecessors (List[Node]): Nodes that precede this node in the graph. 20 | spatial_successors (List[Node]): Nodes that succeed this node in the graph. 21 | inputs (List[Any]): Inputs to be processed by the node. 22 | outputs (List[Any]): Results produced after node execution. 23 | raw_inputs (List[Any]): The original input contains the question or math problem. 24 | last_memory (Dict[str,List[Any]]): Input and output of the previous timestamp. 25 | 26 | Methods: 27 | add_predecessor(operation): 28 | Adds a node as a predecessor of this node, establishing a directed connection. 29 | add_successor(operation): 30 | Adds a node as a successor of this node, establishing a directed connection. 31 | memory_update(): 32 | Update the last_memory. 33 | get_spatial_info(): 34 | Get all of the info from spatial spatial_predecessors. 35 | execute(**kwargs): 36 | Processes the inputs through the node's operation, handling each input individually. 37 | _execute(input, **kwargs): 38 | An internal method that defines how a single input is processed by the node. This method should be implemented specifically for each node type. 39 | _process_inputs(raw_inputs, spatial_info, temporal_info, **kwargs)->List[Any]: 40 | An internal medthod to process the raw_input, the spatial info and temporal info to get the final inputs. 41 | """ 42 | 43 | def __init__(self, 44 | id: Optional[str], 45 | agent_name:str="", 46 | domain:str="", 47 | llm_name:str = "", 48 | ): 49 | """ 50 | Initializes a new Node instance. 51 | """ 52 | self.id:str = id if id is not None else shortuuid.ShortUUID().random(length=4) 53 | self.agent_name:str = agent_name 54 | self.domain:str = domain 55 | self.llm_name:str = llm_name 56 | self.spatial_predecessors: List[Node] = [] 57 | self.spatial_successors: List[Node] = [] 58 | self.temporal_predecessors: List[Node] = [] 59 | self.temporal_successors: List[Node] = [] 60 | self.inputs: List[Any] = [] 61 | self.outputs: List[Any] = [] 62 | self.raw_inputs: List[Any] = [] 63 | self.role = "" 64 | self.last_memory: Dict[str,List[Any]] = {'inputs':[],'outputs':[],'raw_inputs':[]} 65 | 66 | @property 67 | def node_name(self): 68 | return self.__class__.__name__ 69 | 70 | def add_predecessor(self, operation: 'Node', st='spatial'): 71 | if st == 'spatial' and operation not in self.spatial_predecessors: 72 | self.spatial_predecessors.append(operation) 73 | operation.spatial_successors.append(self) 74 | elif st == 'temporal' and operation not in self.temporal_predecessors: 75 | self.temporal_predecessors.append(operation) 76 | operation.temporal_successors.append(self) 77 | 78 | def add_successor(self, operation: 'Node', st='spatial'): 79 | if st =='spatial' and operation not in self.spatial_successors: 80 | self.spatial_successors.append(operation) 81 | operation.spatial_predecessors.append(self) 82 | elif st == 'temporal' and operation not in self.temporal_successors: 83 | self.temporal_successors.append(operation) 84 | operation.temporal_predecessors.append(self) 85 | 86 | def remove_predecessor(self, operation: 'Node', st='spatial'): 87 | if st =='spatial' and operation in self.spatial_predecessors: 88 | self.spatial_predecessors.remove(operation) 89 | operation.spatial_successors.remove(self) 90 | elif st =='temporal' and operation in self.temporal_predecessors: 91 | self.temporal_predecessors.remove(operation) 92 | operation.temporal_successors.remove(self) 93 | 94 | def remove_successor(self, operation: 'Node', st='spatial'): 95 | if st =='spatial' and operation in self.spatial_successors: 96 | self.spatial_successors.remove(operation) 97 | operation.spatial_predecessors.remove(self) 98 | elif st =='temporal' and operation in self.temporal_successors: 99 | self.temporal_successors.remove(operation) 100 | operation.temporal_predecessors.remove(self) 101 | 102 | def clear_connections(self): 103 | self.spatial_predecessors: List[Node] = [] 104 | self.spatial_successors: List[Node] = [] 105 | self.temporal_predecessors: List[Node] = [] 106 | self.temporal_successors: List[Node] = [] 107 | 108 | def update_memory(self): 109 | self.last_memory['inputs'] = self.inputs 110 | self.last_memory['outputs'] = self.outputs 111 | self.last_memory['raw_inputs'] = self.raw_inputs 112 | 113 | def get_spatial_info(self)->Dict[str,Dict]: 114 | """ Return a dict that maps id to info. """ 115 | spatial_info = {} 116 | if self.spatial_predecessors is not None: 117 | for predecessor in self.spatial_predecessors: 118 | predecessor_outputs = predecessor.outputs 119 | if isinstance(predecessor_outputs, list) and len(predecessor_outputs): 120 | predecessor_output = predecessor_outputs[-1] 121 | elif isinstance(predecessor_outputs, list) and len(predecessor_outputs)==0: 122 | continue 123 | else: 124 | predecessor_output = predecessor_outputs 125 | spatial_info[predecessor.id] = {"role":predecessor.role,"output":predecessor_output} 126 | 127 | return spatial_info 128 | 129 | def get_temporal_info(self)->Dict[str,Any]: 130 | temporal_info = {} 131 | if self.temporal_predecessors is not None: 132 | for predecessor in self.temporal_predecessors: 133 | predecessor_outputs = predecessor.last_memory['outputs'] 134 | if isinstance(predecessor_outputs, list) and len(predecessor_outputs): 135 | predecessor_output = predecessor_outputs[-1] 136 | elif isinstance(predecessor_outputs, list) and len(predecessor_outputs)==0: 137 | continue 138 | else: 139 | predecessor_output = predecessor_outputs 140 | temporal_info[predecessor.id] = {"role":predecessor.role,"output":predecessor_output} 141 | 142 | return temporal_info 143 | 144 | def execute(self, input:Any, **kwargs): 145 | self.outputs = [] 146 | spatial_info:Dict[str,Dict] = self.get_spatial_info() 147 | temporal_info:Dict[str,Dict] = self.get_temporal_info() 148 | results = [self._execute(input, spatial_info, temporal_info, **kwargs)] 149 | 150 | for result in results: 151 | if not isinstance(result, list): 152 | result = [result] 153 | self.outputs.extend(result) 154 | return self.outputs 155 | 156 | 157 | async def async_execute(self, input:Any, **kwargs): 158 | 159 | self.outputs = [] 160 | spatial_info:Dict[str,Any] = self.get_spatial_info() 161 | temporal_info:Dict[str,Any] = self.get_temporal_info() 162 | tasks = [asyncio.create_task(self._async_execute(input, spatial_info, temporal_info, **kwargs))] 163 | results = await asyncio.gather(*tasks, return_exceptions=False) 164 | for result in results: 165 | if not isinstance(result, list): 166 | result = [result] 167 | self.outputs.extend(result) 168 | return self.outputs 169 | 170 | @abstractmethod 171 | def _execute(self, input:List[Any], spatial_info:Dict[str,Any], temporal_info:Dict[str,Any], **kwargs): 172 | """ To be overriden by the descendant class """ 173 | """ Use the processed input to get the result """ 174 | 175 | @abstractmethod 176 | async def _async_execute(self, input:List[Any], spatial_info:Dict[str,Any], temporal_info:Dict[str,Any], **kwargs): 177 | """ To be overriden by the descendant class """ 178 | """ Use the processed input to get the result """ 179 | 180 | @abstractmethod 181 | def _process_inputs(self, raw_inputs:List[Any], spatial_info:Dict[str,Any], temporal_info:Dict[str,Any], **kwargs)->List[Any]: 182 | """ To be overriden by the descendant class """ 183 | """ Process the raw_inputs(most of the time is a List[Dict]) """ 184 | -------------------------------------------------------------------------------- /MAR/LLM/__init__.py: -------------------------------------------------------------------------------- 1 | from MAR.LLM.llm_registry import LLMRegistry 2 | from MAR.LLM.gpt_chat import GPTChat 3 | 4 | __all__ = ["LLMRegistry", 5 | "GPTChat",] 6 | -------------------------------------------------------------------------------- /MAR/LLM/gpt_chat.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | from typing import List, Union, Optional 3 | from tenacity import retry, wait_random_exponential, stop_after_attempt 4 | from typing import Dict, Any 5 | from dotenv import load_dotenv 6 | import os 7 | import requests 8 | from groq import Groq, AsyncGroq 9 | from openai import OpenAI, AsyncOpenAI 10 | 11 | from MAR.LLM.price import cost_count 12 | from MAR.LLM.llm import LLM 13 | from MAR.LLM.llm_registry import LLMRegistry 14 | 15 | load_dotenv() 16 | MINE_BASE_URL = os.getenv('BASE_URL') 17 | MINE_API_KEYS = os.getenv('API_KEY') 18 | 19 | 20 | @LLMRegistry.register('ALLChat') 21 | class ALLChat(LLM): 22 | def __init__(self, model_name: str): 23 | self.model_name = model_name 24 | 25 | @retry(wait=wait_random_exponential(max=100), stop=stop_after_attempt(10)) 26 | def gen( 27 | self, 28 | messages: Union[List[Dict], str], 29 | max_tokens: Optional[int] = None, 30 | temperature: Optional[float] = None, 31 | num_comps: Optional[int] = None, 32 | ) -> Union[List[str], str]: 33 | if max_tokens is None: 34 | max_tokens = self.DEFAULT_MAX_TOKENS 35 | if temperature is None: 36 | temperature = self.DEFAULT_TEMPERATURE 37 | if num_comps is None: 38 | num_comps = self.DEFUALT_NUM_COMPLETIONS 39 | 40 | if isinstance(messages, str): 41 | messages = [{'role':"user", 'content':messages}] 42 | client = OpenAI(base_url = os.environ.get("URL"), 43 | api_key = os.environ.get("KEY")) 44 | chat_completion = client.chat.completions.create( 45 | messages = messages, 46 | model = self.model_name, 47 | ) 48 | response = chat_completion.choices[0].message.content 49 | prompt = "".join([item['content'] for item in messages]) 50 | cost_count(prompt, response, self.model_name) 51 | return response 52 | 53 | async def agen( 54 | self, 55 | messages: Union[List[Dict], str], 56 | max_tokens: Optional[int] = None, 57 | temperature: Optional[float] = None, 58 | num_comps: Optional[int] = None, 59 | ) -> Union[List[str], str]: 60 | 61 | if max_tokens is None: 62 | max_tokens = self.DEFAULT_MAX_TOKENS 63 | if temperature is None: 64 | temperature = self.DEFAULT_TEMPERATURE 65 | if num_comps is None: 66 | num_comps = self.DEFUALT_NUM_COMPLETIONS 67 | 68 | if isinstance(messages, str): 69 | messages = [{'role':"user", 'content':messages}] 70 | 71 | client = AsyncOpenAI(base_url = os.environ.get("URL"), 72 | api_key = os.environ.get("KEY"),) 73 | chat_completion = await client.chat.completions.create( 74 | messages = messages, 75 | model = self.model_name, 76 | max_tokens = max_tokens, 77 | temperature = temperature, 78 | ) 79 | response = chat_completion.choices[0].message.content 80 | 81 | return response 82 | 83 | 84 | @LLMRegistry.register('Deepseek') 85 | class DSChat(LLM): 86 | def __init__(self, model_name: str): 87 | self.model_name = model_name 88 | 89 | @retry(wait=wait_random_exponential(max=100), stop=stop_after_attempt(10)) 90 | def gen( 91 | self, 92 | messages: Union[List[Dict], str], 93 | max_tokens: Optional[int] = None, 94 | temperature: Optional[float] = None, 95 | num_comps: Optional[int] = None, 96 | ) -> Union[List[str], str]: 97 | if max_tokens is None: 98 | max_tokens = self.DEFAULT_MAX_TOKENS 99 | if temperature is None: 100 | temperature = self.DEFAULT_TEMPERATURE 101 | if num_comps is None: 102 | num_comps = self.DEFUALT_NUM_COMPLETIONS 103 | 104 | if isinstance(messages, str): 105 | messages = [{'role':"user", 'content':messages}] 106 | client = OpenAI(base_url = os.environ.get("DS_URL"), 107 | api_key = os.environ.get("DS_KEY")) 108 | chat_completion = client.chat.completions.create( 109 | messages = messages, 110 | model = self.model_name, 111 | ) 112 | response = chat_completion.choices[0].message.content 113 | prompt = "".join([item['content'] for item in messages]) 114 | cost_count(prompt, response, self.model_name) 115 | return response 116 | 117 | async def agen( 118 | self, 119 | messages: Union[List[Dict], str], 120 | max_tokens: Optional[int] = None, 121 | temperature: Optional[float] = None, 122 | num_comps: Optional[int] = None, 123 | ) -> Union[List[str], str]: 124 | 125 | if max_tokens is None: 126 | max_tokens = self.DEFAULT_MAX_TOKENS 127 | if temperature is None: 128 | temperature = self.DEFAULT_TEMPERATURE 129 | if num_comps is None: 130 | num_comps = self.DEFUALT_NUM_COMPLETIONS 131 | 132 | if isinstance(messages, str): 133 | messages = [{'role':"user", 'content':messages}] 134 | 135 | client = AsyncOpenAI(base_url = os.environ.get("DS_URL"), 136 | api_key = os.environ.get("DS_KEY"),) 137 | chat_completion = await client.chat.completions.create( 138 | messages = messages, 139 | model = self.model_name, 140 | max_tokens = max_tokens, 141 | temperature = temperature, 142 | ) 143 | response = chat_completion.choices[0].message.content 144 | 145 | return response 146 | 147 | @retry(wait=wait_random_exponential(max=100), stop=stop_after_attempt(10)) 148 | async def achat( 149 | model: str, 150 | msg: List[Dict],): 151 | request_url = MINE_BASE_URL 152 | authorization_key = MINE_API_KEYS 153 | headers = { 154 | 'Content-Type': 'application/json', 155 | 'authorization': authorization_key 156 | } 157 | data = { 158 | "name": model + '-y', 159 | "inputs": { 160 | "stream": False, 161 | "msg": repr(msg), 162 | } 163 | } 164 | async with aiohttp.ClientSession() as session: 165 | async with session.post(request_url, headers=headers ,json=data) as response: 166 | response_data = await response.json() 167 | if isinstance(response_data['data'],str): 168 | prompt = "".join([item['content'] for item in msg]) 169 | cost_count(prompt,response_data['data'], model) 170 | return response_data['data'] 171 | else: 172 | raise Exception("api error") 173 | 174 | @retry(wait=wait_random_exponential(max=100), stop=stop_after_attempt(10)) 175 | def chat( 176 | model: str, 177 | msg: List[Dict],): 178 | request_url = MINE_BASE_URL 179 | authorization_key = MINE_API_KEYS 180 | headers = { 181 | 'Content-Type': 'application/json', 182 | 'authorization': authorization_key 183 | } 184 | data = { 185 | "name": model+'-y', 186 | "inputs": { 187 | "stream": False, 188 | "msg": repr(msg), 189 | } 190 | } 191 | response = requests.post(request_url, headers=headers ,json=data) 192 | response_data = response.json() 193 | if isinstance(response_data['data'],str): 194 | prompt = "".join([item['content'] for item in msg]) 195 | cost_count(prompt,response_data['data'], model) 196 | return response_data['data'] 197 | else: 198 | raise Exception("api error") 199 | 200 | @LLMRegistry.register('GPTChat') 201 | class GPTChat(LLM): 202 | def __init__(self, model_name: str): 203 | self.model_name = model_name 204 | 205 | async def agen( 206 | self, 207 | messages: Union[List[Dict], str], 208 | max_tokens: Optional[int] = None, 209 | temperature: Optional[float] = None, 210 | num_comps: Optional[int] = None, 211 | ) -> Union[List[str], str]: 212 | 213 | if max_tokens is None: 214 | max_tokens = self.DEFAULT_MAX_TOKENS 215 | if temperature is None: 216 | temperature = self.DEFAULT_TEMPERATURE 217 | if num_comps is None: 218 | num_comps = self.DEFUALT_NUM_COMPLETIONS 219 | 220 | if isinstance(messages, str): 221 | messages = [{'role':"user", 'content':messages}] 222 | return await achat(self.model_name,messages) 223 | 224 | def gen( 225 | self, 226 | messages: Union[List[Dict], str], 227 | max_tokens: Optional[int] = None, 228 | temperature: Optional[float] = None, 229 | num_comps: Optional[int] = None, 230 | ) -> Union[List[str], str]: 231 | 232 | if max_tokens is None: 233 | max_tokens = self.DEFAULT_MAX_TOKENS 234 | if temperature is None: 235 | temperature = self.DEFAULT_TEMPERATURE 236 | if num_comps is None: 237 | num_comps = self.DEFUALT_NUM_COMPLETIONS 238 | 239 | if isinstance(messages, str): 240 | messages = [{'role':"user", 'content':messages}] 241 | return chat(self.model_name,messages) 242 | 243 | 244 | @LLMRegistry.register('Groq') 245 | class GroqChat(LLM): 246 | def __init__(self, model_name: str): 247 | self.model_name = model_name 248 | 249 | @retry(wait=wait_random_exponential(max=100), stop=stop_after_attempt(10)) 250 | def gen( 251 | self, 252 | messages: Union[List[Dict], str], 253 | max_tokens: Optional[int] = None, 254 | temperature: Optional[float] = None, 255 | num_comps: Optional[int] = None, 256 | ) -> Union[List[str], str]: 257 | # TODO: Add num_comps to the request 258 | if max_tokens is None: 259 | max_tokens = self.DEFAULT_MAX_TOKENS 260 | if temperature is None: 261 | temperature = self.DEFAULT_TEMPERATURE 262 | if num_comps is None: 263 | num_comps = self.DEFUALT_NUM_COMPLETIONS 264 | 265 | if isinstance(messages, str): 266 | messages = [{'role':"user", 'content':messages}] 267 | 268 | client = Groq(api_key=os.environ.get("GROQ_API_KEY"),) 269 | chat_completion = client.chat.completions.create( 270 | messages = messages, 271 | model = self.model_name, 272 | ) 273 | response = chat_completion.choices[0].message.content 274 | prompt = "".join([item['content'] for item in messages]) 275 | cost_count(prompt, response, self.model_name) 276 | return response 277 | 278 | async def agen( 279 | self, 280 | messages: Union[List[Dict], str], 281 | max_tokens: Optional[int] = None, 282 | temperature: Optional[float] = None, 283 | num_comps: Optional[int] = None, 284 | ) -> Union[List[str], str]: 285 | # TODO: Add num_comps to the request 286 | if max_tokens is None: 287 | max_tokens = self.DEFAULT_MAX_TOKENS 288 | if temperature is None: 289 | temperature = self.DEFAULT_TEMPERATURE 290 | if num_comps is None: 291 | num_comps = self.DEFUALT_NUM_COMPLETIONS 292 | 293 | if isinstance(messages, str): 294 | messages = [{'role':"user", 'content':messages}] 295 | 296 | client = AsyncGroq(api_key=os.environ.get("GROQ_API_KEY"),) 297 | chat_completion = await client.chat.completions.create( 298 | messages = messages, 299 | model = self.model_name, 300 | max_tokens = max_tokens, 301 | temperature = temperature, 302 | ) 303 | response = chat_completion.choices[0].message.content 304 | 305 | return response 306 | 307 | 308 | @LLMRegistry.register('OpenRouter') 309 | class OpenRouterChat(LLM): 310 | def __init__(self, model_name: str): 311 | self.model_name = model_name 312 | 313 | @retry(wait=wait_random_exponential(max=100), stop=stop_after_attempt(10)) 314 | def gen( 315 | self, 316 | messages: Union[List[Dict], str], 317 | max_tokens: Optional[int] = None, 318 | temperature: Optional[float] = None, 319 | num_comps: Optional[int] = None, 320 | ) -> Union[List[str], str]: 321 | if max_tokens is None: 322 | max_tokens = self.DEFAULT_MAX_TOKENS 323 | if temperature is None: 324 | temperature = self.DEFAULT_TEMPERATURE 325 | if num_comps is None: 326 | num_comps = self.DEFUALT_NUM_COMPLETIONS 327 | 328 | if isinstance(messages, str): 329 | messages = [{'role':"user", 'content':messages}] 330 | client = OpenAI(base_url = os.environ.get("OPENROUTER_BASE_URL"), 331 | api_key = os.environ.get("OPENROUTER_API_KEY"),) 332 | chat_completion = client.chat.completions.create( 333 | messages = messages, 334 | model = self.model_name, 335 | ) 336 | response = chat_completion.choices[0].message.content 337 | return response 338 | 339 | async def agen( 340 | self, 341 | messages: Union[List[Dict], str], 342 | max_tokens: Optional[int] = None, 343 | temperature: Optional[float] = None, 344 | num_comps: Optional[int] = None, 345 | ) -> Union[List[str], str]: 346 | # TODO 347 | return 0 -------------------------------------------------------------------------------- /MAR/LLM/llm.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Union, Optional,Dict 3 | 4 | class LLM(ABC): 5 | DEFAULT_MAX_TOKENS = 81920 6 | DEFAULT_TEMPERATURE = 1 7 | DEFUALT_NUM_COMPLETIONS = 1 8 | 9 | @abstractmethod 10 | async def agen( 11 | self, 12 | messages: Union[List[Dict], str], 13 | max_tokens: Optional[int] = None, 14 | temperature: Optional[float] = None, 15 | num_comps: Optional[int] = None, 16 | ) -> Union[List[str], str]: 17 | 18 | pass 19 | 20 | @abstractmethod 21 | def gen( 22 | self, 23 | messages: Union[List[Dict], str], 24 | max_tokens: Optional[int] = None, 25 | temperature: Optional[float] = None, 26 | num_comps: Optional[int] = None, 27 | ) -> Union[List[str], str]: 28 | 29 | pass 30 | -------------------------------------------------------------------------------- /MAR/LLM/llm_embedding.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer 2 | import torch 3 | 4 | def get_sentence_embedding(sentence): 5 | model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') 6 | embeddings = model.encode(sentence) 7 | return torch.tensor(embeddings) 8 | 9 | class SentenceEncoder(torch.nn.Module): 10 | def __init__(self,device=None): 11 | super().__init__() 12 | self.device = device if device else 'cuda' if torch.cuda.is_available() else 'cpu' 13 | self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2',device=self.device) 14 | 15 | def forward(self, sentence): 16 | if len(sentence) == 0: 17 | return torch.tensor([]).to(self.device) 18 | embeddings = self.model.encode(sentence,convert_to_tensor=True,device=self.device) 19 | return embeddings -------------------------------------------------------------------------------- /MAR/LLM/llm_profile.py: -------------------------------------------------------------------------------- 1 | llm_profile = [ 2 | {'Name': 'gpt-4o-mini', 3 | 'Description': 'GPT-4o Mini is a smaller version of the GPT-4o language model, designed for faster inference and reduced memory usage. It retains the same capabilities as the full-size model, but with fewer parameters.\n\ 4 | The model costs $0.15 per million input tokens and $0.6 per million output tokens\n\ 5 | In General Q&A Benchmark MMLU, GPT-4o-mini achieves an accuracy of 77.8.\n\ 6 | In Reasoning Benchmark GPQA, GPT-4o-mini achieves an accuracy of 40.2.\n\ 7 | In Coding Benchmark HumanEval, GPT-4o-mini achieves an accuracy of 85.7.\n\ 8 | In Math Benchmark MATH, GPT-4o-mini achieves an accuracy of 66.09.'}, 9 | {'Name': 'claude-3-5-haiku-20241022', 10 | 'Description': 'The new Claude 3.5 Haiku combines rapid response times with improved reasoning capabilities, making it ideal for tasks that require both speed and intelligence. Claude 3.5 Haiku improves on its predecessor and matches the performance of Claude 3 Opus.\n\ 11 | The model costs $1.0 per million input tokens and $5.0 per million output tokens\n\ 12 | In General Q&A Benchmark MMLU, claude-3-5-haiku achieves an accuracy of 67.9.\n\ 13 | In Reasoning Benchmark GPQA, claude-3-5-haiku achieves an accuracy of 41.6.\n\ 14 | In Coding Benchmark HumanEval, claude-3-5-haiku achieves an accuracy of 86.3.\n\ 15 | In Math Benchmark MATH, claude-3-5-haiku achieves an accuracy of 65.9.'}, 16 | {'Name': 'gemini-1.5-flash-latest', 17 | 'Description': 'Gemini 1.5 Flash was purpose-built as our fastest, most cost-efficient model yet for high volume tasks, at scale, to address developers feedback asking for lower latency and cost.\n\ 18 | The model costs $0.15 per million input tokens and $0.6 per million output tokens\n\ 19 | In General Q&A Benchmark MMLU, gemini-1.5-flash achieves an accuracy of 80.0.\n\ 20 | In Reasoning Benchmark GPQA, gemini-1.5-flash achieves an accuracy of 39.5.\n\ 21 | In Coding Benchmark HumanEval, gemini-1.5-flash achieves an accuracy of 82.6.\n\ 22 | In Math Benchmark MATH, gemini-1.5-flash achieves an accuracy of 74.4.'}, 23 | {'Name': 'llama-3.1-70b-instruct', 24 | 'Description': 'The Meta Llama-3.1-70b-instruct multilingual large language model (LLM) is a pretrained and instruction tuned generative model in 70B (text in/text out).\n\ 25 | The model costs $0.2 per million input tokens and $0.2 per million output tokens\n\ 26 | In General Q&A Benchmark MMLU, Llama 3.1 achieves an accuracy of 79.1.\n\ 27 | In Reasoning Benchmark GPQA, Llama 3.1 achieves an accuracy of 46.7.\n\ 28 | In Coding Benchmark HumanEval, Llama 3.1 achieves an accuracy of 80.7.\n\ 29 | In Math Benchmark MATH, Llama 3.1 achieves an accuracy of 60.3.'}, 30 | {'Name': 'deepseek-chat', 31 | 'Description': 'DeepSeek-V3 is a powerful open-source Mixture-of-Experts (MoE) language model developed by Chinese AI company DeepSeek, featuring 671 billion total parameters with 37 billion activated per token, achieving performance comparable to leading closed-source models like GPT-4.\n\ 32 | The model costs $0.27 per million input tokens and $1.1 per million output tokens\n\ 33 | In General Q&A Benchmark MMLU, deepseek-chat achieves an accuracy of 88.5.\n\ 34 | In Reasoning Benchmark GPQA, deepseek-chat achieves an accuracy of 59.1.\n\ 35 | In Coding Benchmark HumanEval, deepseek-chat achieves an accuracy of 88.4.\n\ 36 | In Math Benchmark MATH, deepseek-chat achieves an accuracy of 85.1'}, 37 | ] -------------------------------------------------------------------------------- /MAR/LLM/llm_registry.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from class_registry import ClassRegistry 3 | 4 | from MAR.LLM.llm import LLM 5 | 6 | 7 | class LLMRegistry: 8 | registry = ClassRegistry() 9 | 10 | @classmethod 11 | def register(cls, *args, **kwargs): 12 | return cls.registry.register(*args, **kwargs) 13 | 14 | @classmethod 15 | def keys(cls): 16 | return cls.registry.keys() 17 | 18 | @classmethod 19 | def get(cls, model_name: Optional[str] = None) -> LLM: 20 | if model_name is None or model_name=="": 21 | model_name = "gpt-4o-mini" 22 | if 'DeepSeek-V3' in model_name: 23 | model = cls.registry.get('Deepseek', model_name) 24 | else: 25 | model = cls.registry.get('ALLChat', model_name) 26 | 27 | return model 28 | 29 | -------------------------------------------------------------------------------- /MAR/LLM/price.py: -------------------------------------------------------------------------------- 1 | from MAR.Utils.globals import Cost, PromptTokens, CompletionTokens 2 | import tiktoken 3 | # GPT-4: https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo 4 | # GPT3.5: https://platform.openai.com/docs/models/gpt-3-5 5 | # DALL-E: https://openai.com/pricing 6 | 7 | def cal_token(model:str, text:str): 8 | encoder = tiktoken.encoding_for_model('gpt-4o') 9 | num_tokens = len(encoder.encode(text)) 10 | return num_tokens 11 | 12 | def cost_count(prompt, response, model_name): 13 | prompt_len: int 14 | completion_len: int 15 | price: float 16 | 17 | prompt_len = cal_token(model_name, prompt) 18 | completion_len = cal_token(model_name, response) 19 | if model_name not in MODEL_PRICE.keys(): 20 | return 0, 0, 0 21 | prompt_price = MODEL_PRICE[model_name]["input"] 22 | completion_price = MODEL_PRICE[model_name]["output"] 23 | price = prompt_len * prompt_price / 1000000 + completion_len * completion_price / 1000000 24 | 25 | Cost.instance().value += price 26 | PromptTokens.instance().value += prompt_len 27 | CompletionTokens.instance().value += completion_len 28 | 29 | # print(f"Prompt Tokens: {prompt_len}, Completion Tokens: {completion_len}") 30 | return price, prompt_len, completion_len 31 | 32 | MODEL_PRICE = { 33 | "gpt-3.5-turbo-0125":{ 34 | "input": 0.5, 35 | "output": 1.5 36 | }, 37 | "gpt-3.5-turbo-1106":{ 38 | "input": 1.0, 39 | "output": 2.0 40 | }, 41 | "gpt-4-1106-preview":{ 42 | "input": 10.0, 43 | "output": 30.0 44 | }, 45 | "gpt-4o":{ 46 | "input": 2.5, 47 | "output": 10.0 48 | }, 49 | "gpt-4o-mini":{ 50 | "input": 0.15, 51 | "output": 0.6 52 | }, 53 | "claude-3-5-haiku-20241022":{ 54 | "input": 0.8, 55 | "output": 4.0 56 | }, 57 | "claude-3-5-sonnet-20241022":{ 58 | "input": 3.0, 59 | "output": 15.0 60 | }, 61 | "gemini-1.5-flash-latest":{ 62 | "input": 0.15, 63 | "output": 0.60 64 | }, 65 | "gemini-2.0-flash-thinking-exp":{ 66 | "input": 4.0, 67 | "output": 16.0 68 | }, 69 | "llama-3.3-70b-versatile":{ 70 | "input": 0.2, 71 | "output": 0.2 72 | }, 73 | "Meta-Llama-3.1-70B-Instruct":{ 74 | "input": 0.2, 75 | "output": 0.2 76 | }, 77 | "llama-3.1-70b-instruct":{ 78 | "input": 0.2, 79 | "output": 0.2 80 | }, 81 | 'deepseek-chat':{ 82 | 'input': 0.27, 83 | 'output': 1.1 84 | }, 85 | 'deepseek-ai/DeepSeek-V3':{ 86 | 'input': 0.27, 87 | 'output': 1.1 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /MAR/Prompts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanweiyue/masrouter/97849a3dcab21d2962e345551be7f3c6935c66d6/MAR/Prompts/__init__.py -------------------------------------------------------------------------------- /MAR/Prompts/message_aggregation.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict 3 | 4 | from MAR.Tools.coding.python_executor import execute_code_get_return 5 | from MAR.Tools.coding.python_executor import PyExecutor 6 | from Datasets.gsm8k_dataset import gsm_get_predict 7 | 8 | # ["Normal", "PythonExecute", "PythonInnerTest", "PHP"] 9 | def message_aggregation(raw_inputs:Dict[str,str], messages:Dict[str,Dict], aggregation_method): 10 | """ 11 | Aggregate messages from other agents in temporal and spatial dimensions. 12 | Args: 13 | messages: Dict[str,Dict]: A dict of messages from other agents. 14 | aggregation_method: str: Aggregation method. 15 | Returns: 16 | Any: str: Aggregated message. 17 | """ 18 | if aggregation_method == "Normal": 19 | return normal_agg(raw_inputs, messages) 20 | elif aggregation_method == "PythonExecute": 21 | return python_execute(raw_inputs, messages) 22 | elif aggregation_method == "PythonInnerTest": 23 | return python_inner_test(raw_inputs, messages) 24 | elif aggregation_method == "PHP": 25 | return php(raw_inputs, messages) 26 | else: 27 | raise ValueError(f"Invalid aggregation method: {aggregation_method}") 28 | 29 | def normal_agg(raw_inputs:Dict[str,str], messages:Dict[str,Dict]): 30 | """ 31 | Aggregate messages from other agents in temporal and spatial dimensions normally. 32 | Args: 33 | messages: Dict[str,Dict]: A dict of messages from other agents. 34 | Returns: 35 | Any:str:Aggregated message. 36 | """ 37 | # Aggregate messages normally 38 | aggregated_message = "" 39 | for id, info in messages.items(): 40 | aggregated_message += f"Agent {id}, role is {info['role'].role}, output is:\n\n {info['output']}\n\n" 41 | return aggregated_message 42 | 43 | def python_execute(raw_inputs:Dict[str,str],messages:Dict[str,Dict]): 44 | """ 45 | Aggregate messages from other agents in temporal and spatial dimensions by executing Python code. 46 | Args: 47 | messages: Dict[str,Dict]: A dict of messages from other agents. 48 | Returns: 49 | Any:str: Aggregated message. 50 | """ 51 | # Execute Python code to aggregate messages 52 | aggregated_message = "" 53 | hints = "(Hint: The answer is near to" 54 | pattern = r'```python.*```' 55 | for id, info in messages.items(): 56 | aggregated_message += f"Agent {id}, role is {info['role'].role}, output is:\n\n {info['output']}\n\n" 57 | match = re.search(pattern, info['output'], re.DOTALL|re.MULTILINE) 58 | if match: 59 | code = match.group(0).lstrip("```python\n").rstrip("\n```") # the result must in the local vars answer 60 | answer = execute_code_get_return(code) 61 | if answer: 62 | hints += f" {answer}," 63 | aggregated_message += f"The execution result of the code is {answer}." 64 | hints += ")." 65 | aggregated_message += hints if hints != "(Hint: The answer is near to)." else "" 66 | return aggregated_message 67 | 68 | def extract_example(prompt: Dict[str,str]) -> list: 69 | # the prompt['query'] only contains the code snippet 70 | prompt = prompt['query'] 71 | lines = (line.strip() for line in prompt.split('\n') if line.strip()) 72 | results = [] 73 | lines_iter = iter(lines) 74 | for line in lines_iter: 75 | if line.startswith('>>>'): 76 | function_call = line[4:] 77 | expected_output = next(lines_iter, None) 78 | if expected_output: 79 | results.append(f"assert {function_call} == {expected_output}") 80 | if line.startswith('assert'): 81 | results.append(line) 82 | return results 83 | 84 | def python_inner_test(raw_inputs:Dict[str,str],messages:Dict[str,Dict]): 85 | """ 86 | Aggregate messages from other agents in temporal and spatial dimensions by running inner tests. 87 | Args: 88 | messages: Dict[str,Dict]: A dict of messages from other agents. 89 | Returns: 90 | Any:str: Aggregated message. 91 | """ 92 | # Run inner tests to aggregate messages 93 | internal_tests = extract_example(raw_inputs) 94 | aggregated_message = "" 95 | pattern = r'```python.*```' 96 | for id, info in messages.items(): 97 | aggregated_message += f"Agent {id}, role is {info['role'].role}, output is:\n\n {info['output']}\n\n" 98 | match = re.search(pattern, info['output'], re.DOTALL|re.MULTILINE) 99 | if match: 100 | code = match.group(0).lstrip("```python\n").rstrip("\n```") 101 | is_solved, feedback, state = PyExecutor().execute(code, internal_tests, timeout=100) 102 | if is_solved: 103 | aggregated_message += f"\nThe code is solved.\n {feedback}" 104 | else: 105 | aggregated_message += f"The code is not solved.\n {feedback}" 106 | return aggregated_message 107 | 108 | def php(raw_inputs:Dict[str,str],messages:Dict[str,Dict]): 109 | """ 110 | Aggregate messages from other agents in temporal and spatial dimensions using PHP. 111 | Args: 112 | messages: Dict[str,Dict]: A dict of messages from other agents. 113 | Returns: 114 | Any:str: Aggregated message. 115 | """ 116 | # Use PHP to aggregate messages 117 | aggregated_message = "" 118 | hints = "(Hint: The answer is near to" 119 | python_pattern = r'```python.*```' 120 | 121 | for id, info in messages.items(): 122 | aggregated_message += f"Agent {id}, role is {info['role'].role}, output is:\n\n {info['output']}\n\n" 123 | python_match = re.search(python_pattern, info['output'], re.DOTALL|re.MULTILINE) 124 | if python_match: 125 | code = python_match.group(0).lstrip("```python\n").rstrip("\n```") # the result must in the local vars answer 126 | answer = execute_code_get_return(code) 127 | if answer: 128 | hints += f" {answer}," 129 | aggregated_message += f"The execution result of the code is {answer}." 130 | if 'the answer is ' in info['output'] or 'The answer is ' in info['output']: 131 | answer = gsm_get_predict(info['output']) 132 | hints += f" {answer}," 133 | hints += ")." 134 | aggregated_message += hints if hints != "(Hint: The answer is near to)." else "" 135 | return aggregated_message 136 | 137 | def inner_test(raw_inputs:Dict[str,str], spatial_info:Dict[str,Dict], temporal_info:Dict[str,Dict], ): 138 | # Use inter tests to aggregate messages 139 | internal_tests = extract_example(raw_inputs) 140 | if internal_tests == []: 141 | return False, "" 142 | pattern = r'```python.*```' 143 | for id, info in spatial_info.items(): 144 | match = re.search(pattern, info['output'], re.DOTALL|re.MULTILINE) 145 | if match: 146 | code = match.group(0).lstrip("```python\n").rstrip("\n```") 147 | is_solved, feedback, state = PyExecutor().execute(code, internal_tests, timeout=10) 148 | if is_solved: 149 | return is_solved, info['output'] 150 | for id, info in temporal_info.items(): 151 | match = re.search(pattern, info['output'], re.DOTALL|re.MULTILINE) 152 | if match: 153 | code = match.group(0).lstrip("```python\n").rstrip("\n```") 154 | is_solved, feedback, state = PyExecutor().execute(code, internal_tests, timeout=10) 155 | if is_solved: 156 | return is_solved, info['output'] 157 | return False, "" 158 | -------------------------------------------------------------------------------- /MAR/Prompts/output_format.py: -------------------------------------------------------------------------------- 1 | # Options: ["None", "Text", "Analyze", "Calculation", "Examine", "Answer", "CodeCompletion", "CodeSolver", "Keys"] 2 | output_format_prompt = { 3 | "None": None, 4 | "Text": "", 5 | "Calculation": "Please provide the formula for the problem and bring in the numerical values to solve the problem.\n\ 6 | The last line of your output must contain only the final result without any units or redundant explanation,\ 7 | for example: The answer is 140\n\ 8 | If it is a multiple choice question, please output the options. For example: The answer is A.\n\ 9 | However, The answer is 140$ or The answer is Option A or The answer is A.140 is not allowed.", 10 | "Examine": "If you are provided with other responses, check that they are correct and match each other.\n\ 11 | Check whether the logic/calculation of the problem solving and analysis process is correct(if present).\n\ 12 | Check whether the code corresponds to the solution analysis(if present).\n\ 13 | Give your own complete solving process using the same format." , 14 | "Answer": "The last line of your output must contain only the final result without any units or redundant explanation,\ 15 | for example: The answer is 140\n\ 16 | If it is a multiple choice question, please output the options. For example: The answer is A.\n\ 17 | However, The answer is 140$ or The answer is Option A or The answer is A.140 is not allowed.\n", 18 | "CodeCompletion": "You will be given a function signature and its docstring by the user.\n\ 19 | Write your full implementation of this function.\n\ 20 | Use a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```\n\ 21 | Do not change function names and input variable types in tasks.", 22 | "CodeSolver": "Analyze the question and write functions to solve the problem.\n\ 23 | The function should not take any arguments and use the final result as the return value.\n\ 24 | The last line of code calls the function you wrote and assigns the return value to the \(answer\) variable.\n\ 25 | Use a Python code block to write your response. For example:\n```python\ndef fun():\n x = 10\n y = 20\n return x + y\nanswer = fun()\n```\n", 26 | "Keys": "Please provide relevant keywords that need to be searched on the Internet, relevant databases, or Wikipedia.\n\ 27 | Use a key word block to give a list of keywords of your choice.\n\ 28 | Please give a few concise short keywords, the number should be less than four\n\ 29 | For example:\n```keyword\n['catfish effect', 'Shakespeare', 'global warming']\n```\n\ 30 | If there is no entity in the question that needs to be searched, you don't have to provide it." 31 | } -------------------------------------------------------------------------------- /MAR/Prompts/post_process.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict 3 | from loguru import logger 4 | 5 | from MAR.Tools.coding.python_executor import execute_code_get_return 6 | from MAR.Tools.coding.python_executor import PyExecutor 7 | from langchain_community.tools import WikipediaQueryRun 8 | from langchain_community.utilities import WikipediaAPIWrapper 9 | 10 | # Options: ["None", "PythonExecute", "PythonInnerTest", "Wiki", "Search", "Reflection"] 11 | def post_process(raw_inputs:Dict[str,str], output:str, post_method:str): 12 | if post_method == None or post_method == "None": 13 | return output 14 | elif post_method == "PythonExecute": 15 | return python_execute(raw_inputs, output) 16 | elif post_method == "PythonInnerTest": 17 | return python_inner_test(raw_inputs, output) 18 | elif post_method == "Wiki": 19 | return wiki(raw_inputs, output) 20 | elif post_method == "Search": 21 | return search(raw_inputs, output) 22 | elif post_method == "Reflection": 23 | return reflection(raw_inputs, output) 24 | else: 25 | raise ValueError(f"Invalid post-processing method: {post_method}") 26 | 27 | 28 | def python_execute(raw_inputs:Dict[str,str], output:str): 29 | """ 30 | Execute Python code to post-process the output. 31 | Args: 32 | output: str: The output from the LLM. 33 | Returns: 34 | Any: str: The post-processed output. 35 | """ 36 | # Execute Python code to post-process the output 37 | pattern = r'```python.*```' 38 | match = re.search(pattern, output, re.DOTALL|re.MULTILINE) 39 | if match: 40 | code = match.group(0).lstrip("```python\n").rstrip("\n```") 41 | output += f"\nthe answer is {execute_code_get_return(code)}" 42 | return output 43 | 44 | def extract_example(prompt: str) -> list: 45 | # the prompt['query'] only contains the code snippet 46 | prompt = prompt['query'] 47 | lines = (line.strip() for line in prompt.split('\n') if line.strip()) 48 | results = [] 49 | lines_iter = iter(lines) 50 | for line in lines_iter: 51 | if line.startswith('>>>'): 52 | function_call = line[4:] 53 | expected_output = next(lines_iter, None) 54 | if expected_output: 55 | results.append(f"assert {function_call} == {expected_output}") 56 | return results 57 | 58 | def python_inner_test(raw_inputs:Dict[str,str], output:str): 59 | """ 60 | Execute Python code to post-process the output. 61 | Args: 62 | raw_inputs: Dict[str,str]: The raw inputs. 63 | output: str: The output from the LLM. 64 | Returns: 65 | Any: str: The post-processed output. 66 | """ 67 | internal_tests = extract_example(raw_inputs) 68 | pattern = r'```python.*```' 69 | match = re.search(pattern, output, re.DOTALL|re.MULTILINE) 70 | if match: 71 | code = match.group(0).lstrip("```python\n").rstrip("\n```") 72 | is_solved, feedback, state = PyExecutor().execute(code, internal_tests, timeout=10) 73 | if is_solved: 74 | output += f"\nThe code is solved.\n {feedback}" 75 | else: 76 | output += f"\nThe code is not solved.\n {feedback}" 77 | return output 78 | 79 | def wiki(raw_inputs:Dict[str,str], output:str): 80 | """ 81 | Extract information from Wikipedia to post-process the output. 82 | Args: 83 | output: str: The output from the LLM. 84 | Returns: 85 | Any: str: The post-processed output. 86 | """ 87 | # Extract information from Wikipedia to post-process the output 88 | pattern = r'```keyword.*```' 89 | match = re.search(pattern, output, re.DOTALL|re.MULTILINE) 90 | if match: 91 | keywords = match.group(0).lstrip("```keyword\n").rstrip("\n```") 92 | # Extract information from Wikipedia 93 | logger.info(f"keywords: {keywords}") 94 | keywords = eval(keywords) 95 | wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper(top_k_results=2)) 96 | for keyword in keywords: 97 | if type(keyword) == str or type(keyword) == dict: 98 | output += f"\n{wikipedia.run(keyword)}" 99 | return output 100 | 101 | def search(raw_inputs:Dict[str,str], output:str): 102 | """ 103 | Search for information to post-process the output. 104 | Args: 105 | output: str: The output from the LLM. 106 | Returns: 107 | Any: str: The post-processed output. 108 | """ 109 | #TODO Search for information to post-process the output 110 | #Maybe bocha or brave search 111 | 112 | return output 113 | 114 | def reflection(raw_inputs:Dict[str,str], output:str): 115 | """ 116 | Reflect on the output to post-process the output. 117 | Args: 118 | output: str: The output from the LLM. 119 | Returns: 120 | Any: str: The post-processed output. 121 | """ 122 | return output -------------------------------------------------------------------------------- /MAR/Prompts/reasoning.py: -------------------------------------------------------------------------------- 1 | reasoning_prompt = {"IO": "", 2 | "CoT": "Please give step by step answers to the questions.", 3 | "Chain": "", 4 | "FullConnected": "", 5 | "Debate": "Please try your best to give answers that are different or opposite to those of other agents.", 6 | "Reflection":""} -------------------------------------------------------------------------------- /MAR/Prompts/tasks_profile.py: -------------------------------------------------------------------------------- 1 | tasks_profile = [{'Name': 'Math', 'Description': 'A mathematics problem often involves logical reasoning, calculations, and applying various mathematical concepts such as algebra, geometry, calculus, or statistics. The goal is to solve for unknowns, prove theorems, or model real-world phenomena using mathematical methods.'}, 2 | {'Name': 'Commonsense', 'Description': 'A commonsense question typically involves general knowledge, reasoning, or problem-solving skills. It may require understanding of everyday concepts, social norms, or human behavior to answer questions about common situations or scenarios.'}, 3 | {'Name': 'Code', 'Description': 'A code question typically involves writing, debugging, or understanding computer code. It may require knowledge of programming languages, algorithms, data structures, or software development practices to solve coding problems or implement software solutions.'},] -------------------------------------------------------------------------------- /MAR/Roles/Code/AlgorithmDesigner.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "AlgorithmDesigner", 3 | "MessageAggregation": "PythonInnerTest", 4 | "Description": "You are an algorithm designer. You will be given a function signature and its docstring by the user.\nYou need to specify the specific design of the algorithm, including explanations of the algorithm, usage instructions, and API references.\nYou can refer to specific examples.\nWhen the implementation logic is complex, you can give the pseudocode logic of the main algorithm.\nYour reply will be more concise.\nPreferably within fifty words.", 5 | "OutputFormat": "Text", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Code/BugFixer.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "BugFixer", 3 | "MessageAggregation": "PythonInnerTest", 4 | "Description": "You are a programming expert. You will be given a function signature and its docstring by the user. Use a Python code block to write your full implementation (restate the function signature).", 5 | "OutputFormat": "CodeCompletion", 6 | "PostProcess": "PythonInnerTest", 7 | "PostDescription": "You need to provide modified and improved python code based on the current code implementation and problems that arise during testing.\nYou can refer to specific examples.\nWrite your full implementation (restate the function signature). ", 8 | "PostOutputFormat": "CodeCompletion" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Code/PlanSolver.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "PlanSolver", 3 | "MessageAggregation": "PythonInnerTest", 4 | "Description": "Please give the pseudo code of the function.", 5 | "OutputFormat": "Text", 6 | "PostProcess": "None", 7 | "PostDescription": "You are a programming expert. You will be given a function signature and its docstring by the user.\n Use a Python code block to write your full implementation (restate the function signature).", 8 | "PostOutputFormat": "CodeCompletion" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Code/ProgrammingExpert.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "ProgrammingExpert", 3 | "MessageAggregation": "PythonInnerTest", 4 | "Description": "You are a programming expert. You will be given a function signature and its docstring by the user.\nYou can be combined with specific examples in the docstring.\n Use a Python code block to write your full implementation (restate the function signature).", 5 | "OutputFormat": "CodeCompletion", 6 | "PostProcess": "PythonInnerTest", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Code/ProjectManager.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "ProjectManager", 3 | "MessageAggregation": "PythonInnerTest", 4 | "Description": "You are a project manager. You will be given a function signature and its docstring by the user.\nYou are responsible for overseeing the overall structure of the code, ensuring that the code is structured to complete the task Implement code concisely and correctly without pursuing over-engineering.\nYou need to suggest optimal design patterns to ensure that the code follows best practices for maintainability and flexibility.\nYour reply should be more concise.\nPreferably within fifty words.", 5 | "OutputFormat": "Text", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Code/ReflectProgrammer.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "ReflectProgrammer", 3 | "MessageAggregation": "PythonInnerTest", 4 | "Description": "You are a programming expert. You will be given a function signature and its docstring by the user. Use a Python code block to write your full implementation (restate the function signature).", 5 | "OutputFormat": "CodeCompletion", 6 | "PostProcess": "PythonInnerTest", 7 | "PostDescription": "Reflect on possible errors in the answer above and answer again.", 8 | "PostOutputFormat": "CodeCompletion" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Code/TestAnalyst.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "TestAnalyst", 3 | "MessageAggregation": "PythonInnerTest", 4 | "Description": "You are a Test Analyst. You will be given a function signature and its docstring by the user.\nYou need to provide problems in the current code or solution based on the test data and possible test feedback in the question.\nYou need to provide additional special use cases, boundary conditions, etc. that should be paid attention to when writing code.\nYou can point out any potential errors in the code.\nYour reply should be more concise.\nPreferably within fifty words.", 5 | "OutputFormat": "Text", 6 | "PostProcess": "None", 7 | "PostDescription": "You are a programming expert. You will be given a function signature and its docstring by the user.\nGive your own answers to problems that arise in other implementations.\n Use a Python code block to write your full implementation (restate the function signature).", 8 | "PostOutputFormat": "CodeCompletion" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Commonsense/Critic.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Critic", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are an excellent critic. Please point out potential issues in other agent's analysis point by point. Give your critical opinion. Finally give the final result", 5 | "OutputFormat": "Answer", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Commonsense/Economist.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Economist", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are an experienced economist with expertise in macroeconomics, microeconomics, and financial markets.Your role is to provide well-reasoned and evidence-based answers to the given questions", 5 | "OutputFormat": "Answer", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Commonsense/Historian.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Historian", 3 | "MessageAggregation": "Normal", 4 | "Description": "You research and analyze cultural, economic, political, and social events in the past, collect data from primary sources and use it to develop theories about what happened during various periods of history.", 5 | "OutputFormat": "Answer", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Commonsense/KnowledgeExpert.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "KnowledgeExpert", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are a knowlegable expert in question answering. Please analyze step by step and choose the correct answer.", 5 | "OutputFormat": "Answer", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Commonsense/Reflector.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Reflector", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are a knowlegable expert in question answering. Please analyze step by step and choose the correct answer.", 5 | "OutputFormat": "Answer", 6 | "PostProcess": "None", 7 | "PostDescription": "Reflect on possible errors in the answer above and answer again.", 8 | "PostOutputFormat": "Answer" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Commonsense/Scientist.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Scientist", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are a scientist with knowledge and insights in natural science. You will be given a complex math problem . Your task is to provide a thorough and detailed solving process, including any necessary proofs and explanations.", 5 | "OutputFormat": "Answer", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Commonsense/WikiSearcher.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "WikiSearcher", 3 | "MessageAggregation": "Normal", 4 | "Description": "Please give several key entities that need to be searched in wikipedia to solve the problem. These entities should be separate entries, which can be found on Wikipedia.", 5 | "OutputFormat": "Keys", 6 | "PostProcess": "Wiki", 7 | "PostDescription": "You are a knowlegable expert in question answering. Please answer the question based on the explanation of the question keywords obtained from the wikipedia search.", 8 | "PostOutputFormat": "Answer" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/FinalNode/gsm8k.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "You will be given a math problem, analysis and code from other agents. Please find the most reliable answer based on the analysis and results of other agents. Give reasons for making decisions. The last line of your output contains only the final result without any units, for example: The answer is 140. However, The answer is 140$ or The answer is Option A or The answer is A.140 is not allowed. Remember not to add units and no other content.", 3 | "user": "Please provide the final answer based on the analysis and results of other agents. For example: the answer is 140 or the answer is 0" 4 | } -------------------------------------------------------------------------------- /MAR/Roles/FinalNode/humaneval.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "You are the top decision-maker and are good at analyzing and summarizing other people's opinions, finding errors and giving final answers. \nUse a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```\nDo not include anything other than Python code blocks in your response.", 3 | "user": "You will be given a function signature and its docstring by the user.\nYou may be given the overall code design, algorithm framework, code implementation or test problems.\nWrite your full implementation (restate the function signature).\nUse a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```\nDo not include anything other than Python code blocks in your response." 4 | } -------------------------------------------------------------------------------- /MAR/Roles/FinalNode/math.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "You will be given a math problem, analysis and code from other agents. Please find the most reliable answer based on the analysis and results of other agents. Give reasons for making decisions. You answer should be wrapped by \\boxed{} without any units, for example: The answer is \\boxed{140}. Remember not to add units.", 3 | "user": "Please provide the final answer based on the analysis and results of other agents. For example: the answer is \\boxed{140} or the answer is \\boxed{140}" 4 | } -------------------------------------------------------------------------------- /MAR/Roles/FinalNode/mbpp.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "You are the top decision-maker and are good at analyzing and summarizing other people's opinions, finding errors and giving final answers. \nUse a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```\nDo not include anything other than Python code blocks in your response.", 3 | "user": "You will be given a function signature and its docstring by the user.\nYou may be given the overall code design, algorithm framework, code implementation or test problems.\nWrite your full implementation (restate the function signature).\nUse a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```\nDo not include anything other than Python code blocks in your response." 4 | } -------------------------------------------------------------------------------- /MAR/Roles/FinalNode/mmlu.json: -------------------------------------------------------------------------------- 1 | { 2 | "system": "You are the top decision-maker and are good at analyzing and summarizing other people's opinions, finding errors and giving final answers.", 3 | "user": "\nOnly one answer out of the offered 4 is correct.\nYou must choose the correct answer to the question.\nYour response must be one of the 4 letters: A, B, C or D, corresponding to the correct answer.\nI will give you some other people's answers and analysis.\nThe last line of the reply should contain only one sentence(the answer is \\boxed{A/B/C/D}.) and nothing else.\nFor example, The answer is the answer is \\boxed{A}." 4 | } -------------------------------------------------------------------------------- /MAR/Roles/Math/AlgorithmEngineer.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "AlgorithmEngineer", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are a proficient algorithm engineer with expertise in designing, analyzing, and optimizing algorithms. You are skilled in fields such as algorithm design, mathematical theories,machine learning, and optimization.You will be given a math problem, analysis and code from other agents. Integrate step-by-step reasoning and Python code to solve math problems. ", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Math/CertifiedAccountant.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "CertifiedAccountant", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are a Certified Accountant. You will be given financial problems and scenarios. You always analyze and understand the problem correctly and gives right calculations and solutions.", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Math/Economist.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Economist", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are an experienced economist with expertise in macroeconomics, microeconomics, and financial markets.Your role is to provide well-reasoned and evidence-based answers to the given questions", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Math/Engineer.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Engineer", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are an experienced engineer. You are familiar with designing, developing, and maintaining engineering systems and solutions.You will be given a math problem. Give your own solving process based on your knowledge.", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } 10 | -------------------------------------------------------------------------------- /MAR/Roles/Math/Inspector.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Inspector", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are an Inspector. You will be given a math problem, analysis and code from other agents. Check whether the logic/calculation of the problem solving and analysis process is correct(if present). Check whether the code corresponds to the solution analysis(if present). Give your own solving process step by step based on hints", 5 | "OutputFormat": "Answer", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } 10 | -------------------------------------------------------------------------------- /MAR/Roles/Math/MathAnalyst.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "MathAnalyst", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are a mathematical analyst. You will be given a math problem, analysis and code from other agents. You need to first analyze the problem-solving process, where the variables are represented by letters. Then you substitute the values into the analysis process to perform calculations and get the results.", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } 10 | -------------------------------------------------------------------------------- /MAR/Roles/Math/MathSolver.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "MathSolver", 3 | "MessageAggregation": "PHP", 4 | "Description": "You are a math expert. You will be given a math problem and hints from other agents. Give your own solving process based on hints.", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Math/MathTeacher.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "MathTeacher", 3 | "MessageAggregation": "PHP", 4 | "Description": "You are an excellent math teacher and always teach your students math problems correctly. And I am one of your students.You will be given a math problem, teach me step by step how to solve the problem.", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Math/Mathematician.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Mathematician", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are a mathematician who is good at math games, arithmetic calculation, and long-term planning.", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Math/ProgrammingExpert.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "ProgrammingExpert", 3 | "MessageAggregation": "PHP", 4 | "Description": "You are a programming expert. You will be given a math problem, analysis and code from other agents. Integrate step-by-step reasoning and Python code to solve math problems. Analyze the question and write functions to solve the problem. ", 5 | "OutputFormat": "CodeSolver", 6 | "PostProcess": "PythonExecute", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "Answer" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Math/Scientist.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "Scientist", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are a scientist with knowledge and insights in natural science. You will be given a complex math problem . Your task is to provide a thorough and detailed solving process, including any necessary proofs and explanations.", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "None", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "None" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/Math/SoftwareDeveloper.json: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "SoftwareDeveloper", 3 | "MessageAggregation": "Normal", 4 | "Description": "You are a skilled software developer with expertise in software architecture, coding, debugging, and system design. Your role is to analyze the given problem, design efficient solutions, and provide clear and concise functions to solve the problem. ", 5 | "OutputFormat": "Calculation", 6 | "PostProcess": "PythonExecute", 7 | "PostDescription": "None", 8 | "PostOutputFormat": "Answer" 9 | } -------------------------------------------------------------------------------- /MAR/Roles/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanweiyue/masrouter/97849a3dcab21d2962e345551be7f3c6935c66d6/MAR/Roles/__init__.py -------------------------------------------------------------------------------- /MAR/Roles/role_example.py: -------------------------------------------------------------------------------- 1 | { 2 | "Name": "MathSolver", # str:Role name 3 | "MessageAggregation": "PHP", # str: How to aggregate messages from other agents in temporal and spatial dimensions. 4 | # The order of the list is the priority of the aggregation method. 5 | # Options: ["Normal", "PythonExecute", "PythonInnerTest", "PHP"] 6 | "Description": "You are a math expert.\n\ 7 | You will be given a math problem and hints from other agents.\n\ 8 | Give your own solving process step by step based on hints.", # str:Role description about what tasks need to be completed? 9 | "OutputFormat": "Answer", # str: Output format of the role. 10 | # Options: ["None", "Text", "Analyze", "Calculation", "Examine", "Answer", "CodeCompletion", "CodeSolver", "Keys"] 11 | "PostProcess": "Wiki", # str: Post-processing methods for the output. 12 | # Options: ["None", "PythonExecute", "PythonInnerTest", "Wiki", "Search", "Reflection"] 13 | "PostDescription": "Reflect on possible errors in the answer above and answer again using the same format.", # str: Post-processing description. 14 | "PostOutputFormat": "Answer", # str: Post-processing output format. 15 | # Options: ["None", "Text", "Analyze", "Calculation", "Examine", "Answer", "CodeCompletion", "CodeSolver", "Keys"] 16 | # None means no post-processing. 17 | } -------------------------------------------------------------------------------- /MAR/Roles/role_registry.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class RoleRegistry: 4 | def __init__(self, domain, role): 5 | self.domain = domain 6 | self.role = role 7 | self.role_profile = self.get_role_profile() 8 | 9 | def get_role_profile(self): 10 | profile = json.load(open(f"MAR/Roles/{self.domain}/{self.role}.json")) 11 | return profile 12 | 13 | def get_name(self): 14 | return self.role_profile['Name'] 15 | 16 | def get_message_aggregation(self): 17 | return self.role_profile['MessageAggregation'] 18 | 19 | def get_description(self): 20 | return self.role_profile['Description'] 21 | 22 | def get_output_format(self): 23 | return self.role_profile['OutputFormat'] 24 | 25 | def get_reasoning(self): 26 | return self.role_profile['Reasoning'] 27 | 28 | def get_post_process(self): 29 | return self.role_profile['PostProcess'] 30 | 31 | def get_post_description(self): 32 | return self.role_profile['PostDescription'] 33 | 34 | def get_post_output_format(self): 35 | return self.role_profile['PostOutputFormat'] 36 | -------------------------------------------------------------------------------- /MAR/Tools/coding/executor_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from MAR.Utils.log import logger 6 | from MAR.Tools.coding.python_executor import PyExecutor 7 | from MAR.Tools.coding.executor_types import Executor 8 | 9 | EXECUTOR_MAPPING = { 10 | "py": PyExecutor, 11 | "python": PyExecutor, 12 | } 13 | 14 | def executor_factory(lang: str) -> Executor: 15 | 16 | if lang not in EXECUTOR_MAPPING: 17 | raise ValueError(f"Invalid language for executor: {lang}") 18 | 19 | executor_class = EXECUTOR_MAPPING[lang] 20 | return executor_class() -------------------------------------------------------------------------------- /MAR/Tools/coding/executor_types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from typing import NamedTuple, List, Tuple 5 | from abc import ABC, abstractmethod 6 | 7 | class ExecuteResult(NamedTuple): 8 | is_passing: bool 9 | feedback: str 10 | state: Tuple[bool] 11 | 12 | class Executor(ABC): 13 | @abstractmethod 14 | def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult: 15 | ... 16 | 17 | @abstractmethod 18 | def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool: 19 | ... 20 | 21 | -------------------------------------------------------------------------------- /MAR/Tools/coding/executor_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import json 6 | from threading import Thread 7 | 8 | def timeout_handler(_, __): 9 | raise TimeoutError() 10 | 11 | 12 | def to_jsonl(dict_data, file_path): 13 | with open(file_path, 'a') as file: 14 | json_line = json.dumps(dict_data) 15 | file.write(json_line + os.linesep) 16 | 17 | 18 | class PropagatingThread(Thread): 19 | def run(self): 20 | self.exc = None 21 | try: 22 | if hasattr(self, '_Thread__target'): 23 | # Thread uses name mangling prior to Python 3. 24 | self.ret = self._Thread__target(*self._Thread__args, **self._Thread__kwargs) 25 | else: 26 | self.ret = self._target(*self._args, **self._kwargs) 27 | except BaseException as e: 28 | self.exc = e 29 | 30 | def join(self, timeout=None): 31 | super(PropagatingThread, self).join(timeout) 32 | if self.exc: 33 | raise self.exc 34 | return self.ret 35 | 36 | 37 | def function_with_timeout(func, args, timeout): 38 | result_container = [] 39 | 40 | def wrapper(): 41 | result_container.append(func(*args)) 42 | 43 | thread = PropagatingThread(target=wrapper) 44 | thread.start() 45 | thread.join(timeout) 46 | 47 | if thread.is_alive(): 48 | raise TimeoutError() 49 | else: 50 | return result_container[0] 51 | 52 | 53 | -------------------------------------------------------------------------------- /MAR/Tools/coding/python_executor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import ast 5 | import astunparse 6 | from typing import List 7 | 8 | from MAR.Tools.coding.executor_utils import function_with_timeout 9 | from MAR.Tools.coding.executor_types import ExecuteResult, Executor 10 | 11 | 12 | def get_call_str(assert_statement: str) -> str: 13 | ast_parsed = ast.parse(assert_statement) 14 | try: 15 | call_str = ast_parsed.body[0].test.left # type: ignore 16 | except: 17 | call_str = ast_parsed.body[0].test # type: ignore 18 | 19 | return astunparse.unparse(call_str).strip() 20 | 21 | def get_output(func: str, assert_statement: str, timeout: int = 5) -> str: 22 | try: 23 | exec(f"from typing import *\n{func}", globals()) 24 | func_call = get_call_str(assert_statement) 25 | output = function_with_timeout(eval, (func_call, globals()), timeout) 26 | return output 27 | except TimeoutError: 28 | return "TIMEOUT" 29 | except Exception as e: 30 | return str(e) 31 | 32 | def execute_code_get_return(code: str): 33 | local_vars = {} 34 | try: 35 | exec(code, {}, local_vars) 36 | if 'answer' in local_vars: 37 | return local_vars['answer'] 38 | else: 39 | return None 40 | except Exception as e: 41 | return f"Error occurred: {e}" 42 | 43 | class PyExecutor(Executor): 44 | def execute(self, func: str, tests: List[str], timeout: int = 5, verbose: bool = True) -> ExecuteResult: 45 | # Combine function code and assert statement 46 | imports = 'from typing import *' 47 | func_test_list = [f'{imports}\n{func}\n{test}' for test in tests] 48 | 49 | # Run the tests and collect the results 50 | success_tests = [] 51 | failed_tests = [] 52 | is_passing = True 53 | num_tests = len(func_test_list) 54 | for i in range(num_tests): 55 | try: 56 | function_with_timeout(exec, (func_test_list[i], globals()), timeout) 57 | success_tests.append(tests[i]) 58 | except Exception: 59 | output = get_output(func, tests[i], timeout=timeout) 60 | failed_tests.append(f"{tests[i]} # output: {output}") 61 | is_passing = False 62 | 63 | state = [test in success_tests for test in tests] 64 | 65 | feedback = "Tests passed:\n" + "\n".join(success_tests) + "\n\nTests failed:" 66 | feedback += "\n" + "\n".join(failed_tests) 67 | return is_passing, feedback, tuple(state) 68 | 69 | def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool: 70 | """ 71 | Evaluates the implementation on Human-Eval Python. 72 | 73 | probably should be written in a dataset-agnostic way but not now 74 | """ 75 | 76 | code = f"""{func} 77 | 78 | {test} 79 | 80 | check({name}) 81 | """ 82 | try: 83 | function_with_timeout(exec, (code, globals()), timeout) 84 | return True 85 | except Exception: 86 | return False 87 | -------------------------------------------------------------------------------- /MAR/Tools/reader/readers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from openai import OpenAI 5 | import pdb 6 | 7 | """INSTALL 8 | pip install openai --upgrade 9 | pip install python-docx 10 | pip install markdown 11 | pip install PyPDF2 12 | pip install openpyxl 13 | pip install beautifulsoup4 14 | pip install pylatexenc 15 | pip install python-pptx 16 | pip install xlrd 17 | """ 18 | 19 | import json 20 | import os 21 | import pandas as pd 22 | import charset_normalizer 23 | import docx 24 | import markdown 25 | import PyPDF2 26 | import openpyxl 27 | import yaml 28 | import zipfile 29 | import subprocess 30 | from pathlib import Path 31 | from abc import ABC, abstractmethod 32 | from typing import Union, Any, Optional 33 | from bs4 import BeautifulSoup 34 | from pylatexenc.latex2text import LatexNodes2Text 35 | from pptx import Presentation 36 | 37 | from MAR.Utils.log import logger 38 | from MAR.Utils.globals import Cost 39 | 40 | from dotenv import load_dotenv 41 | load_dotenv() 42 | import aiohttp 43 | import requests 44 | from openai import OpenAI, AsyncOpenAI 45 | 46 | OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") 47 | 48 | 49 | # Refs: https://platform.openai.com/docs/api-reference 50 | # Refs: https://github.com/Significant-Gravitas/AutoGPT/blob/0e332c0c1221857f3ce96490f073c1c88bcbd367/autogpts/autogpt/autogpt/commands/file_operations_utils.py 51 | 52 | class Reader(ABC): 53 | @abstractmethod 54 | def parse(self, file_path: Path) -> str: 55 | """ To be overriden by the descendant class """ 56 | 57 | 58 | class TXTReader(Reader): 59 | def parse(self, file_path: Path) -> str: 60 | content = charset_normalizer.from_path(file_path).best() 61 | encoding = getattr(content, 'encoding', None) 62 | logger.info(f"Reading TXT file from {file_path} using encoding '{encoding}'.") 63 | return str(content) 64 | 65 | class PDFReader(Reader): 66 | def parse(self, file_path: Path) -> str: 67 | logger.info(f"Reading PDF file from {file_path}.") 68 | content = PyPDF2.PdfReader(file_path) 69 | text = "" 70 | for page_idx in range(len(content.pages)): 71 | text += f'Page {page_idx + 1}\n' + content.pages[page_idx].extract_text() 72 | return text 73 | 74 | class DOCXReader(Reader): 75 | def parse(self, file_path: Path) -> str: 76 | logger.info(f"Reading DOCX file from {file_path}.") 77 | content = docx.Document(str(file_path)) 78 | text = "" 79 | for i, para in enumerate(content.paragraphs): 80 | text += f'Page {i + 1}:\n' + para.text 81 | return text 82 | 83 | class JSONReader(Reader): 84 | def parse_file(self, file_path: Path) -> list: 85 | logger.info(f"Reading JSON file from {file_path}.") 86 | try: 87 | with open(file_path, "r") as f: 88 | data = json.load(f) 89 | #text = str(data) 90 | return data#text 91 | except: 92 | return [] 93 | 94 | def parse(self, file_path: Path) -> str: 95 | logger.info(f"Reading JSON file from {file_path}.") 96 | try: 97 | with open(file_path, "r") as f: 98 | data = json.load(f) 99 | text = str(data) 100 | return text 101 | except: 102 | return '' 103 | 104 | class JSONLReader(Reader): 105 | def parse_file(self, file_path) -> list: 106 | logger.info(f"Reading JSON Lines file from {file_path}.") 107 | with open(file_path, "r",encoding='utf-8') as f: 108 | lines = [json.loads(line) for line in f] 109 | #text = '\n'.join([str(line) for line in lines]) 110 | return lines #text 111 | 112 | def parse(self, file_path) -> str: 113 | logger.info(f"Reading JSON Lines file from {file_path}.") 114 | with open(file_path, "r",encoding='utf-8') as f: 115 | lines = [json.loads(line) for line in f] 116 | text = '\n'.join([str(line) for line in lines]) 117 | return text 118 | 119 | class XMLReader(Reader): 120 | def parse(self, file_path: Path) -> str: 121 | logger.info(f"Reading XML file from {file_path}.") 122 | with open(file_path, "r") as f: 123 | data = BeautifulSoup(f, "xml") 124 | text = data.get_text() 125 | return text 126 | 127 | class YAMLReader(Reader): 128 | def parse(self, file_path: Path, return_str=True) -> Union[str, Any]: 129 | logger.info(f"Reading YAML file from {file_path}.") 130 | with open(file_path, "r") as f: 131 | data = yaml.load(f, Loader=yaml.FullLoader) 132 | text = str(data) 133 | if return_str: 134 | return text 135 | else: 136 | return data 137 | 138 | class HTMLReader(Reader): 139 | def parse(self, file_path: Path) -> str: 140 | logger.info(f"Reading HTML file from {file_path}.") 141 | with open(file_path, "r") as f: 142 | data = BeautifulSoup(f, "html.parser") 143 | text = data.get_text() 144 | return text 145 | 146 | class MarkdownReader(Reader): 147 | def parse(self, file_path: Path) -> str: 148 | logger.info(f"Reading Markdown file from {file_path}.") 149 | with open(file_path, "r") as f: 150 | data = markdown.markdown(f.read()) 151 | text = "".join(BeautifulSoup(data, "html.parser").findAll(string=True)) 152 | return text 153 | 154 | class LaTexReader(Reader): 155 | def parse(self, file_path: Path) -> str: 156 | logger.info(f"Reading LaTex file from {file_path}.") 157 | with open(file_path, "r") as f: 158 | data = f.read() 159 | text = LatexNodes2Text().latex_to_text(data) 160 | return text 161 | 162 | 163 | 164 | class AudioReader(Reader): 165 | @staticmethod 166 | def parse(file_path: Path) -> str: 167 | logger.info(f"Transcribing audio file from {file_path}.") 168 | client = OpenAI(api_key=OPENAI_API_KEY) 169 | try: 170 | client = OpenAI() 171 | with open(file_path, "rb") as audio_file: 172 | transcript = client.audio.translations.create( 173 | model="whisper-1", 174 | file=audio_file 175 | ) 176 | return transcript.text 177 | except Exception as e: 178 | logger.info(f"Error transcribing audio file: {e}") 179 | return "Error transcribing audio file." 180 | 181 | class PPTXReader(Reader): 182 | def parse(self, file_path: Path) -> str: 183 | logger.info(f"Reading PowerPoint file from {file_path}.") 184 | try: 185 | pres = Presentation(str(file_path)) 186 | text = [] 187 | for slide_idx, slide in enumerate(pres.slides): 188 | text.append(f"Slide {slide_idx + 1}:\n") 189 | for shape in slide.shapes: 190 | if hasattr(shape, "text"): 191 | text.append(shape.text) 192 | return "\n".join(text) 193 | except Exception as e: 194 | logger.info(f"Error reading PowerPoint file: {e}") 195 | return "Error reading PowerPoint file." 196 | 197 | class ExcelReader(Reader): 198 | def parse(self, file_path: Path) -> str: 199 | logger.info(f"Reading Excel file from {file_path}.") 200 | try: 201 | excel_data = pd.read_excel(file_path, sheet_name=None) 202 | 203 | all_sheets_text = [] 204 | for sheet_name, data in excel_data.items(): 205 | all_sheets_text.append(f"Sheet Name: {sheet_name}\n{data.to_string()}\n") 206 | 207 | return "\n".join(all_sheets_text) 208 | except Exception as e: 209 | logger.info(f"Error reading Excel file: {e}") 210 | return "Error reading Excel file." 211 | 212 | class XLSXReader(Reader): 213 | def parse(self, file_path: Path) -> str: 214 | logger.info(f"Reading XLSX file from {file_path}.") 215 | workbook = openpyxl.load_workbook(file_path, data_only=True) 216 | text = "" 217 | 218 | for sheet in workbook: 219 | text += f"\nSheet: {sheet.title}\n" 220 | for row in sheet.iter_rows(values_only=True): 221 | row_data = [str(cell) if cell is not None else "" for cell in row] 222 | text += "\t".join(row_data) + "\n" 223 | 224 | return text 225 | 226 | class ZipReader(Reader): 227 | def parse(self, file_path: str) -> Optional[str]: 228 | #only support files that can be represented as text 229 | logger.info(f"Reading ZIP file from {file_path}.") 230 | try: 231 | file_content = "" 232 | with zipfile.ZipFile(file_path, 'r') as zip_ref: 233 | extract_dir = file_path[:-4] + '/' 234 | zip_ref.extractall(extract_dir) 235 | reader = FileReader() 236 | for file_name in zip_ref.namelist(): 237 | file_content += f'File {file_name}:\n"{reader.read_file(extract_dir + file_name)}"\n' 238 | return file_content 239 | 240 | except zipfile.BadZipFile: 241 | logger.info("Invalid ZIP file.") 242 | 243 | except Exception as e: 244 | logger.info(f"Error reading ZIP file: {e}") 245 | 246 | 247 | class PythonReader(Reader): 248 | def parse(self, file_path: Path): 249 | logger.info(f"Executing and reading Python file from {file_path}.") 250 | execution_result = "" 251 | error = "" 252 | file_content = "" 253 | try: 254 | completed_process = subprocess.run(["python", file_path], capture_output=True, text=True, check=True) 255 | execution_result = "Output:\n" + completed_process.stdout 256 | except subprocess.CalledProcessError as e: 257 | error = "Error:\n" + e.stderr 258 | except Exception as e: 259 | logger.info(f"Error executing Python file: {e}") 260 | 261 | try: 262 | with open(file_path, "r") as file: 263 | file_content = "\nFile Content:\n" + file.read() 264 | except Exception as e: 265 | logger.info(f"Error reading Python file: {e}") 266 | return file_content, execution_result, error 267 | 268 | 269 | class IMGReader(Reader): 270 | def parse(self, file_path: str, task: str = "Describe this image as detail as possible." ): 271 | # logger.info(f"Reading image file from {file_path}.") 272 | # runner = VisualLLMRegistry.get() 273 | # answer = runner.gen(task, file_path) 274 | # return answer 275 | return "" 276 | 277 | class VideoReader(Reader): 278 | def parse(self, file_path: str, task: str = "Describe this image as detail as possible.", frame_interval: int = 30, used_audio: bool = True): 279 | # logger.info(f"Processing video file from {file_path} with frame interval {frame_interval}.") 280 | # runner = VisualLLMRegistry.get() 281 | # answer = runner.gen_video(task, file_path, frame_interval) 282 | 283 | # if used_audio: 284 | # audio_content = AudioReader.parse(file_path) 285 | 286 | # return answer + "The audio includes:\n" + audio_content 287 | return "" 288 | 289 | 290 | # Support 41 kinds of files. 291 | READER_MAP = { 292 | ".png": IMGReader(), 293 | ".jpg": IMGReader(), 294 | ".jpeg": IMGReader(), 295 | ".gif": IMGReader(), 296 | ".bmp": IMGReader(), 297 | ".tiff": IMGReader(), 298 | ".tif": IMGReader(), 299 | ".webp": IMGReader(), 300 | ".mp3": AudioReader(), 301 | ".m4a": AudioReader(), 302 | ".wav": AudioReader(), 303 | ".MOV": VideoReader(), 304 | ".mp4": VideoReader(), 305 | ".mov": VideoReader(), 306 | ".avi": VideoReader(), 307 | ".mpg": VideoReader(), 308 | ".mpeg": VideoReader(), 309 | ".wmv": VideoReader(), 310 | ".flv": VideoReader(), 311 | ".webm": VideoReader(), 312 | ".zip": ZipReader(), 313 | ".pptx": PPTXReader(), 314 | ".xlsx": ExcelReader(), 315 | ".xls": ExcelReader(), 316 | ".txt": TXTReader(), 317 | ".csv": TXTReader(), 318 | ".pdf": PDFReader(), 319 | ".docx": DOCXReader(), 320 | ".json": JSONReader(), 321 | ".jsonld": JSONReader(), 322 | ".jsonl": JSONLReader(), 323 | ".xml": XMLReader(), 324 | ".yaml": YAMLReader(), 325 | ".yml": YAMLReader(), 326 | ".html": HTMLReader(), 327 | ".htm": HTMLReader(), 328 | ".xhtml": HTMLReader(), 329 | ".md": MarkdownReader(), 330 | ".markdown": MarkdownReader(), 331 | ".tex": LaTexReader(), 332 | ".py": PythonReader(), 333 | ".pdb": TXTReader(), 334 | } 335 | 336 | class FileReader: 337 | def set_reader(self, suffix) -> None: 338 | self.reader = READER_MAP[suffix] 339 | logger.info(f"Setting Reader to {type(self.reader).__name__}") 340 | 341 | def read_file(self, file_path: str, task="describe the file")->str: 342 | suffix = '.' + file_path.split(".")[-1] 343 | self.set_reader(suffix) 344 | if isinstance(self.reader, IMGReader) or isinstance(self.reader, VideoReader): 345 | file_content = self.reader.parse(file_path, task) 346 | else: 347 | file_content = self.reader.parse(file_path) 348 | logger.info(f"Reading file {file_path} using {type(self.reader).__name__}") 349 | return file_content 350 | 351 | 352 | class GeneralReader: 353 | def __init__(self): 354 | self.file_reader = FileReader() 355 | self.name = "General File Reader" 356 | self.description = """A general file reader support to formats: 'py', 'java', 'cpp', 'c', 'js', 357 | 'css', 'html', 'htm', 'xml', 'txt', 'jsonl', 'csv', 'json', 358 | 'jsonld', 'jsonl', 'yaml', 'yml', 'xlsx', 'xls', 'jpg', 'png', 359 | 'jpeg', 'gif', 'bmp', 'mp3', 'wav', 'ogg', 'mp4', 'avi', 'mkv', 360 | 'mov', 'pdf', 'doc', 'docx', 'ppt', 'pptx', 'md', 'markdown', 361 | 'tex', 'zip', 'tar', 'gz', '7z', 'rar'. 362 | """ 363 | 364 | def read(self, task, file): 365 | 366 | files_content = "" 367 | file_content = self.file_reader.read_file(file, task) 368 | suffix = file.split(".")[-1] 369 | 370 | if suffix in ['py', 'java', 'cpp', 'c', 'js', 'css', 'html', 'htm', 'xml']: 371 | files_content += f'\nThe {suffix} file contains:\n---\n{file_content[0]}' 372 | if file_content[1] != '': 373 | files_content += f'\nExecution result:\n{file_content[1]}' 374 | if file_content[2] != '': 375 | files_content += f'\nExecution error message:\n{file_content[2]}' 376 | files_content += '\n---' 377 | 378 | elif suffix in ['txt', 'jsonl', 'csv', 'json', 'jsonld', 'jsonl', 'yaml', 'yml', 379 | 'xlsx', 'xls', 'jpg', 'png', 'jpeg', 'gif', 'bmp', 'mp3', 'wav', 380 | 'ogg', 'mp4', 'avi', 'mkv', 'mov', 'pdf', 'doc', 'docx', 'ppt', 381 | 'pptx', 'md', 'markdown', 'tex', 'zip', 'tar', 'gz', '7z', 'rar']: 382 | files_content += f'\nThe {suffix} file contains:\n---\n{file_content}\n---' 383 | 384 | return files_content 385 | -------------------------------------------------------------------------------- /MAR/Tools/search/arXiv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import arxiv 5 | 6 | 7 | class ArxivSearch: 8 | def __init__(self): 9 | self.name = "ArXiv Searcher" 10 | self.description = "Search for a paper on ArXiv" 11 | 12 | def search(self, query=None, id_list=None, sort_by=arxiv.SortCriterion.Relevance, sort_order=arxiv.SortOrder.Descending): 13 | search = arxiv.Search(query=query, id_list=id_list, max_results=1, sort_by=sort_by, sort_order=sort_order) 14 | results = arxiv.Client().results(search) 15 | paper = next(results, None) 16 | 17 | return paper 18 | -------------------------------------------------------------------------------- /MAR/Tools/search/search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | from dotenv import load_dotenv 6 | from googleapiclient.discovery import build 7 | import requests 8 | import ast 9 | load_dotenv() 10 | 11 | 12 | class GoogleSearchEngine(): 13 | def __init__(self) -> None: 14 | load_dotenv() 15 | self.api_key = os.getenv("GOOGLE_API_KEY") 16 | self.cse_id = os.getenv("GOOGLE_CSE_ID") 17 | self.service = build("customsearch", "v1", developerKey=self.api_key) 18 | 19 | def search(self, query: str, num: int = 3): 20 | try: 21 | res = self.service.cse().list(q=query, cx=self.cse_id, num=num).execute() 22 | return '\n'.join([item['snippet'] for item in res['items']]) 23 | except: 24 | return '' 25 | 26 | 27 | class SearchAPIEngine(): 28 | 29 | def search(self, query: str, item_num: int = 3): 30 | try: 31 | url = "https://www.searchapi.io/api/v1/search" 32 | params = { 33 | "engine": "google", 34 | "q": query, 35 | "api_key": os.getenv("SEARCHAPI_API_KEY") 36 | } 37 | 38 | response = ast.literal_eval(requests.get(url, params = params).text) 39 | 40 | except: 41 | return '' 42 | 43 | if 'knowledge_graph' in response.keys() and 'description' in response['knowledge_graph'].keys(): 44 | return response['knowledge_graph']['description'] 45 | if 'organic_results' in response.keys() and len(response['organic_results']) > 0: 46 | 47 | return '\n'.join([res['snippet'] for res in response['organic_results'][:item_num]]) 48 | return '' 49 | 50 | 51 | 52 | if __name__ == "__main__": 53 | print(SearchAPIEngine().search("Juergen Schmidhuber")) -------------------------------------------------------------------------------- /MAR/Tools/search/wiki.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import wikipedia 5 | import asyncio 6 | 7 | class WikiSearch: 8 | def __init__(self): 9 | self.name = "Wikipedia SearchEngine" 10 | self.description = "Seach for an item in Wikipedia" 11 | 12 | def search(self, query): 13 | result = wikipedia.search(query[:300], results=1, suggestion=True) 14 | print(result) 15 | if len(result[0]) != 0: 16 | return wikipedia.page(title=result[0]).content 17 | 18 | if result[1] is not None: 19 | result = wikipedia.search(result[1], results=1) 20 | return wikipedia.page(title=result[0]).content 21 | 22 | return None 23 | 24 | async def get_wikipedia_summary(title): 25 | try: 26 | wikipedia.set_lang("en") 27 | summ = wikipedia.summary(title) 28 | return summ 29 | except wikipedia.exceptions.DisambiguationError as e: 30 | return await get_wikipedia_summary(e.options[0]) 31 | except wikipedia.exceptions.PageError: 32 | return "" 33 | 34 | async def search_wiki(query): 35 | wikipedia.set_lang("en") 36 | result = wikipedia.search(query, results=2, suggestion=True) 37 | print(result) 38 | ret = "" 39 | tasks = [] 40 | 41 | if len(result[0]) != 0: 42 | for res in result[0]: 43 | tasks.append(get_wikipedia_summary(res)) 44 | summaries = await asyncio.gather(*tasks) 45 | for res, summa in zip(result[0], summaries): 46 | if len(summa): 47 | ret += f"The summary of {res} in Wikipedia is: {summa}\n" 48 | if result[1] is not None: 49 | summa = await get_wikipedia_summary(result[1]) 50 | if len(summa): 51 | ret += f"The summary of {result[1]} in Wikipedia is: {summa}\n" 52 | return ret 53 | 54 | 55 | async def search_wiki_main(queries): 56 | tasks = [search_wiki(query) for query in queries] 57 | results = await asyncio.gather(*tasks) 58 | return results 59 | 60 | if __name__ == "__main__": 61 | queries = ["Python", "Asyncio", "Wikipedia"] 62 | asyncio.run(search_wiki_main(queries)) -------------------------------------------------------------------------------- /MAR/Tools/vgen/dalle3.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from https://github.com/abi/screenshot-to-code/blob/5e3a174203dd6e59603c2fa944b14c7b398bfade/backend/image_generation.py 2 | #!/usr/bin/env python 3 | # -*- coding: utf-8 -*- 4 | 5 | import asyncio 6 | import os 7 | import re 8 | from openai import AsyncOpenAI 9 | from bs4 import BeautifulSoup 10 | 11 | 12 | async def process_tasks(prompts, api_key): 13 | tasks = [generate_image(prompt, api_key) for prompt in prompts] 14 | results = await asyncio.gather(*tasks, return_exceptions=True) 15 | 16 | processed_results = [] 17 | for result in results: 18 | if isinstance(result, Exception): 19 | print(f"An exeception occured: {result}") 20 | processed_results.append(None) 21 | else: 22 | processed_results.append(result) 23 | 24 | return processed_results 25 | 26 | 27 | async def generate_image(prompt, api_key): 28 | client = AsyncOpenAI(api_key=api_key) 29 | image_params = { 30 | "model": "dall-e-3", 31 | "quality": "standard", 32 | "style": "natural", 33 | "n": 1, 34 | "size": "1024x1024", 35 | "prompt": prompt, 36 | } 37 | res = await client.images.generate(**image_params) 38 | return res.data[0].url 39 | 40 | 41 | def extract_dimensions(url): 42 | # Regular expression to match numbers in the format '300x200' 43 | matches = re.findall(r"(\d+)x(\d+)", url) 44 | 45 | if matches: 46 | width, height = matches[0] # Extract the first match 47 | width = int(width) 48 | height = int(height) 49 | return (width, height) 50 | else: 51 | return (100, 100) 52 | 53 | 54 | def create_alt_url_mapping(code): 55 | soup = BeautifulSoup(code, "html.parser") 56 | images = soup.find_all("img") 57 | 58 | mapping = {} 59 | 60 | for image in images: 61 | if not image["src"].startswith("https://placehold.co"): 62 | mapping[image["alt"]] = image["src"] 63 | 64 | return mapping 65 | 66 | 67 | async def generate_images(code, api_key, image_cache): 68 | # Fine all images 69 | soup = BeautifulSoup(code, "html.parser") 70 | images = soup.find_all("img") 71 | 72 | # Extract alt texts as image prompts 73 | alts = [] 74 | for img in images: 75 | # Only include URL if the image starts with htt[s://placehold.co 76 | # and it's not already in the image_cache 77 | if ( 78 | img["src"].startswith("https://placehold.co") 79 | and image_cache.get(img.get("alt")) is None 80 | ): 81 | alts.append(img.get("alt", None)) 82 | 83 | # Exclude images with no alt text 84 | alts = [alt for alt in alts if alt is not None] 85 | 86 | # Remove deplicates 87 | prompts = list(set(alts)) 88 | 89 | # Return early if there are no images to replace 90 | if len(prompts) == 0: 91 | return code 92 | 93 | # Generate images 94 | results = await process_tasks(prompts, api_key) 95 | 96 | # Create a dict mapping alt text to image URL 97 | mapped_image_urls = dict(zip(prompts, results)) 98 | 99 | # Merge with image_cache 100 | mapped_image_urls = {**mapped_image_urls, **image_cache} 101 | 102 | # Replace old image URLs with the generated URLs 103 | for img in images: 104 | # Skip images that don't start with https://placehold.co (leave them alone) 105 | if not img["src"].startswith("https://placehold.co"): 106 | continue 107 | 108 | new_url = mapped_image_urls[img.get("alt")] 109 | 110 | if new_url: 111 | # Set width and height attributes 112 | width, height = extract_dimensions(img["src"]) 113 | img["width"] = width 114 | img["height"] = height 115 | # Replace img['src'] with the mapped image URL 116 | img["src"] = new_url 117 | else: 118 | print("Image generation failed for alt text:" + img.get("alt")) 119 | 120 | # Return the modified HTML 121 | # (need to prettify it because BeautifulSoup messes up the formatting) 122 | return soup.prettify() 123 | -------------------------------------------------------------------------------- /MAR/Tools/web/screenshot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Ref: https://github.com/abi/screenshot-to-code/blob/main/backend/routes/screenshot.py 5 | """ 6 | 7 | 8 | import base64 9 | from fastapi import APIRouter 10 | from pydantic import BaseModel 11 | import httpx 12 | 13 | router = APIRouter() 14 | 15 | def bytes_to_data_url(image_bytes: bytes, mime_type: str) -> str: 16 | base64_image = base64.b64encode(image_bytes).decode("utf-8") 17 | return f"data:{mime_type};base64,{base64_image}" 18 | 19 | 20 | async def capture_screenshot(target_url, api_key, device="desktop") -> bytes: 21 | api_base_url = "https://api.screenshotone.com/take" 22 | 23 | params = { 24 | "access_key": api_key, 25 | "url": target_url, 26 | "full_page": "true", 27 | "device_scale_factor": "1", 28 | "format": "png", 29 | "block_ads": "true", 30 | "block_cookie_banners": "true", 31 | "block_trackers": "true", 32 | "cache": "false", 33 | "viewport_width": "342", 34 | "viewport_height": "684", 35 | } 36 | 37 | if device == "desktop": 38 | params["viewport_width"] = "1280" 39 | params["viewport_height"] = "832" 40 | 41 | async with httpx.AsyncClient(timeout=60) as client: 42 | response = await client.get(api_base_url, params=params) 43 | if response.status_code == 200 and response.content: 44 | return response.content 45 | else: 46 | raise Exception("Error taking screenshot") 47 | 48 | 49 | class ScreenshotRequest(BaseModel): 50 | url: str 51 | apiKey: str 52 | 53 | 54 | class ScreenshotResponse(BaseModel): 55 | url: str 56 | 57 | -------------------------------------------------------------------------------- /MAR/Tools/web/youtube.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from pytube import YouTube 5 | from MAR.Utils.const import MAR_ROOT 6 | 7 | def Youtube(url, has_subtitles): 8 | # get video id from url 9 | video_id=url.split('v=')[-1].split('&')[0] 10 | # Create a YouTube object 11 | youtube = YouTube(url) 12 | # Get the best available video stream 13 | video_stream = youtube.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() 14 | if has_subtitles: 15 | # Download the video to a location 16 | print('Downloading video') 17 | video_stream.download(output_path=f"{MAR_ROOT}/workspace",filename=f"{video_id}.mp4") 18 | print('Video downloaded successfully') 19 | return f"{MAR_ROOT}/workspace/{video_id}.mp4" 20 | else: 21 | return video_stream.url -------------------------------------------------------------------------------- /MAR/Utils/const.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | from pathlib import Path 6 | 7 | 8 | MAR_ROOT = Path(os.path.realpath(os.path.join(os.path.split(__file__)[0], "../.."))) 9 | -------------------------------------------------------------------------------- /MAR/Utils/globals.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | from typing import Union, Literal, List 4 | 5 | class Singleton: 6 | _instance = None 7 | 8 | @classmethod 9 | def instance(cls): 10 | if cls._instance is None: 11 | cls._instance = cls() 12 | return cls._instance 13 | 14 | def reset(self): 15 | self.value = 0.0 16 | 17 | class Cost(Singleton): 18 | def __init__(self): 19 | self.value = 0.0 20 | 21 | class PromptTokens(Singleton): 22 | def __init__(self): 23 | self.value = 0.0 24 | 25 | class CompletionTokens(Singleton): 26 | def __init__(self): 27 | self.value = 0.0 28 | 29 | class Time(Singleton): 30 | def __init__(self): 31 | self.value = "" 32 | 33 | class Mode(Singleton): 34 | def __init__(self): 35 | self.value = "" 36 | -------------------------------------------------------------------------------- /MAR/Utils/log.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | from pathlib import Path 7 | from loguru import logger 8 | from MAR.Utils.const import MAR_ROOT 9 | 10 | def configure_logging(print_level: str = "INFO", logfile_level: str = "DEBUG", log_name:str = "log.txt") -> None: 11 | """ 12 | Configure the logging settings for the application. 13 | 14 | Args: 15 | print_level (str): The logging level for console output. 16 | logfile_level (str): The logging level for file output. 17 | """ 18 | logger.remove() 19 | logger.add(sys.stderr, level=print_level) 20 | logger.add(MAR_ROOT /f'logs/{log_name}', level=logfile_level) 21 | 22 | def initialize_log_file(experiment_name: str, time_stamp: str) -> Path: 23 | """ 24 | Initialize the log file with a start message and return its path. 25 | 26 | Args: 27 | mode (str): The mode of operation, used in the file path. 28 | time_stamp (str): The current timestamp, used in the file path. 29 | 30 | Returns: 31 | Path: The path to the initialized log file. 32 | """ 33 | try: 34 | log_file_path = MAR_ROOT / f'result/{experiment_name}/logs/log_{time_stamp}.txt' 35 | os.makedirs(log_file_path.parent, exist_ok=True) 36 | with open(log_file_path, 'w') as file: 37 | file.write("============ Start ============\n") 38 | except OSError as error: 39 | logger.error(f"Error initializing log file: {error}") 40 | raise 41 | return log_file_path 42 | 43 | def swarmlog(sender: str, text: str, cost: float, prompt_tokens: int, complete_tokens: int, log_file_path: str) -> None: 44 | """ 45 | Custom log function for swarm operations. Includes dynamic global variables. 46 | 47 | Args: 48 | sender (str): The name of the sender. 49 | text (str): The text message to log. 50 | cost (float): The cost associated with the operation. 51 | result_file (Path, optional): Path to the result file. Default is None. 52 | solution (list, optional): Solution data to be logged. Default is an empty list. 53 | """ 54 | # Directly reference global variables for dynamic values 55 | formatted_message = ( 56 | f"{sender} | 💵Total Cost: ${cost:.5f} | " 57 | f"Prompt Tokens: {prompt_tokens} | " 58 | f"Completion Tokens: {complete_tokens} | \n {text}" 59 | ) 60 | print(formatted_message) 61 | 62 | try: 63 | os.makedirs(log_file_path.parent, exist_ok=True) 64 | with open(log_file_path, 'a') as file: 65 | file.write(f"{formatted_message}\n") 66 | except OSError as error: 67 | logger.error(f"Error initializing log file: {error}") 68 | raise 69 | 70 | 71 | def main(): 72 | configure_logging() 73 | # Example usage of swarmlog with dynamic values 74 | swarmlog("SenderName", "This is a test message.", 0.123) 75 | 76 | if __name__ == "__main__": 77 | main() 78 | 79 | -------------------------------------------------------------------------------- /MAR/Utils/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import shortuuid 4 | from collections import Counter 5 | import random 6 | import os 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import seaborn as sns 10 | 11 | from typing import List, Union, Literal, Optional 12 | 13 | ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") 14 | INVALID_ANS = "[invalid]" 15 | 16 | N_SHOT = 8 17 | COT_FLAG = True 18 | DEBUG = False 19 | ANSWER_TRIGGER = "The answer is" 20 | 21 | def extract_answer_from_output(completion): 22 | match = ANS_RE.search(completion) 23 | if match: 24 | match_str = match.group(1).strip() 25 | match_str = match_str.replace(",", "") 26 | return match_str 27 | else: 28 | return INVALID_ANS 29 | 30 | def is_correct(model_answer, answer): 31 | gt_answer = extract_answer_from_output(answer) 32 | assert gt_answer != INVALID_ANS 33 | return model_answer == gt_answer 34 | 35 | def clean_answer(model_pred): 36 | model_pred = model_pred.lower() 37 | preds = model_pred.split(ANSWER_TRIGGER.lower()) 38 | answer_flag = True if len(preds) > 1 else False 39 | if answer_flag: 40 | # Pick first answer with flag 41 | pred = preds[1] 42 | else: 43 | # Pick last number without flag 44 | pred = preds[-1] 45 | 46 | pred = pred.replace(",", "") 47 | pred = [s for s in re.findall(r"-?\d+\.?\d*", pred)] 48 | 49 | if len(pred) == 0: 50 | return INVALID_ANS 51 | 52 | if answer_flag: 53 | # choose the first element in list 54 | pred = pred[0] 55 | else: 56 | # choose the last element in list 57 | pred = pred[-1] 58 | 59 | # (For arithmetic tasks) if a word ends with period, it will be omitted ... 60 | if pred[-1] == ".": 61 | pred = pred[:-1] 62 | 63 | return pred 64 | 65 | def nuclear_norm(matrix): 66 | _, S, _ = torch.svd(matrix) 67 | return torch.sum(S) 68 | 69 | def frobenius_norm(A, S): 70 | return torch.norm(A - S, p='fro') 71 | 72 | used_ids = set() 73 | def generate_unique_ids(n:int=1,pre:str="",length:int=4)->List[str]: 74 | ids = set() 75 | while len(ids) < n: 76 | random_id = shortuuid.ShortUUID().random(length=length) 77 | if pre: 78 | random_id = f"{pre}_{random_id}" 79 | if pre in used_ids: 80 | length += 1 81 | continue 82 | ids.add(random_id) 83 | return list(ids) 84 | 85 | def extract_json(raw:str)->str: 86 | """ 87 | Extract the json string from the raw string. 88 | If there is no json string, return an empty string. 89 | """ 90 | json_pattern = r'\{.*\}' 91 | match = re.search(json_pattern, raw, re.DOTALL) 92 | return match.group(0) if match else "" 93 | 94 | def fix_random_seed(seed:int=1234): 95 | torch.manual_seed(seed) 96 | torch.cuda.manual_seed(seed) 97 | torch.cuda.manual_seed_all(seed) 98 | torch.backends.cudnn.deterministic = True 99 | torch.backends.cudnn.benchmark = False 100 | 101 | def find_mode(nums): 102 | count = Counter(nums) 103 | mode, _ = count.most_common(1)[0] 104 | return mode 105 | 106 | 107 | def get_kwargs(mode:Union[Literal['DirectAnswer','CoT','IO','FullConnected','Random','Chain','Debate','Layered','Star'],str] 108 | ,N:int): 109 | initial_spatial_probability: float = 0.5 110 | fixed_spatial_masks: Optional[List[List[int]]] = None 111 | initial_temporal_probability: float = 0.5 112 | fixed_temporal_masks:Optional[List[List[int]]] = None 113 | node_kwargs = None 114 | num_rounds = 1 115 | # agent_names = [] 116 | 117 | def generate_layered_graph(N,layer_num=2): 118 | adj_matrix = [[0 for _ in range(N)] for _ in range(N)] 119 | base_size = N // layer_num 120 | remainder = N % layer_num 121 | layers = [] 122 | for i in range(layer_num): 123 | size = base_size + (1 if i < remainder else 0) 124 | layers.extend([i] * size) 125 | random.shuffle(layers) 126 | for i in range(N): 127 | current_layer = layers[i] 128 | for j in range(N): 129 | if layers[j] == current_layer + 1: 130 | adj_matrix[i][j] = 1 131 | return adj_matrix 132 | 133 | def generate_star_graph(n): 134 | matrix = [[0] * n for _ in range(n)] 135 | for i in range(0, n): 136 | for j in range(i+1,n): 137 | matrix[i][j] = 1 138 | return matrix 139 | 140 | if mode=='DirectAnswer' or mode=='CoT' or mode=='IO' or mode=='Reflection': 141 | fixed_spatial_masks = [[0 for _ in range(N)] for _ in range(N)] 142 | fixed_temporal_masks = [[0 for _ in range(N)] for _ in range(N)] 143 | elif mode=='FullConnected': 144 | fixed_spatial_masks = [[1 if i!=j else 0 for i in range(N)] for j in range(N)] 145 | fixed_temporal_masks = [[1 for _ in range(N)] for _ in range(N)] 146 | elif mode=='Random': 147 | fixed_spatial_masks = [[random.randint(0, 1) if i!=j else 0 for i in range(N)] for j in range(N)] 148 | fixed_temporal_masks = [[random.randint(0, 1) for _ in range(N)] for _ in range(N)] 149 | elif mode=='Chain': 150 | fixed_spatial_masks = [[1 if i==j+1 else 0 for i in range(N)] for j in range(N)] 151 | fixed_temporal_masks = [[1 if i==0 and j==N-1 else 0 for i in range(N)] for j in range(N)] 152 | elif mode == 'Debate': 153 | fixed_spatial_masks = [[0 for i in range(N)] for j in range(N)] 154 | fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)] 155 | num_rounds = 2 156 | elif mode == 'Layered': 157 | fixed_spatial_masks = generate_layered_graph(N) 158 | fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)] 159 | elif mode == 'Star': 160 | fixed_spatial_masks = generate_star_graph(N) 161 | fixed_temporal_masks = [[1 for i in range(N)] for j in range(N)] 162 | 163 | 164 | return {"initial_spatial_probability": initial_spatial_probability, 165 | "fixed_spatial_masks": fixed_spatial_masks, 166 | "initial_temporal_probability": initial_temporal_probability, 167 | "fixed_temporal_masks": fixed_temporal_masks, 168 | "node_kwargs":node_kwargs, 169 | "num_rounds":num_rounds,} 170 | 171 | def split_list(input_list, ratio): 172 | if not (0 < ratio < 1): 173 | raise ValueError("Ratio must be between 0 and 1.") 174 | 175 | random.shuffle(input_list) 176 | split_index = int(len(input_list) * ratio) 177 | part1 = input_list[:split_index] 178 | part2 = input_list[split_index:] 179 | 180 | return part1, part2 181 | 182 | def plot_embedding_heatmap(embedding: torch.Tensor, title: str, save_path: str): 183 | embedding_np = embedding.detach().cpu().numpy() 184 | 185 | plt.figure(figsize=(10, max(4, embedding_np.shape[0] * 0.4))) 186 | sns.heatmap(embedding_np, cmap="viridis", cbar=True) 187 | 188 | plt.title(title) 189 | plt.xlabel("Embedding Dimension") 190 | plt.ylabel("Index") 191 | 192 | plt.tight_layout() 193 | plt.savefig(save_path) 194 | plt.close() 195 | 196 | def plot_row_similarity(embeddings: torch.Tensor, title: str, save_path: str): 197 | embeddings_np = embeddings.detach().cpu().numpy() 198 | row_similarities = np.corrcoef(embeddings_np) 199 | 200 | plt.figure(figsize=(10, max(4, row_similarities.shape[0] * 0.4))) 201 | sns.heatmap(row_similarities, cmap="viridis", cbar=True) 202 | 203 | plt.title(title) 204 | plt.xlabel("Index") 205 | plt.ylabel("Index") 206 | 207 | plt.tight_layout() 208 | plt.savefig(save_path) 209 | plt.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ACL 2025] MasRouter: Learning to Route LLMs for Multi-Agent Systems 2 | 3 | ## 📰 News 4 | 5 | - 🎉 Updates (2025-5-15) MasRouter is accpected to ACL 2025 Main! 6 | - 🚩 Updates (2025-2-16) Initial upload to arXiv [PDF](https://arxiv.org/abs/2502.11133). 7 | 8 | 9 | ## 🤔 Why MasRouter? 10 | 11 | **MasRouter** expands LLM routing to the multi-agent systems (MAS) *for the first time*. It leverages the powerful reasoning capabilities of LLM MAS, while also making it relatively cost-effective. 12 | 13 | ![intro](assets/intro.png) 14 | 15 | ## 👋🏻 Method Overview 16 | 17 | **MasRouter** integrates all components of MAS into a unified routing framework. It employs collaboration mode determination, role allocation, and LLM routing through a cascaded controller network, progressively constructing a MAS that balances effectiveness and efficiency. 18 | 19 | ![pipeline](assets/pipeline.png) 20 | 21 | ## 🏃‍♂️‍➡️ Quick Start 22 | 23 | ### 📊 Datasets 24 | 25 | Please download the `GSM8K`, `HumanEval`, `MATH`, `MBPP`, `MMLU` datasets and place it in the `Datasets` folder. The file structure should be organized as follows: 26 | ``` 27 | Datasets 28 | └── gsm8k 29 | └── gsm8k.jsonl 30 | └── humaneval 31 | └── humaneval-py.jsonl 32 | └── MATH 33 | └── test 34 | └── train 35 | └── mbpp 36 | └── mbpp.jsonl 37 | └── MMLU 38 | └── data 39 | ``` 40 | 41 | ### 🔑 Add API keys 42 | 43 | Add API keys in `template.env` and change its name to `.env`. We recommend that this API be able to access multiple LLMs. 44 | ```python 45 | URL = "" # the URL of LLM backend 46 | KEY = "" # the key for API 47 | ``` 48 | 49 | ### 🐹 Run the code 50 | 51 | The code below verifies the experimental results of the `mbpp` dataset. 52 | 53 | ```bash 54 | python experiments/run_mbpp.py 55 | ``` 56 | 57 | ## 📚 Citation 58 | 59 | If you find this repo useful, please consider citing our paper as follows: 60 | ```bibtex 61 | @misc{yue2025masrouter, 62 | title={MasRouter: Learning to Route LLMs for Multi-Agent Systems}, 63 | author={Yanwei Yue and Guibin Zhang and Boyang Liu and Guancheng Wan and Kun Wang and Dawei Cheng and Yiyan Qi}, 64 | year={2025}, 65 | eprint={2502.11133}, 66 | archivePrefix={arXiv}, 67 | primaryClass={cs.LG}, 68 | url={https://arxiv.org/abs/2502.11133}, 69 | } 70 | ``` 71 | 72 | ## 🙏 Acknowledgement 73 | 74 | Special thanks to the following repositories for their invaluable code and datasets: 75 | 76 | - [MapCoder](https://github.com/Md-Ashraful-Pramanik/MapCoder) 77 | - [GPTSwarm](https://github.com/metauto-ai/GPTSwarm). 78 | -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanweiyue/masrouter/97849a3dcab21d2962e345551be7f3c6935c66d6/assets/intro.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanweiyue/masrouter/97849a3dcab21d2962e345551be7f3c6935c66d6/assets/pipeline.png -------------------------------------------------------------------------------- /template.env: -------------------------------------------------------------------------------- 1 | URL = "" 2 | KEY = "" --------------------------------------------------------------------------------