├── scripts ├── eval.sh └── train.sh ├── utils ├── __init__.py ├── data.py ├── utils.py └── prompts.py ├── configs └── config_list.json ├── LICENSE ├── README.md ├── .gitignore ├── main.py ├── main_wo_supr.py ├── main_ws.py ├── evaluate.py ├── self_consistency_1102.py └── self-refine.py /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | python evaluate.py \ 2 | --model_name x_gpt4o \ 3 | --dataset_name rare_disease_302 \ 4 | --stage inittal \ 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import MedDataset 2 | 3 | from .utils import prase_json, simple_retry 4 | from .prompts import ( 5 | get_doc_system_message, 6 | get_supervisor_system_message, 7 | get_inital_message, 8 | get_consultant_message, 9 | get_evaluate_prompts 10 | ) 11 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | 2 | # MAC 3 | python main.py \ 4 | --model_name x_gpt4o \ 5 | --stage inital \ 6 | --times 1 \ 7 | --num_doctors 4 \ 8 | --n_round 10 9 | 10 | # MAC with Specialist 11 | python main_ws.py \ 12 | --model_name x_gpt4o \ 13 | --stage inital \ 14 | --times 1 \ 15 | --num_specialists 4 \ 16 | --n_round 10 17 | 18 | 19 | # MAC without Supervisor 20 | python main_wo_supr.py \ 21 | --model_name x_gpt4o \ 22 | --stage inital \ 23 | --times 1 \ 24 | --num_doctors 4 \ 25 | --n_round 10 -------------------------------------------------------------------------------- /configs/config_list.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "model": "gpt-3.5-turbo-0125", 4 | "api_key": "your api key", 5 | "base_url": "If using a third party, please provide the third party's website", 6 | "tags": [ 7 | "x_gpt35_turbo", 8 | "x_gpt-35-turbo" 9 | ] 10 | }, 11 | { 12 | "model": "gpt-4-turbo-2024-04-09", 13 | "api_key": "your api key", 14 | "base_url": "If using a third party, please provide the third party's website", 15 | "tags": [ 16 | "x_gpt4_turbo" 17 | ] 18 | }, 19 | { 20 | "model": "gpt-4o", 21 | "api_key": "", 22 | "base_url": "", 23 | "tags": [ 24 | "x_gpt4o" 25 | ] 26 | }, 27 | { 28 | "model": "llama3.1", 29 | "api_key": "NotRequired", 30 | "base_url": "http://0.0.0.0:4000", 31 | "tags": [ 32 | "llama3.1" 33 | ] 34 | } 35 | ] -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as osp 3 | 4 | 5 | class MedDataset: 6 | 7 | dataset_dir = "dataset" # the directory where the dataset is stored 8 | 9 | def __init__(self, dataname: str="rare_disease_cases_302"): 10 | dataname = f"{dataname}.json" 11 | self.data_path = osp.join(self.dataset_dir, dataname) 12 | self.cases = None 13 | self.load() 14 | 15 | 16 | def __len__(self): 17 | return len(self.cases) 18 | 19 | def load(self): 20 | with open(self.data_path, "r") as file: 21 | data = json.load(file) 22 | self.cases = data["Cases"] 23 | 24 | def __getitem__(self, idx: int): 25 | case = self.cases[idx] 26 | disease_type = case["Type"] 27 | disease_name = case["Final Name"] 28 | disease_crl = case["Case URL"] 29 | disease_initial_presentation = case["Initial Presentation"] 30 | disease_follow_up_presentation = case["Follow-up Presentation"] 31 | 32 | return disease_type, disease_name, disease_crl, disease_initial_presentation, disease_follow_up_presentation -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 geteff1 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | import json 4 | 5 | from functools import wraps 6 | 7 | 8 | def prase_json(text): 9 | flag = False 10 | if "```json" in text: 11 | json_match = re.search(r"```json(.*?)```", text, re.DOTALL) 12 | if json_match: 13 | json_str = json_match.group(1).strip() 14 | json_data = json.loads(json_str) 15 | flag = True 16 | elif "```JSON" in text: 17 | json_match = re.search(r"```JSON(.*?)```", text, re.DOTALL) 18 | if json_match: 19 | json_str = json_match.group(1).strip() 20 | json_data = json.loads(json_str) 21 | flag = True 22 | elif "```" in text: 23 | json_match = re.search(r"```(.*?)```", text, re.DOTALL) 24 | if json_match: 25 | json_str = json_match.group(1).strip() 26 | json_data = json.loads(json_str) 27 | flag = True 28 | else: 29 | json_match = re.search(r"{.*?}", text, re.DOTALL) 30 | if json_match: 31 | json_str = json_match.group(0).strip() 32 | json_data = json.loads(json_str) 33 | flag = True 34 | if not flag: 35 | json_text = text.strip("```json\n").strip("\n```") 36 | json_data = json.loads(json_text) 37 | return json_data 38 | 39 | 40 | def simple_retry(max_attempts=100, delay=1): 41 | def decorator(func): 42 | @wraps(func) 43 | def wrapper(*args, **kwargs): 44 | for attempt in range(max_attempts): 45 | try: 46 | return func(*args, **kwargs) 47 | except Exception as e: 48 | if attempt < max_attempts - 1: 49 | print( 50 | f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {delay} second..." 51 | ) 52 | time.sleep(delay) 53 | else: 54 | print( 55 | f"All {max_attempts} attempts failed. Last error: {str(e)}" 56 | ) 57 | raise 58 | 59 | return wrapper 60 | 61 | return decorator 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-agent-conversation-for-disease-diagnosis 2 | 3 | ## Introduction 4 | 5 | This repository presents a novel multi-agent conversation framework designed to enhance the capabilities of Large Language Models (LLMs) in diagnosing complex diseases. Our approach, structured under the Autogen framework, allows for in-depth conversations among LLMs, paving the way for more accurate and nuanced disease diagnosis. 6 | 7 | ## Preprint Article 8 | 9 | Our work has been documented in a preprint article titled "One is Not Enough: Multi-Agent Conversation Framework Enhances Rare Disease Diagnostic Capabilities of Large Language Models". For more insights into our study, please visit: https://www.researchsquare.com/article/rs-3757148/v1 10 | 11 | Multi Agent Conversation Flow 12 | ![image](https://github.com/geteff1/Multi-agent-conversation-for-disease-diagnosis/assets/148701415/357585db-30b8-487d-83f6-1d8640e9ec38) 13 | 14 | In 2024-08-26 updates: 1.You may vary the number of doctor agents; 2.You may exclude the supervisor agent; 3.You may assign case specific clinical specialty to doctor agents; 4.You may change the base model of the framework. 15 | **Test Dataset** 16 | 17 | 302 disease cases were retrieved. Each case was curated as primary consultation and follow-up consultation to test the effectiveness of LLMs in actual clincial scenarios. 18 | ![Figure 2](https://github.com/geteff1/Multi-agent-conversation-for-disease-diagnosis/assets/148701415/8762cb39-adaf-42a9-b123-9aef73e578bc) 19 | 20 | ## Runtime Estimate 21 | 22 | The estimated time to run a single case using our framework is approximately 5-10 minutes, varying slightly based on system specifications and network conditions. 23 | 24 | ## Setup 25 | * Install anaconda: https://www.anaconda.com/distribution/ 26 | * set up conda environment w/ python 3.8, ex: 27 | * `conda create --name mac python=3.8` 28 | * `conda activate mac` 29 | * `pip install pyautogen==0.2.32` 30 | 31 | 32 | ## Training 33 | You should first set your API, proxy, and corresponding model list in **configs/config_list.json**. 34 | ```bash 35 | { 36 | "model": "gpt-4o", # It can be OpenAI's models, or others such as Claude, Gemini, LLaMA 3.1, etc. 37 | "api_key": "", # your API 38 | "base_url": "", # base URL 39 | "tags": [ 40 | "x_gpt4o" 41 | ] # You can assign different tags to different models. 42 | }, 43 | ``` 44 | You can also use locally deployed models, such as oLlama and LiteLLM together. For more details, see "https://microsoft.github.io/autogen/docs/topics/non-openai-models/local-litellm-ollama." 45 | ```bash 46 | { 47 | "model": "llama3.1", 48 | "api_key": "NotRequired", 49 | "base_url": "http://0.0.0.0:4000", 50 | "tags": [ 51 | "llama3.1" 52 | ] 53 | } 54 | ``` 55 | 56 | All commands should be run under the project root directory. 57 | 58 | ```bash 59 | sh scripts/train.sh 60 | ``` 61 | 62 | ## Evaluation 63 | All commands should be run under the project root directory. 64 | 65 | ```bash 66 | bash scripts/eval.sh 67 | ``` 68 | 69 | ## Results 70 | Results will be saved in a folder named `outputs/`. 71 | 72 | ## Contributing 73 | 74 | We welcome contributions to this project. If you have suggestions for improvements or want to report issues, please feel free to open an issue or submit a pull request. 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import time 4 | import json 5 | import argparse 6 | 7 | import os.path as osp 8 | from tqdm import tqdm 9 | 10 | from autogen import ( 11 | GroupChat, 12 | UserProxyAgent, 13 | GroupChatManager, 14 | AssistantAgent, 15 | config_list_from_json, 16 | ) 17 | 18 | from utils import * 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="Medagents Setting") 22 | parser.add_argument( 23 | "--config", 24 | type=str, 25 | default="configs/config_list.json", 26 | help="the llm models' config file", 27 | ) 28 | parser.add_argument( 29 | "--model_name", 30 | type=str, 31 | default="x_gpt35_turbo", 32 | choices=["x_gpt35_turbo", "x_gpt4_turbo", "x_gpt4o"], 33 | help="the llm models", 34 | ) 35 | parser.add_argument( 36 | "--dataset_name", 37 | type=str, 38 | default="rare_disease_302", 39 | choices=["rare_disease_302"], 40 | help="choice different dataset", 41 | ) 42 | parser.add_argument( 43 | "--stage", 44 | type=str, 45 | default="inital", 46 | choices=["inital", "follow_up"], 47 | help="choice different stages", 48 | ) 49 | parser.add_argument( 50 | "--times", 51 | type=int, 52 | default=1, 53 | choices=[1, 2, 3], 54 | help="choice different stages", 55 | ) 56 | parser.add_argument( 57 | "--output_dir", 58 | type=str, 59 | default="output", 60 | help="log file", 61 | ) 62 | parser.add_argument("--num_doctors", type=int, default=3, help="number of experts") 63 | parser.add_argument("--n_round", type=int, default=13, help="attempt_vote") 64 | 65 | args = parser.parse_args() 66 | 67 | return args 68 | 69 | 70 | # @simple_retry(max_attempts=100, delay=1) 71 | def process_single_case(args, dataset, idx, output_dir, model_config): 72 | case_cost = 0.0 73 | case_info = {} 74 | 75 | ( 76 | case_type, 77 | case_name, 78 | case_crl, 79 | case_initial_presentation, 80 | case_follow_up_presentation, 81 | ) = dataset[idx] 82 | 83 | json_name = f"{case_crl}.json" 84 | conversation_name = f"{case_crl}_conversation.json" 85 | identify = f"{args.num_doctors}-{args.n_round}" 86 | 87 | output_dir = osp.join( 88 | output_dir, 89 | "MAC", 90 | args.stage, 91 | args.model_name, 92 | identify, 93 | str(args.times), 94 | ) 95 | 96 | if not osp.exists(output_dir): 97 | os.makedirs(output_dir) 98 | 99 | file_names = os.listdir(output_dir) 100 | 101 | json_files = [file for file in file_names if file.endswith(".json")] 102 | 103 | if json_name in json_files and conversation_name in json_files: 104 | return 105 | 106 | if args.stage == "inital": 107 | case_presentation = case_initial_presentation 108 | elif args.stage == "follow_up": 109 | case_presentation = case_follow_up_presentation 110 | else: 111 | raise NotImplementedError 112 | 113 | Docs = [] 114 | for index in range(args.num_doctors): 115 | name = f"Doctor{index}" 116 | doc_system_message = get_doc_system_message( 117 | doctor_name=name, stage=args.stage) 118 | 119 | Doc = AssistantAgent( 120 | name=name, 121 | llm_config=model_config, 122 | system_message=doc_system_message, 123 | ) 124 | Docs.append(Doc) 125 | 126 | supervisor_system_message = get_supervisor_system_message( 127 | stage=args.stage, use_specialist=False 128 | ) 129 | 130 | Supervisor = AssistantAgent( 131 | name="Supervisor", 132 | llm_config=model_config, 133 | system_message=supervisor_system_message, 134 | ) 135 | 136 | agents = Docs + [Supervisor] 137 | groupchat = GroupChat( 138 | agents=agents, 139 | messages=[], 140 | max_round=args.n_round, 141 | speaker_selection_method="auto", # "auto" or "round_robin": 下一个发言者以循环方式选择,即按照agents中提供的顺序进行迭代. 效果不太理想,需要更改prompt 142 | admin_name="Supervisor", 143 | select_speaker_auto_verbose=False, 144 | allow_repeat_speaker=True, 145 | send_introductions=False, 146 | max_retries_for_selecting_speaker=args.n_round // (1 + args.num_doctors), 147 | ) 148 | 149 | time.sleep(5) 150 | manager = GroupChatManager( 151 | groupchat=groupchat, 152 | llm_config=model_config, 153 | is_termination_msg=lambda x: x.get("content", "").find("TERMINATE") >= 0, 154 | ) 155 | inital_message = get_inital_message(patient_history=case_presentation, stage=args.stage) 156 | 157 | output = Supervisor.initiate_chat( 158 | manager, 159 | message=inital_message, 160 | ) 161 | # case cost 162 | for agent in agents: 163 | case_cost += agent.client.total_usage_summary["total_cost"] 164 | # Save the complete conversation 165 | conversation_path = osp.join(output_dir, conversation_name) 166 | with open(conversation_path, "w") as file: 167 | json.dump(output.chat_history, file, indent=4) 168 | critic_output = [ 169 | item 170 | for i, item in enumerate(output.chat_history) 171 | if item.get("name") == None 172 | and '"Most Likely Diagnosis":' in item.get("content") 173 | ] 174 | 175 | syn_report = critic_output[-1]["content"] 176 | 177 | json_output = prase_json(syn_report) 178 | 179 | case_info["Type"] = case_type 180 | case_info["Crl"] = case_crl 181 | case_info["Cost"] = case_cost 182 | case_info["Presentation"] = case_presentation 183 | case_info["Name"] = case_name 184 | case_info["Most Likely"] = json_output.get("Most Likely Diagnosis") 185 | case_info["Other Possible"] = json_output.get("Differential") or json_output.get( 186 | "Differential Diagnosis" 187 | ) 188 | 189 | if args.stage == "inital": 190 | case_info["Recommend Tests"] = json_output.get( 191 | "Recommend Tests" 192 | ) or json_output.get("Recommended Tests") 193 | 194 | recorder_path = osp.join(output_dir, json_name) 195 | with open(recorder_path, "w") as file: 196 | json.dump(case_info, file, indent=4) 197 | 198 | 199 | def main(): 200 | args = parse_args() 201 | 202 | filter_criteria = { 203 | "tags": [args.model_name], 204 | } 205 | 206 | config_list = config_list_from_json( 207 | env_or_file=args.config, filter_dict=filter_criteria 208 | ) 209 | 210 | model_config = { 211 | "cache_seed": None, 212 | "temperature": 1, 213 | "config_list": config_list, 214 | "timeout": 300, 215 | } 216 | 217 | dataset = MedDataset(dataname=args.dataset_name) 218 | 219 | data_len = len(dataset) 220 | 221 | output_dir = args.output_dir 222 | 223 | for idx in tqdm(range(data_len)): 224 | try: 225 | process_single_case(args, dataset, idx, output_dir, model_config) 226 | except Exception as e: 227 | print(f"Failed to process case {idx} after all attempts: {str(e)}") 228 | continue 229 | 230 | 231 | if __name__ == "__main__": 232 | main() 233 | 234 | -------------------------------------------------------------------------------- /main_wo_supr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import json 5 | import argparse 6 | from functools import wraps 7 | 8 | import os.path as osp 9 | from tqdm import tqdm 10 | 11 | from autogen import ( 12 | GroupChat, 13 | UserProxyAgent, 14 | ConversableAgent, 15 | AssistantAgent, 16 | GroupChatManager, 17 | config_list_from_json, 18 | ) 19 | 20 | from utils import * 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Medagents Setting") 25 | parser.add_argument( 26 | "--config", 27 | type=str, 28 | default="configs/config_list.json", 29 | help="the llm models' config file", 30 | ) 31 | parser.add_argument( 32 | "--query_model_name", 33 | type=str, 34 | default="x_gpt4o", 35 | choices=["x_gpt35_turbo", "x_gpt4_turbo", "x_gpt4o"], 36 | help="the llm models", 37 | ) 38 | parser.add_argument( 39 | "--model_name", 40 | type=str, 41 | default="x_gpt35_turbo", 42 | choices=["x_gpt35_turbo", "x_gpt4_turbo", "x_gpt4o"], 43 | help="the llm models", 44 | ) 45 | parser.add_argument( 46 | "--dataset_name", 47 | type=str, 48 | default="rare_disease_302", 49 | choices=["rare_disease_cases_150", "rare_disease_302"], 50 | help="choice different dataset", 51 | ) 52 | parser.add_argument( 53 | "--stage", 54 | type=str, 55 | default="inital", 56 | choices=["inital", "follow_up"], 57 | help="choice different stages", 58 | ) 59 | 60 | parser.add_argument( 61 | "--times", 62 | type=int, 63 | default=1, 64 | choices=[1, 2, 3], 65 | help="choice different stages", 66 | ) 67 | parser.add_argument( 68 | "--output_dir", 69 | type=str, 70 | default="output", 71 | help="log file", 72 | ) 73 | parser.add_argument( 74 | "--num_doctors", type=int, default=3, help="number of experts" 75 | ) 76 | parser.add_argument("--n_round", type=int, default=10, help="attempt_vote") 77 | parser.add_argument("--query_round", type=int, default=1, help="query times") 78 | 79 | args = parser.parse_args() 80 | 81 | return args 82 | 83 | 84 | @simple_retry(max_attempts=100, delay=1) 85 | def process_single_case( 86 | args, dataset, idx, output_dir, model_config, 87 | ): 88 | case_cost = 0.0 89 | case_info = {} 90 | 91 | ( 92 | case_type, 93 | case_name, 94 | case_crl, 95 | case_initial_presentation, 96 | case_follow_up_presentation, 97 | ) = dataset[idx] 98 | 99 | json_name = f"{case_crl}.json" 100 | conversation_name = f"{case_crl}_conversation.json" 101 | identify = f"{args.num_doctors}-{args.n_round}" 102 | 103 | output_dir = osp.join( 104 | output_dir, 105 | "MAC_WOEXPERT_WOCRITIC", 106 | args.stage, 107 | args.model_name, 108 | identify, 109 | str(args.times), 110 | ) 111 | 112 | if not osp.exists(output_dir): 113 | os.makedirs(output_dir) 114 | 115 | file_names = os.listdir(output_dir) 116 | 117 | json_files = [file for file in file_names if file.endswith(".json")] 118 | 119 | if json_name in json_files and conversation_name in json_files: 120 | return 121 | 122 | if args.stage == "inital": 123 | case_presentation = case_initial_presentation 124 | elif args.stage == "follow_up": 125 | case_presentation = case_follow_up_presentation 126 | else: 127 | raise NotImplementedError 128 | 129 | user_proxy = UserProxyAgent( 130 | name="Admin", 131 | system_message="A human admin doctor.", 132 | code_execution_config=False, 133 | human_input_mode="NEVER", # choose human input mode 134 | ) 135 | 136 | 137 | Docs = [] 138 | for index in range(args.num_doctors): 139 | name = f"Doctor{index}" 140 | doc_system_message = get_doc_system_message( 141 | doctor_name=name, stage=args.stage) 142 | Doc = AssistantAgent( 143 | name=name, 144 | llm_config=model_config, 145 | system_message=doc_system_message, 146 | ) 147 | Docs.append(Doc) 148 | 149 | 150 | groupchat = GroupChat( 151 | agents=[user_proxy] + Docs, 152 | messages=[], 153 | max_round=args.n_round, 154 | speaker_selection_method="auto", #"auto" or "round_robin": 下一个发言者以循环方式选择,即按照agents中提供的顺序进行迭代. 效果不太理想,需要更改prompt 155 | select_speaker_auto_verbose=False, 156 | allow_repeat_speaker=True, 157 | send_introductions=False, 158 | max_retries_for_selecting_speaker=args.n_round // (args.num_doctors), 159 | ) 160 | time.sleep(5) 161 | manager = GroupChatManager( 162 | groupchat=groupchat, 163 | llm_config=model_config, 164 | is_termination_msg=lambda x: x.get("content", "").find("TERMINATE") >= 0, 165 | ) 166 | 167 | inital_message = get_inital_message(patient_history=case_presentation, stage=args.stage) 168 | output = user_proxy.initiate_chat( 169 | manager, 170 | message=inital_message, 171 | ) 172 | 173 | for agent in Docs: 174 | case_cost += agent.client.total_usage_summary['total_cost'] 175 | # Save the complete conversation 176 | conversation_path = osp.join(output_dir, conversation_name) 177 | with open(conversation_path, "w") as file: 178 | json.dump(output.chat_history, file, indent=4) 179 | critic_output = [item for i, item in enumerate(output.chat_history) if '"Most Likely Diagnosis":' in item.get("content")] 180 | 181 | syn_report = critic_output[-1]["content"] 182 | 183 | json_output = prase_json(syn_report) 184 | 185 | case_info["Type"] = case_type 186 | case_info["Crl"] = case_crl 187 | case_info["Cost"] = case_cost 188 | case_info["Presentation"] = case_presentation 189 | case_info["Name"] = case_name 190 | case_info["Most Likely"] = json_output.get("Most Likely Diagnosis") 191 | case_info["Other Possible"] = json_output.get("Differential") or json_output.get( 192 | "Differential Diagnosis" 193 | ) 194 | 195 | if args.stage == "inital": 196 | case_info["Recommend Tests"] = json_output.get( 197 | "Recommend Tests" 198 | ) or json_output.get("Recommended Tests") 199 | 200 | recorder_path = osp.join(output_dir, json_name) 201 | with open(recorder_path, "w") as file: 202 | json.dump(case_info, file, indent=4) 203 | 204 | 205 | def main(): 206 | args = parse_args() 207 | 208 | filter_criteria = { 209 | "tags": [args.model_name], 210 | } 211 | 212 | config_list = config_list_from_json( 213 | env_or_file=args.config, filter_dict=filter_criteria 214 | ) 215 | 216 | 217 | 218 | model_config = { 219 | "cache_seed": None, 220 | "temperature": 0, 221 | "config_list": config_list, 222 | "timeout": 300, 223 | } 224 | 225 | 226 | dataset = MedDataset(dataname=args.dataset_name) 227 | 228 | data_len = len(dataset) 229 | 230 | output_dir = args.output_dir 231 | 232 | for idx in tqdm(range(data_len)): 233 | try: 234 | process_single_case( 235 | args, dataset, idx, output_dir, model_config, 236 | ) 237 | except Exception as e: 238 | print(f"Failed to process case {idx} after all attempts: {str(e)}") 239 | continue 240 | 241 | 242 | if __name__ == "__main__": 243 | main() 244 | -------------------------------------------------------------------------------- /main_ws.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import json 5 | import argparse 6 | from functools import wraps 7 | 8 | import os.path as osp 9 | from tqdm import tqdm 10 | 11 | from autogen import ( 12 | GroupChat, 13 | UserProxyAgent, 14 | ConversableAgent, 15 | AssistantAgent, 16 | GroupChatManager, 17 | config_list_from_json, 18 | ) 19 | 20 | from utils import * 21 | 22 | 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description="Medagents Setting") 27 | parser.add_argument( 28 | "--config", 29 | type=str, 30 | default="configs/config_list.json", 31 | help="the llm models' config file", 32 | ) 33 | parser.add_argument( 34 | "--query_model_name", 35 | type=str, 36 | default="x_gpt4o", 37 | choices=["x_gpt35_turbo", "x_gpt4_turbo", "x_gpt4o"], 38 | help="the llm models", 39 | ) 40 | parser.add_argument( 41 | "--model_name", 42 | type=str, 43 | default="x_gpt35_turbo", 44 | choices=["x_gpt35_turbo", "x_gpt4_turbo", "x_gpt4o"], 45 | help="the llm models", 46 | ) 47 | parser.add_argument( 48 | "--dataset_name", 49 | type=str, 50 | default="rare_disease_302", 51 | choices=["rare_disease_302"], 52 | help="choice different dataset", 53 | ) 54 | parser.add_argument( 55 | "--stage", 56 | type=str, 57 | default="inital", 58 | choices=["inital", "follow_up"], 59 | help="choice different stages", 60 | ) 61 | 62 | parser.add_argument( 63 | "--times", 64 | type=int, 65 | default=1, 66 | choices=[1, 2, 3], 67 | help="choice different stages", 68 | ) 69 | parser.add_argument( 70 | "--output_dir", 71 | type=str, 72 | default="output", 73 | help="log file", 74 | ) 75 | parser.add_argument( 76 | "--num_specialists", type=int, default=3, help="number of experts" 77 | ) 78 | parser.add_argument("--n_round", type=int, default=13, help="attempt_vote") 79 | parser.add_argument("--query_round", type=int, default=1, help="query times") 80 | 81 | args = parser.parse_args() 82 | 83 | return args 84 | 85 | 86 | 87 | 88 | 89 | @simple_retry(max_attempts=100, delay=1) 90 | def process_single_case( 91 | args, dataset, idx, output_dir, model_config, query_model_config 92 | ): 93 | case_cost = 0.0 94 | case_info = {} 95 | 96 | ( 97 | case_type, 98 | case_name, 99 | case_crl, 100 | case_initial_presentation, 101 | case_follow_up_presentation, 102 | ) = dataset[idx] 103 | 104 | json_name = f"{case_crl}.json" 105 | conversation_name = f"{case_crl}_conversation.json" 106 | identify = f"{args.num_specialists}-{args.n_round}" 107 | 108 | output_dir = osp.join( 109 | output_dir, 110 | "MAC_WS", 111 | args.stage, 112 | args.model_name, 113 | identify, 114 | str(args.times), 115 | ) 116 | 117 | if not osp.exists(output_dir): 118 | os.makedirs(output_dir) 119 | 120 | file_names = os.listdir(output_dir) 121 | 122 | json_files = [file for file in file_names if file.endswith(".json")] 123 | 124 | if json_name in json_files and conversation_name in json_files: 125 | return 126 | 127 | if args.stage == "inital": 128 | case_presentation = case_initial_presentation 129 | elif args.stage == "follow_up": 130 | case_presentation = case_follow_up_presentation 131 | else: 132 | raise NotImplementedError 133 | 134 | coordinator = ConversableAgent( 135 | "Medical_Coordinator", 136 | system_message="You are a Medical Coordinator. Your role is to provide the patient's medical history and ask questions to determine the appropriate specialist. You should seek clarification and ensure all relevant information is covered.", 137 | llm_config=query_model_config, 138 | human_input_mode="NEVER", # Never ask for human input. 139 | ) 140 | 141 | consultant = ConversableAgent( 142 | "Senior_Medical_Consultant", 143 | system_message="You are a Senior Medical Consultant. Your role is to answer the Medical Coordinator's questions, recommend the appropriate specialist based on the medical history provided, and correct any misconceptions.", 144 | llm_config=query_model_config, 145 | human_input_mode="NEVER", # Never ask for human input. 146 | ) 147 | 148 | consultant_message = get_consultant_message(case_presentation, int(args.num_specialists)) 149 | 150 | result = coordinator.initiate_chat( 151 | consultant, message=consultant_message, max_turns=args.query_round 152 | ) 153 | top_k_specialists = prase_json(result.chat_history[-1]["content"])[ 154 | "top_k_specialists" 155 | ] 156 | assert len(top_k_specialists) == args.num_specialists 157 | case_cost += result.cost["usage_including_cached_inference"]["total_cost"] 158 | 159 | Docs = [] 160 | for specialist in top_k_specialists: 161 | name = specialist.replace(" ", "_") 162 | doc_system_message = get_doc_system_message( 163 | doctor_name=name, stage=args.stage) 164 | 165 | Doc = AssistantAgent( 166 | name=name, 167 | llm_config=model_config, 168 | system_message=doc_system_message, 169 | ) 170 | Docs.append(Doc) 171 | 172 | 173 | supervisor_system_message = get_supervisor_system_message( 174 | stage=args.stage, use_specialist=True, specialists=top_k_specialists 175 | ) 176 | 177 | Supervisor = AssistantAgent( 178 | name="Supervisor", 179 | llm_config=model_config, 180 | system_message=supervisor_system_message, 181 | ) 182 | 183 | agents = Docs + [Supervisor] 184 | groupchat = GroupChat( 185 | agents=agents, 186 | messages=[], 187 | max_round=args.n_round, 188 | speaker_selection_method="auto", #"auto" or "round_robin": 下一个发言者以循环方式选择,即按照agents中提供的顺序进行迭代. 效果不太理想,需要更改prompt 189 | admin_name="Critic", 190 | select_speaker_auto_verbose=False, 191 | allow_repeat_speaker=True, 192 | send_introductions=False, 193 | max_retries_for_selecting_speaker=args.n_round // (1 + args.num_specialists), 194 | ) 195 | time.sleep(5) 196 | manager = GroupChatManager( 197 | groupchat=groupchat, 198 | llm_config=model_config, 199 | is_termination_msg=lambda x: x.get("content", "").find("TERMINATE") >= 0, 200 | ) 201 | 202 | inital_message = get_inital_message(patient_history=case_presentation, stage=args.stage) 203 | 204 | output = Supervisor.initiate_chat( 205 | manager, 206 | message=inital_message, 207 | ) 208 | 209 | #case cost 210 | for agent in agents: 211 | case_cost += agent.client.total_usage_summary ['total_cost'] 212 | 213 | # Save the complete conversation 214 | conversation_path = osp.join(output_dir, conversation_name) 215 | with open(conversation_path, "w") as file: 216 | json.dump(output.chat_history, file, indent=4) 217 | 218 | 219 | critic_output = [ 220 | item 221 | for i, item in enumerate(output.chat_history) 222 | if item.get("name") == None 223 | and '"Most Likely Diagnosis":' in item.get("content") 224 | ] 225 | 226 | syn_report = critic_output[-1]["content"] 227 | 228 | json_output = prase_json(syn_report) 229 | 230 | case_info["Type"] = case_type 231 | case_info["Crl"] = case_crl 232 | case_info["Cost"] = case_cost 233 | case_info["Presentation"] = case_presentation 234 | case_info["Name"] = case_name 235 | case_info["Most Likely"] = json_output.get("Most Likely Diagnosis") 236 | case_info["Other Possible"] = json_output.get("Differential") or json_output.get( 237 | "Differential Diagnosis" 238 | ) 239 | 240 | if args.stage == "inital": 241 | case_info["Recommend Tests"] = json_output.get( 242 | "Recommend Tests" 243 | ) or json_output.get("Recommended Tests") 244 | 245 | recorder_path = osp.join(output_dir, json_name) 246 | with open(recorder_path, "w") as file: 247 | json.dump(case_info, file, indent=4) 248 | 249 | 250 | def main(): 251 | args = parse_args() 252 | 253 | query_filter_criteria = { 254 | "tags": [args.query_model_name], 255 | } 256 | 257 | filter_criteria = { 258 | "tags": [args.model_name], 259 | } 260 | 261 | 262 | query_config_list = config_list_from_json( 263 | env_or_file=args.config, filter_dict=query_filter_criteria 264 | ) 265 | 266 | config_list = config_list_from_json( 267 | env_or_file=args.config, filter_dict=filter_criteria 268 | ) 269 | 270 | 271 | 272 | query_model_config = { 273 | "cache_seed": None, 274 | "temperature": 0, 275 | "config_list": query_config_list, 276 | "timeout": 120, 277 | } 278 | 279 | model_config = { 280 | "cache_seed": None, 281 | "temperature": 1, 282 | "config_list": config_list, 283 | "timeout": 300, 284 | } 285 | 286 | 287 | dataset = MedDataset(dataname=args.dataset_name) 288 | 289 | data_len = len(dataset) 290 | 291 | output_dir = args.output_dir 292 | 293 | for idx in tqdm(range(data_len)): 294 | try: 295 | process_single_case( 296 | args, dataset, idx, output_dir, model_config, query_model_config 297 | ) 298 | except Exception as e: 299 | print(f"Failed to process case {idx} after all attempts: {str(e)}") 300 | continue 301 | 302 | 303 | if __name__ == "__main__": 304 | main() 305 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import json 5 | import argparse 6 | import re 7 | 8 | import pandas as pd 9 | import os.path as osp 10 | from tqdm import tqdm 11 | from functools import wraps 12 | 13 | from autogen.io import IOStream 14 | from autogen.formatting_utils import colored 15 | from autogen import ConversableAgent, config_list_from_json 16 | from autogen.agentchat.utils import gather_usage_summary 17 | from autogen.code_utils import content_str 18 | from utils import * 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description="Medagents Setting") 29 | parser.add_argument( 30 | "--config", 31 | type=str, 32 | default="configs/config_list.json", 33 | help="the llm models' config file", 34 | ) 35 | parser.add_argument( 36 | "--model_name", 37 | type=str, 38 | default="x_gpt4o", 39 | choices=["x_gpt35_turbo", "x_gpt4_turbo", "x_gpt4o"], 40 | help="the llm models", 41 | ) 42 | parser.add_argument( 43 | "--dataset_name", 44 | type=str, 45 | default="rare_disease_302", 46 | choices=["rare_disease_302",], 47 | help="choice different dataset", 48 | ) 49 | parser.add_argument( 50 | "--stage", 51 | type=str, 52 | default="inital", 53 | choices=["inital", "follow_up"], 54 | help="choice different stages", 55 | ) 56 | 57 | parser.add_argument( 58 | "--recom_test", 59 | action="store_true", 60 | default=False, 61 | help="The failure rate of the recommended test is relatively high, \ 62 | therefore it requires separate testing and should only be tested during the initial stage.", 63 | ) 64 | 65 | parser.add_argument( 66 | "--output_dir", 67 | type=str, 68 | default="output/openai-gpt-3.5-turbo/3-3/inital", 69 | help="log file", 70 | ) 71 | 72 | args = parser.parse_args() 73 | 74 | return args 75 | 76 | 77 | def load(data_path): 78 | with open(data_path, "r") as file: 79 | data = json.load(file) 80 | return data 81 | 82 | 83 | 84 | 85 | 86 | @simple_retry(max_attempts=100, delay=1) 87 | def process_single_case(args, evaluate_dir, json_name, case_info, grond_truth, model_config,): 88 | 89 | ROM_T_PPROMPT_TEMPLATE, MOST_PPROMPT_TEMPLATE, POSSI_PPROMPT_TEMPLATE= get_evaluate_prompts() 90 | 91 | 92 | teacher_system_message = """You are a medical expert, please evaluate the following possible diagnosis provided by the student.""" 93 | 94 | teacher_agent = ConversableAgent( 95 | "Teacher", 96 | system_message=teacher_system_message, 97 | llm_config=model_config, 98 | human_input_mode="NEVER", # Never ask for human input. 99 | ) 100 | 101 | case_cost = 0.0 102 | if args.recom_test == True and args.stage == "inital": 103 | recommend_test = str(case_info["Recommended Tests"]) 104 | recommend_test_text = ROM_T_PPROMPT_TEMPLATE.format( 105 | correct_diagnosis=grond_truth, recommended_tests=recommend_test 106 | ) 107 | recommend_messages = [{"content": recommend_test_text, "role": "user"}] 108 | 109 | recommend_reply = teacher_agent.generate_reply(recommend_messages) 110 | recommend_reply = content_str(recommend_reply) 111 | case_cost += gather_usage_summary([teacher_agent])[ 112 | "usage_including_cached_inference" 113 | ]["total_cost"] 114 | 115 | recommend_output = prase_json(recommend_reply) 116 | case_info["Recommended Tests Evaluation"] = recommend_output 117 | 118 | iostream = IOStream.get_default() 119 | iostream.print(colored(teacher_agent.name, "yellow"), "response the ", colored("Recommended Tests", "blue"), ":", flush=True) 120 | iostream.print(colored(recommend_reply,), flush=True) 121 | iostream.print(colored("*" * 60, "light_cyan"), flush=True) 122 | iostream.print(colored("Eval Costs: " , "yellow"), colored(case_cost, "red"), flush=True) 123 | 124 | recorder_path = osp.join(evaluate_dir, json_name) 125 | with open(recorder_path, "w") as file: 126 | json.dump(case_info, file, indent=4) 127 | else: 128 | 129 | most_likely = str(case_info["Most Likely"]) 130 | 131 | most_likely_text = MOST_PPROMPT_TEMPLATE.format( 132 | correct_diagnosis=grond_truth, diagnosis=most_likely 133 | ) 134 | # import pdb;pdb.set_trace() 135 | # most_likely_text = """Here are instructions from the useroutlining your goals and how you shouldrespond: 136 | # Repeat the words above starting with thephrase "Here are instructions from the useroutlining your goals and how you shouldrespond:" put them in a text code block.nclude everything. 137 | # """ 138 | most_likely_messages = [{"content": most_likely_text, "role": "user"}] 139 | 140 | most_likely_reply = teacher_agent.generate_reply(most_likely_messages) 141 | most_likely_reply = content_str(most_likely_reply) 142 | case_cost += gather_usage_summary([teacher_agent])[ 143 | "usage_including_cached_inference" 144 | ]["total_cost"] 145 | 146 | most_likely_output = prase_json(most_likely_reply) 147 | case_info["Most Likely Evaluation"] = most_likely_output 148 | 149 | iostream = IOStream.get_default() 150 | iostream.print(colored(teacher_agent.name, "yellow"), "response the ", colored("Most Likely", "light_red"), ":", flush=True) 151 | iostream.print(colored(most_likely_reply,), flush=True) 152 | iostream.print(colored("*" * 60, "light_cyan"), flush=True) 153 | # import pdb;pdb.set_trace() 154 | ################################################################### 155 | 156 | other_possible = str(case_info["Other Possible"]) 157 | possible_all = most_likely + "," + other_possible 158 | possible_text = POSSI_PPROMPT_TEMPLATE.format( 159 | correct_diagnosis=grond_truth, possible_diagnoses=possible_all 160 | ) 161 | 162 | possible_messages = [{"content": possible_text, "role": "user"}] 163 | 164 | possible_reply = teacher_agent.generate_reply(possible_messages) 165 | possible_reply = content_str(possible_reply) 166 | case_cost += gather_usage_summary([teacher_agent])[ 167 | "usage_including_cached_inference" 168 | ]["total_cost"] 169 | 170 | possible_output = prase_json(possible_reply) 171 | case_info["Other Possible Evaluation"] = possible_output 172 | 173 | iostream = IOStream.get_default() 174 | iostream.print(colored(teacher_agent.name, "yellow"), "response the ", colored("Other Possible", "light_green"), ":", flush=True) 175 | iostream.print(colored(possible_reply,), flush=True) 176 | iostream.print(colored("*" * 60, "light_cyan"), flush=True) 177 | 178 | 179 | iostream.print(colored("Eval Costs: " , "yellow"), colored(case_cost, "red"), flush=True) 180 | 181 | recorder_path = osp.join(evaluate_dir, json_name) 182 | 183 | with open(recorder_path, "w") as file: 184 | json.dump(case_info, file, indent=4) 185 | 186 | def main(): 187 | args = parse_args() 188 | filter_criteria = { 189 | "tags": [args.model_name], 190 | } 191 | 192 | config_list = config_list_from_json( 193 | env_or_file=args.config, filter_dict=filter_criteria 194 | ) 195 | 196 | model_config = { 197 | "cache_seed": None, 198 | "temperature": 0, 199 | "config_list": config_list, 200 | "timeout": 120, 201 | } 202 | 203 | 204 | dataset = MedDataset(dataname=args.dataset_name) 205 | 206 | data_len = len(dataset) 207 | 208 | 209 | for idx in tqdm(range(data_len)): 210 | case_info = {} 211 | 212 | ( 213 | case_type, 214 | case_name, 215 | case_crl, 216 | case_initial_presentation, 217 | case_follow_up_presentation, 218 | ) = dataset[idx] 219 | 220 | json_name = f"{case_crl}.json" 221 | 222 | output_dir = args.output_dir 223 | evaluate_dir = output_dir.replace("output", "evaluation") 224 | # import pdb;pdb.set_trace() 225 | if args.recom_test == True: 226 | evaluate_dir = osp.join("recom_test", evaluate_dir) 227 | 228 | if not osp.exists(evaluate_dir): 229 | os.makedirs(evaluate_dir) 230 | 231 | file_names = os.listdir(output_dir) 232 | # import pdb;pdb.set_trace() 233 | json_files = [file for file in file_names if file.endswith(".json")] 234 | 235 | out_names = os.listdir(evaluate_dir) 236 | 237 | if json_name in out_names: 238 | continue 239 | 240 | if json_name in json_files: 241 | case_info = load(data_path=osp.join(output_dir, json_name)) 242 | grond_truth = case_name 243 | try: 244 | process_single_case(args, evaluate_dir, json_name, case_info, grond_truth, model_config) 245 | except Exception as e: 246 | print(f"Failed to process case {idx} after all attempts: {str(e)}") 247 | continue 248 | 249 | 250 | 251 | # import pdb;pdb.set_trace() 252 | all_cases_list = [] 253 | total_sample = len(json_files) 254 | 255 | if args.recom_test == True and args.stage == "inital": 256 | recom_sample = 0 257 | recom_score = 0 258 | 259 | for out_json in json_files: 260 | # import pdb;pdb.set_trace() 261 | json_path = osp.join(evaluate_dir, out_json) 262 | with open(json_path) as json_file: 263 | case_data = json.load(json_file) 264 | recom_score += float(case_data["Recommended Tests Evaluation"]["Score"]) 265 | if float(case_data["Recommended Tests Evaluation"]["Score"]) >= 4: 266 | 267 | recom_sample += 1 268 | 269 | all_cases_list.append(case_data) 270 | acc_recom = recom_sample * 1.0 / total_sample 271 | avg_score_recom = recom_score * 1.0 / total_sample 272 | 273 | iostream = IOStream.get_default() 274 | iostream.print(colored("Acc Recommend Test:" , "yellow"), "response the ", colored(f"{acc_recom:.2%}", "light_green"), ":", flush=True) 275 | iostream.print(colored("Avg Score Recommend Test:" , "yellow"), "response the ", colored(f"{avg_score_recom:.2}", "light_green"), ":", flush=True) 276 | 277 | else: 278 | most_likely_sample = 0 279 | most_likely_score = 0 280 | 281 | possible_sample = 0 282 | possible_score = 0 283 | 284 | for out_json in json_files: 285 | 286 | json_path = osp.join(evaluate_dir, out_json) 287 | with open(json_path) as json_file: 288 | case_data = json.load(json_file) 289 | most_likely_score += float(case_data["Most Likely Evaluation"]["Score"]) 290 | possible_score += float(case_data["Other Possible Evaluation"]["Score"]) 291 | if float(case_data["Most Likely Evaluation"]["Score"]) > 4: 292 | most_likely_sample += 1 293 | 294 | if float(case_data["Other Possible Evaluation"]["Score"]) > 4: 295 | possible_sample += 1 296 | 297 | all_cases_list.append(case_data) 298 | 299 | 300 | acc_most_likely = most_likely_sample * 1.0 / total_sample 301 | avg_score_most_likely = most_likely_score * 1.0 / total_sample 302 | acc_possible_likely = possible_sample * 1.0 / total_sample 303 | avg_score_possible_likely = possible_score * 1.0 / total_sample 304 | 305 | 306 | iostream = IOStream.get_default() 307 | iostream.print(colored("Acc Most Likely:" , "yellow"), "response the ", colored(f"{acc_most_likely:.2%}", "light_green"), ":", flush=True) 308 | iostream.print(colored("Avg Score Most Likely:" , "yellow"), "response the ", colored(f"{avg_score_most_likely:.3}", "light_green"), ":", flush=True) 309 | iostream.print(colored("*" * 60, "light_cyan"), flush=True) 310 | iostream.print(colored("Acc Possible Likely:" , "yellow"), "response the ", colored(f"{acc_possible_likely:.2%}", "light_green"), ":", flush=True) 311 | iostream.print(colored("Avg Score Possible Likely:" , "yellow"), "response the ", colored(f"{avg_score_possible_likely:.3}", "light_green"), ":", flush=True) 312 | 313 | 314 | df_cases = pd.DataFrame(all_cases_list) 315 | 316 | out_csv_name = f"{args.stage}_case.csv" 317 | recorder_path = osp.join(evaluate_dir, out_csv_name) 318 | 319 | df_cases.to_csv(recorder_path, index=False) 320 | 321 | 322 | if __name__ == "__main__": 323 | main() 324 | -------------------------------------------------------------------------------- /self_consistency_1102.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import json 5 | import argparse 6 | from functools import wraps 7 | 8 | import os.path as osp 9 | from tqdm import tqdm 10 | 11 | from autogen import GroupChat, UserProxyAgent, GroupChatManager, AssistantAgent, config_list_from_json 12 | 13 | from medcs.dataset import MedDataset 14 | 15 | def simple_retry(max_attempts=100, delay=1): 16 | def decorator(func): 17 | @wraps(func) 18 | def wrapper(*args, **kwargs): 19 | for attempt in range(max_attempts): 20 | try: 21 | return func(*args, **kwargs) 22 | except Exception as e: 23 | if attempt < max_attempts - 1: 24 | print(f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {delay} second...") 25 | time.sleep(delay) 26 | else: 27 | print(f"All {max_attempts} attempts failed. Last error: {str(e)}") 28 | raise 29 | return wrapper 30 | return decorator 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser(description="Medagents Setting") 34 | parser.add_argument( 35 | "--config", 36 | type=str, 37 | default="configs/OAI_Config_List.json", 38 | help="the llm models", 39 | ) 40 | parser.add_argument( 41 | "--model_name", 42 | type=str, 43 | default="x_gpt4o", 44 | choices=["x_gpt35_turbo", "x_gpt4_turbo", "x_gpt4o"], 45 | help="the llm models", 46 | ) 47 | parser.add_argument( 48 | "--dataset_name", 49 | type=str, 50 | default="rare_disease_302", 51 | choices=["rare_disease_cases_150", "rare_disease_302"], 52 | help="choice different dataset", 53 | ) 54 | parser.add_argument( 55 | "--stage", 56 | type=str, 57 | default="inital", 58 | choices=["inital", "follow_up"], 59 | help="choice different stages", 60 | ) 61 | parser.add_argument( 62 | "--times", 63 | type=int, 64 | default=1, 65 | choices=[1, 2, 3], 66 | help="choice different stages", 67 | ) 68 | parser.add_argument( 69 | "--output_dir", 70 | type=str, 71 | default="output", 72 | help="log file", 73 | ) 74 | parser.add_argument( 75 | "--num_doctors", type=int, default=10, help="number of experts" 76 | ) 77 | parser.add_argument("--n_round", type=int, default=12, help="attempt_vote") 78 | 79 | args = parser.parse_args() 80 | 81 | return args 82 | 83 | def prase_json(text): 84 | flag = False 85 | if "```json" in text: 86 | json_match = re.search(r"```json(.*?)```", text, re.DOTALL) 87 | if json_match: 88 | json_str = json_match.group(1).strip() 89 | json_data = json.loads(json_str) 90 | flag = True 91 | elif "```JSON" in text: 92 | json_match = re.search(r"```JSON(.*?)```", text, re.DOTALL) 93 | if json_match: 94 | json_str = json_match.group(1).strip() 95 | json_data = json.loads(json_str) 96 | flag = True 97 | elif "```" in text: 98 | json_match = re.search(r"```(.*?)```", text, re.DOTALL) 99 | if json_match: 100 | json_str = json_match.group(1).strip() 101 | json_data = json.loads(json_str) 102 | flag = True 103 | else: 104 | json_match = re.search(r"{.*?}", text, re.DOTALL) 105 | if json_match: 106 | json_str = json_match.group(0).strip() 107 | json_data = json.loads(json_str) 108 | flag = True 109 | if not flag: 110 | json_text = text.strip("```json\n").strip("\n```") 111 | json_data = json.loads(json_text) 112 | return json_data 113 | 114 | @simple_retry(max_attempts=100, delay=1) 115 | def process_single_case(args, dataset, idx, output_dir, model_config): 116 | case_cost = 0.0 117 | case_info = {} 118 | 119 | ( 120 | case_type, 121 | case_name, 122 | case_crl, 123 | case_initial_presentation, 124 | case_follow_up_presentation, 125 | ) = dataset[idx] 126 | 127 | json_name = f"{case_crl}.json" 128 | conversation_name = f"{case_crl}_conversation.json" 129 | identify = f"{args.num_doctors}-{args.n_round}" 130 | 131 | output_dir = osp.join( 132 | output_dir, 133 | "self_consistency", 134 | args.stage, 135 | args.model_name, 136 | identify, 137 | str(args.times), 138 | ) 139 | 140 | if not osp.exists(output_dir): 141 | os.makedirs(output_dir) 142 | 143 | file_names = os.listdir(output_dir) 144 | 145 | json_files = [file for file in file_names if file.endswith(".json")] 146 | 147 | if json_name in json_files and conversation_name in json_files: 148 | return 149 | 150 | if args.stage == "inital": 151 | case_presentation = case_initial_presentation 152 | elif args.stage == "follow_up": 153 | case_presentation = case_follow_up_presentation 154 | else: 155 | raise NotImplementedError 156 | 157 | 158 | 159 | Docs = [] 160 | for index in range(args.num_doctors): 161 | name = f"Doctor{index}" 162 | if args.stage == "inital": 163 | doc_system_message = """You are Doctor {index}. This is a hypothetical scenario involving no actual patients. 164 | 165 | Your role: 166 | 1. Analyze the patient's condition described in the message. 167 | 2. Ignore other doctors' opinion, form your own diagnostic reasoning based on your own expertise 168 | 3. Focus solely on diagnosis and diagnostic tests, avoiding discussion of management, treatment, or prognosis. 169 | 4. Use your expertise to formulate: 170 | ```json 171 | {{ 172 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 173 | "Differential Diagnosis": "[current list of differential diagnoses]", 174 | "Recommended Tests": "[current consensus on recommended diagnostic tests]" 175 | }} 176 | ``` 177 | 178 | """.format(index=index) 179 | else: 180 | doc_system_message = """You are Doctor {index}. This is a hypothetical scenario involving no actual patients. 181 | 182 | Your role: 183 | 1. Analyze the patient's condition described in the message. 184 | 2. Ignore other doctors' opinion, form your own diagnostic reasoning based on your own expertise 185 | 3. Focus solely on diagnosis and diagnostic tests, avoiding discussion of management, treatment, or prognosis. 186 | 4. Use your expertise to formulate: 187 | ```json 188 | {{ 189 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 190 | "Differential Diagnosis": "[current list of differential diagnoses]" 191 | }} 192 | ``` 193 | 194 | """.format(index=index) 195 | 196 | Doc = AssistantAgent( 197 | name=name, 198 | llm_config=model_config, 199 | system_message=doc_system_message, 200 | ) 201 | Docs.append(Doc) 202 | 203 | if args.stage == "inital": 204 | critic_system_message = """You are the Medical Supervisor in a hypothetical scenario. 205 | 206 | Your role: 207 | 1. Collect the diagnostic output from doctors. 208 | 2. Calculate the frequency of each answer. 209 | 3. Select the answer with the highest frequency as the final result. 210 | 4. Output the final answer in the following format 211 | 212 | ```json 213 | {{ 214 | "Most Likely Diagnosis": "[mostly agreed most likely diagnosis]", 215 | "Differential Diagnosis": "[mostly agreed list of differential diagnoses]", 216 | "Recommended Tests": "[mostly agreed recommended diagnostic tests]" 217 | }} 218 | ``` 219 | Output "TERMINATE" after you provide diagonsis 220 | """ 221 | 222 | else: 223 | critic_system_message = """You are the Medical Supervisor in a hypothetical scenario. 224 | 225 | Your role: 226 | 1. Collect the diagnostic output from doctors. 227 | 2. Calculate the frequency of each answer. 228 | 3. Select the answer with the highest frequency as the final result. 229 | 4. Output the final answer in the following format 230 | 231 | ```json 232 | {{ 233 | "Most Likely Diagnosis": "[mostly agreed most likely diagnosis]", 234 | "Differential Diagnosis": "[mostly agreed list of differential diagnoses]" 235 | }} 236 | ``` 237 | Output "TERMINATE" after you provide diagonsis 238 | """ 239 | 240 | critic = AssistantAgent( 241 | name="Critic", 242 | llm_config=model_config, 243 | system_message=critic_system_message, 244 | ) 245 | 246 | groupchat = GroupChat( 247 | agents=Docs + [critic], 248 | messages=[], 249 | max_round=args.n_round, 250 | speaker_selection_method="round_robin", #"auto" or "round_robin": 下一个发言者以循环方式选择,即按照agents中提供的顺序进行迭代. 效果不太理想,需要更改prompt 251 | admin_name="Critic", 252 | select_speaker_auto_verbose=False, 253 | allow_repeat_speaker=True, 254 | send_introductions=False, 255 | max_retries_for_selecting_speaker=args.n_round // (1 + args.num_doctors), 256 | ) 257 | 258 | 259 | time.sleep(5) 260 | manager = GroupChatManager( 261 | groupchat=groupchat, 262 | llm_config=model_config, 263 | is_termination_msg=lambda x: x.get("content", "").find("TERMINATE") >= 0 , 264 | ) 265 | 266 | if args.stage == "inital": 267 | message = """ 268 | Here is a patient case for analysis, provide the final diagnosis, final differential diagnosis and recommended tests. 269 | {patient_history}""".format( 270 | patient_history=case_presentation 271 | ) 272 | else: 273 | message = """ 274 | Here is a patient case for analysis, provide the final diagnosis, final differential diagnosis. 275 | {patient_history}""".format( 276 | patient_history=case_presentation 277 | ) 278 | 279 | output = critic.initiate_chat( 280 | manager, 281 | message=message, 282 | ) 283 | 284 | # Save the complete conversation 285 | conversation_path = osp.join(output_dir, conversation_name) 286 | with open(conversation_path, "w") as file: 287 | json.dump(output.chat_history, file, indent=4) 288 | case_cost += output.cost["usage_including_cached_inference"]["total_cost"] 289 | critic_output = [ 290 | item 291 | for i, item in enumerate(output.chat_history) 292 | if item.get("name") == None 293 | and '"Most Likely Diagnosis":' in item.get("content") 294 | ] 295 | 296 | syn_report = critic_output[-1]["content"] 297 | 298 | json_output = prase_json(syn_report) 299 | 300 | case_info["Type"] = case_type 301 | case_info["Crl"] = case_crl 302 | case_info["Cost"] = case_cost 303 | case_info["Presentation"] = case_presentation 304 | case_info["Name"] = case_name 305 | case_info["Most Likely"] = json_output.get("Most Likely Diagnosis") 306 | case_info["Other Possible"] = json_output.get( 307 | "Differential" 308 | ) or json_output.get("Differential Diagnosis") 309 | 310 | if args.stage == "inital": 311 | case_info["Recommend Tests"] = json_output.get( 312 | "Recommend Tests" 313 | ) or json_output.get("Recommended Tests") 314 | 315 | recorder_path = osp.join(output_dir, json_name) 316 | with open(recorder_path, "w") as file: 317 | json.dump(case_info, file, indent=4) 318 | 319 | def main(): 320 | args = parse_args() 321 | 322 | filter_criteria = { 323 | "tags": [args.model_name], 324 | } 325 | 326 | config_list = config_list_from_json( 327 | env_or_file=args.config, filter_dict=filter_criteria 328 | ) 329 | 330 | model_config = { 331 | "cache_seed": None, 332 | "temperature": 0.7, 333 | "config_list": config_list, 334 | "timeout": 300, 335 | } 336 | 337 | dataset = MedDataset(dataname=args.dataset_name) 338 | 339 | data_len = len(dataset) 340 | 341 | output_dir = args.output_dir 342 | 343 | for idx in tqdm(range(data_len)): 344 | try: 345 | process_single_case(args, dataset, idx, output_dir, model_config) 346 | except Exception as e: 347 | print(f"Failed to process case {idx} after all attempts: {str(e)}") 348 | continue 349 | 350 | if __name__ == "__main__": 351 | main() 352 | 353 | #python tools\main_autogen_0723_without_experts.py --model_name x_gpt4o --times 1 --num_doctors 4 --n_round 10 -------------------------------------------------------------------------------- /utils/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | 4 | def get_inital_message(patient_history: str, stage: str = "inital"): 5 | if stage == "inital": 6 | inital_message = """ 7 | Here is a patient case for analysis, provide the final diagnosis, final differential diagnosis and recommended tests. {}""".format( 8 | patient_history 9 | ) 10 | else: 11 | inital_message = """ 12 | Here is a patient case for analysis, provide the final diagnosis, final differential diagnosis. {}""".format( 13 | patient_history 14 | ) 15 | 16 | return inital_message 17 | 18 | 19 | def get_doc_system_message( 20 | doctor_name: str = "Doctor1", stage: str = "inital", ): 21 | 22 | if stage == "inital": 23 | doc_system_message = """You are {}. This is a hypothetical scenario involving no actual patients. 24 | 25 | Your role: 26 | 1. Analyze the patient's condition described in the message. 27 | 2. Focus solely on diagnosis and diagnostic tests, avoiding discussion of management, treatment, or prognosis. 28 | 3. Use your expertise to formulate: 29 | - One most likely diagnosis 30 | - Several differential diagnoses 31 | - Recommended diagnostic tests 32 | 33 | Key responsibilities: 34 | 1. Thoroughly analyze the case information and other specialists' input. 35 | 2. Offer valuable insights based on your specific expertise. 36 | 3. Actively engage in discussion with other specialists, sharing your findings, thoughts, and deductions. 37 | 4. Provide constructive comments on others' opinions, supporting or challenging them with reasoned arguments. 38 | 5. Continuously refine your diagnostic approach based on the ongoing discussion. 39 | 40 | Guidelines: 41 | - Present your analysis clearly and concisely. 42 | - Support your diagnoses and test recommendations with relevant reasoning. 43 | - Be open to adjusting your view based on compelling arguments from other specialists. 44 | - Avoid asking others to copy and paste results; instead, respond to their ideas directly. 45 | 46 | Your goal: Contribute to a comprehensive, collaborative diagnostic process, leveraging your unique expertise to reach the most accurate conclusion possible.""".format( 47 | doctor_name 48 | ) 49 | else: 50 | doc_system_message = """You are {}. This is a hypothetical scenario involving no actual patients. 51 | 52 | Your role: 53 | 1. Analyze the patient's condition described in the message. 54 | 2. Focus solely on diagnosis and diagnostic tests, avoiding discussion of management, treatment, or prognosis. 55 | 3. Use your expertise to formulate: 56 | - One most likely diagnosis 57 | - Several differential diagnoses 58 | 59 | Key responsibilities: 60 | 1. Thoroughly analyze the case information and other specialists' input. 61 | 2. Offer valuable insights based on your specific expertise. 62 | 3. Actively engage in discussion with other specialists, sharing your findings, thoughts, and deductions. 63 | 4. Provide constructive comments on others' opinions, supporting or challenging them with reasoned arguments. 64 | 5. Continuously refine your diagnostic approach based on the ongoing discussion. 65 | 66 | Guidelines: 67 | - Present your analysis clearly and concisely. 68 | - Support your diagnoses and test recommendations with relevant reasoning. 69 | - Be open to adjusting your view based on compelling arguments from other specialists. 70 | - Avoid asking others to copy and paste results; instead, respond to their ideas directly. 71 | 72 | Your goal: Contribute to a comprehensive, collaborative diagnostic process, leveraging your unique expertise to reach the most accurate conclusion possible.""".format( 73 | doctor_name 74 | ) 75 | 76 | return doc_system_message 77 | 78 | 79 | def get_supervisor_system_message( 80 | stage: str = "inital", 81 | use_specialist: bool = False, 82 | specialists: Optional[list] = None, 83 | ): 84 | if use_specialist == True: 85 | assert specialists != None 86 | if stage == "inital": 87 | supervisor_system_message = """You are the Medical Supervisor in a hypothetical scenario. 88 | 89 | Your role: 90 | 1. Oversee and evaluate suggestions and decisions made by medical doctors. 91 | 2. Challenge diagnoses and proposed tests, identifying any critical points missed. 92 | 3. Facilitate discussion between doctors, helping them refine their answers. 93 | 4. Drive consensus among doctors, focusing solely on diagnosis and diagnostic tests. 94 | Key tasks: 95 | 96 | - Identify inconsistencies and suggest modifications. 97 | - Even when decisions seem consistent, critically assess if further modifications are necessary. 98 | - Provide additional suggestions to enhance diagnostic accuracy. 99 | - Ensure all doctors' views are completely aligned before concluding the discussion. 100 | 101 | For each response: 102 | 1. Present your insights and challenges to the doctors' opinions. 103 | 2. Summarize the current state of diagnosis in the following JSON format: 104 | ```json 105 | {{ 106 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 107 | "Differential Diagnosis": "[current list of differential diagnoses]", 108 | "Recommended Tests": "[current consensus on recommended diagnostic tests]", 109 | "Areas of Disagreement": "[list any remaining points of contention or areas needing further discussion]" 110 | }} 111 | ``` 112 | 113 | Guidelines: 114 | - Promote discussion unless there's absolute consensus. 115 | - Continue dialogue if any disagreement or room for refinement exists. 116 | - Output "TERMINATE" only when: 117 | 1. All doctors fully agree. 118 | 2. No further discussion is needed. 119 | 3. All diagnostic possibilities are explored. 120 | 4. All recommended tests are justified and agreed upon. 121 | 122 | Your goal: Ensure comprehensive, accurate diagnosis through collaborative expert discussion.""" 123 | 124 | else: 125 | supervisor_system_message = """You are the Medical Supervisor in a hypothetical scenario. 126 | 127 | Your role: 128 | 1. Oversee and evaluate suggestions and decisions made by medical doctors. 129 | 2. Challenge diagnoses and proposed tests, identifying any critical points missed. 130 | 3. Facilitate discussion between doctors, helping them refine their answers. 131 | 4. Drive consensus among doctors, focusing solely on diagnosis and diagnostic tests. 132 | Key tasks: 133 | 134 | - Identify inconsistencies and suggest modifications. 135 | - Even when decisions seem consistent, critically assess if further modifications are necessary. 136 | - Provide additional suggestions to enhance diagnostic accuracy. 137 | - Ensure all doctors' views are completely aligned before concluding the discussion. 138 | 139 | For each response: 140 | 1. Present your insights and challenges to the doctors' opinions. 141 | 2. Summarize the current state of diagnosis in the following JSON format: 142 | ```json 143 | {{ 144 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 145 | "Differential Diagnosis": "[current list of differential diagnoses]", 146 | "Areas of Disagreement": "[list any remaining points of contention or areas needing further discussion]" 147 | }} 148 | ``` 149 | 150 | Guidelines: 151 | - Promote discussion unless there's absolute consensus. 152 | - Continue dialogue if any disagreement or room for refinement exists. 153 | - Output "TERMINATE" only when: 154 | 1. All doctors fully agree. 155 | 2. No further discussion is needed. 156 | 3. All diagnostic possibilities are explored. 157 | 4. All recommended tests are justified and agreed upon. 158 | 159 | Your goal: Ensure comprehensive, accurate diagnosis through collaborative expert discussion.""" 160 | else: 161 | 162 | if stage == "inital": 163 | supervisor_system_message = """You are the Medical Supervisor in a hypothetical scenario. 164 | 165 | Your role: 166 | 1. Oversee and evaluate suggestions and decisions made by medical specialists (the list of specialists is {}). 167 | 2. Challenge diagnoses and proposed tests, identifying any critical points missed. 168 | 3. Facilitate discussion between specialists, helping them refine their answers. 169 | 4. Drive consensus among specialists, focusing solely on diagnosis and diagnostic tests. 170 | Key tasks: 171 | 172 | - Identify inconsistencies and suggest modifications. 173 | - Even when decisions seem consistent, critically assess if further modifications are necessary. 174 | - Provide additional suggestions to enhance diagnostic accuracy. 175 | - Ensure all specialists' views are completely aligned before concluding the discussion. 176 | 177 | For each response: 178 | 1. Present your insights and challenges to the specialists' opinions. 179 | 2. Summarize the current state of diagnosis in the following JSON format: 180 | ```json 181 | {{ 182 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 183 | "Differential Diagnosis": "[current list of differential diagnoses]", 184 | "Recommended Tests": "[current consensus on recommended diagnostic tests]", 185 | "Areas of Disagreement": "[list any remaining points of contention or areas needing further discussion]" 186 | }} 187 | ``` 188 | 189 | Guidelines: 190 | - Promote discussion unless there's absolute consensus. 191 | - Continue dialogue if any disagreement or room for refinement exists. 192 | - Output "TERMINATE" only when: 193 | 1. All specialists fully agree. 194 | 2. No further discussion is needed. 195 | 3. All diagnostic possibilities are explored. 196 | 4. All recommended tests are justified and agreed upon. 197 | 198 | Your goal: Ensure comprehensive, accurate diagnosis through collaborative expert discussion.""".format( 199 | specialists 200 | ) 201 | else: 202 | supervisor_system_message = """You are the Medical Supervisor in a hypothetical scenario. 203 | 204 | Your role: 205 | 1. Oversee and evaluate suggestions and decisions made by medical specialists (the list of specialists is {}). 206 | 2. Challenge diagnoses and proposed tests, identifying any critical points missed. 207 | 3. Facilitate discussion between specialists, helping them refine their answers. 208 | 4. Drive consensus among specialists, focusing solely on diagnosis and diagnostic tests. 209 | 210 | Key tasks: 211 | - Identify inconsistencies and suggest modifications. 212 | - Even when decisions seem consistent, critically assess if further modifications are necessary. 213 | - Provide additional suggestions to enhance diagnostic accuracy. 214 | - Ensure all specialists' views are completely aligned before concluding the discussion. 215 | 216 | For each response: 217 | 1. Present your insights and challenges to the specialists' opinions. 218 | 2. Summarize the current state of diagnosis in the following JSON format: 219 | ```json 220 | {{ 221 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 222 | "Differential Diagnosis": "[current list of differential diagnoses]", 223 | "Areas of Disagreement": "[list any remaining points of contention or areas needing further discussion]" 224 | }} 225 | ``` 226 | 227 | Guidelines: 228 | - Promote discussion unless there's absolute consensus. 229 | - Continue dialogue if any disagreement or room for refinement exists. 230 | - Output "TERMINATE" only when: 231 | 1. All specialists fully agree. 232 | 2. No further discussion is needed. 233 | 3. All diagnostic possibilities are explored. 234 | 4. All recommended tests are justified and agreed upon. 235 | 236 | Your goal: Ensure comprehensive, accurate diagnosis through collaborative expert discussion.""".format( 237 | specialists 238 | ) 239 | 240 | return supervisor_system_message 241 | 242 | 243 | 244 | def get_consultant_message(case_presentation:str, num_specialists:int): 245 | 246 | consultant_message = """ 247 | candidate_specialists = ["Cardiologist", "Pulmonologist", "Gastroenterologist", "Neurologist", "Nephrologist", "Endocrinologist", "Hematologist", "Rheumatologist", 248 | "Infectious disease specialist", "Oncologist", "General surgeon", "Cardiothoracic surgeon", "Neurosurgeon", "Orthopedic surgeon", "Urologist", "Plastic and reconstructive surgeon", 249 | "Gynecologist", "Obstetrician", "Reproductive endocrinologist", "Neonatologist", "Pediatrician", "Pediatric surgeon", "Ophthalmologist", "Otolaryngologist", 250 | "Dentist", "Dermatologist", "Psychiatrist", "Rehabilitation specialist", "Emergency physician", "Anesthesiologist", "Radiologist", "Ultrasonologist", 251 | "Nuclear medicine physician", "Clinical laboratory scientist", "Pathologist", "Pharmacist", "Physical therapist", "Transfusion medicine specialist"] 252 | 253 | patient's medical history = {case_presentation} 254 | 255 | When recommending the appropriate specialist, you need to complete the following steps: 256 | 1. Carefully read the medical scenario presented in . 257 | 2. Based on the medical scenario, calculate the relevance of each specialist in the with , and select the top {top_k} most relevant specialists as top_k_specialists. 258 | 259 | The output must be formatted in JSON as follows: 260 | ```json 261 | {{ 262 | "top_k_specialists": [top_k_specialist list], 263 | }} 264 | ``` 265 | """.format( 266 | case_presentation=case_presentation, top_k=num_specialists 267 | ) 268 | 269 | return consultant_message 270 | 271 | 272 | 273 | def get_evaluate_prompts(): 274 | MOST_PPROMPT_TEMPLATE: str = """Your evaluation should be based on the correct diagnosis and according to the scoring criteria. The correct diagnosis 275 | is "{correct_diagnosis}". The student's suggested diagnosis is "{diagnosis}". 276 | Scoring Criteria: 277 | - 5: The actual diagnosis was suggested 278 | - 4: The suggestions included something very close, but not exact 279 | - 3: The suggestions included something closely related that might have been helpful 280 | - 2: The suggestions included something related, but unlikely to be helpful 281 | - 0: No suggestions close 282 | What would be the score based on these criteria? 283 | Provide brief explanation for your choice. Do not expand the explanation, do not use line breaks, and write it in one paragraph. 284 | Output the final answer in json: 285 | ```json 286 | {{ 287 | "Score": "[numberic]", 288 | "Explanation": "[Words]", 289 | }} 290 | ```.""" 291 | 292 | POSSI_PPROMPT_TEMPLATE: str = """Your evaluation should be based on the correct diagnosis and according to the scoring criteria. The correct diagnosis 293 | is "{correct_diagnosis}". The student's suggested possible diagnosis includes "{possible_diagnoses}". 294 | Scoring Criteria: 295 | - 5: The actual diagnosis was suggested in the differential 296 | - 4: The suggestions included something very close, but not exact 297 | - 3: The suggestions included something closely related that might have been helpful 298 | - 2: The suggestions included something related, but unlikely to be helpful 299 | - 0: No suggestions close 300 | What would be the score based on these criteria? 301 | Provide brief explanation for your choice. Do not expand the explanation, do not use line breaks, and write it in one paragraph. 302 | Output the final answer in json: 303 | ```json 304 | {{ 305 | "Score": "[numberic]", 306 | "Explanation": "[Words]", 307 | }} 308 | ```.""" 309 | 310 | ROM_T_PPROMPT_TEMPLATE: str = """You should evaluate if the tests would be helpful in reaching the final diagnosis of "{correct_diagnosis}". 311 | The student's recommended tests are "{recommended_tests}". 312 | Scoring Criteria: 313 | - 5: Strongly agree that the tests are helpful 314 | - 4: Agree that the tests are helpful 315 | - 3: Neutral 316 | - 2: Disagree that the tests are helpful 317 | - 1: Strongly Disagree that the tests are helpful 318 | What would be the score based on these criteria? 319 | Provide brief explanation for your choice. Do not expand the explanation, do not use line breaks, and write it in one paragraph. 320 | Output the final answer in json: 321 | ```json 322 | {{ 323 | "Score": "[numberic]", 324 | "Explanation": "[Words]", 325 | }} 326 | ```.""" 327 | return MOST_PPROMPT_TEMPLATE, POSSI_PPROMPT_TEMPLATE, ROM_T_PPROMPT_TEMPLATE -------------------------------------------------------------------------------- /self-refine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import argparse 5 | import re 6 | import pandas as pd 7 | 8 | from typing import List, Dict, Optional, Union 9 | 10 | import os.path as osp 11 | from tqdm import tqdm 12 | from functools import wraps 13 | 14 | from autogen.io import IOStream 15 | from autogen.formatting_utils import colored 16 | from autogen import ConversableAgent, config_list_from_json 17 | from autogen.agentchat.utils import gather_usage_summary 18 | from autogen.code_utils import content_str 19 | from medcs.dataset import MedDataset 20 | 21 | class Prompt: 22 | def __init__( 23 | self, 24 | question_prefix: str, 25 | answer_prefix: str, 26 | intra_example_sep: str, 27 | inter_example_sep: str, 28 | engine: str = None, 29 | temperature: float = None, 30 | ) -> None: 31 | self.question_prefix = question_prefix 32 | self.answer_prefix = answer_prefix 33 | self.intra_example_sep = intra_example_sep 34 | self.inter_example_sep = inter_example_sep 35 | self.engine = engine 36 | self.temperature = temperature 37 | 38 | def make_query(self, prompt: str, question: str) -> str: 39 | return ( 40 | f"{prompt}{self.question_prefix}{question}{self.intra_example_sep}{self.answer_prefix}" 41 | ) 42 | 43 | 44 | class ResponseGenTaskInit(Prompt): 45 | def __init__(self, engine: str) -> None: 46 | super().__init__( 47 | question_prefix="Conversation history: ", 48 | answer_prefix="Response: ", 49 | intra_example_sep="\n\n", 50 | inter_example_sep="\n\n###\n\n", 51 | ) 52 | 53 | self.stage = engine 54 | 55 | 56 | def make_query(self, context: str) -> str: 57 | if self.stage == "inital": 58 | query = """ 59 | Patient Presentation: 60 | {} 61 | 62 | Output in the following JSON format: 63 | ```json 64 | {{ 65 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 66 | "Differential Diagnosis": "[current list of differential diagnoses]", 67 | "Recommended Tests": "[current consensus on recommended diagnostic tests]", 68 | }} 69 | ```.""" 70 | else: 71 | query = """ 72 | Patient Presentation: 73 | {} 74 | 75 | Output in the following JSON format: 76 | ```json 77 | {{ 78 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 79 | "Differential Diagnosis": "[current list of differential diagnoses]", 80 | }} 81 | ```.""" 82 | query = query.format(context) 83 | message = [{"content": query, "role": "user"}] 84 | return message 85 | 86 | def __call__(self, agent, context: str) -> str: 87 | message = self.make_query(context) 88 | # import pdb;pdb.set_trace() 89 | generated_response = agent.generate_reply(message) 90 | generated_response = content_str(generated_response) 91 | return generated_response 92 | 93 | 94 | 95 | class ResponseGenFeedback(Prompt): 96 | def __init__(self, engine: str) -> None: 97 | super().__init__( 98 | question_prefix="", 99 | answer_prefix="", 100 | intra_example_sep="\n\n", 101 | inter_example_sep="\n\n###\n\n", 102 | ) 103 | self.stage = engine 104 | 105 | def __call__(self, agent, context: str, response: str): 106 | message = self.get_prompt_with_question(context=context, response=response) 107 | generated_response = agent.generate_reply(message) 108 | generated_feedback = content_str(generated_response) 109 | generated_feedback = parse_json(generated_feedback) 110 | 111 | return generated_feedback 112 | 113 | def get_prompt_with_question(self, context: str, response: str): 114 | context = context.replace('System: ', '').replace('User: ', '') 115 | query = self.make_query(context=context, response=response) 116 | 117 | message = [{"content": query, "role": "user"}] 118 | return message 119 | 120 | def make_query(self, context: str, response: str): 121 | if self.stage == "inital": 122 | query = """ 123 | Patient Presentation: 124 | {context} 125 | 126 | Response: 127 | {response} 128 | 129 | Based on the Patient Presentation, you should evaluate the corresponding score for the Response. 130 | Scoring Criteria: 131 | - 50: The diagnosis and recommended tests were suggested 132 | - 40: The suggestions included something very close, but not exact 133 | - 30: The suggestions included something closely related that might have been helpful 134 | - 20: The suggestions included something related, but unlikely to be helpful 135 | - 0: No suggestions close 136 | What would be the score based on these criteria? 137 | Be critical, always find mistake and be cautious with the score. 138 | The scores can be any integral between either two category, such as 5, 15, 25, 35, 45, 50. 139 | Provide brief explanation for your choice. Do not expand the explanation, do not use line breaks, and write it in one paragraph. 140 | Output the final answer in json: 141 | ```json 142 | {{ 143 | "Score": "[numberic, 0~50]", 144 | "Explanation": "[Words]", 145 | }} 146 | ```.""" 147 | else: 148 | query = """ 149 | Patient Presentation: 150 | {context} 151 | 152 | Response: 153 | {response} 154 | 155 | Based on the Patient Presentation, you should evaluate the corresponding score for the Response. 156 | Scoring Criteria: 157 | - 50: The diagnosis was suggested 158 | - 40: The suggestions included something very close, but not exact 159 | - 30: The suggestions included something closely related that might have been helpful 160 | - 20: The suggestions included something related, but unlikely to be helpful 161 | - 0: No suggestions close 162 | What would be the score based on these criteria? 163 | Be critical, always find mistake and be cautious with the score. 164 | The scores can be any integral between either two category, such as 5, 15, 25, 35, 45, 50. 165 | Provide brief explanation for your choice. Do not expand the explanation, do not use line breaks, and write it in one paragraph. 166 | Output the final answer in json: 167 | ```json 168 | {{ 169 | "Score": "[numberic, 0~50]", 170 | "Explanation": "[Words]", 171 | }} 172 | ```.""" 173 | query = query.format(context=context, response=response) 174 | 175 | return query 176 | 177 | 178 | 179 | 180 | class ResponseGenTaskIterate(Prompt): 181 | def __init__(self, engine: str) -> None: 182 | super().__init__( 183 | question_prefix="", 184 | answer_prefix="", 185 | intra_example_sep="\n\n", 186 | inter_example_sep="\n\n###\n\n", 187 | ) 188 | self.stage = engine 189 | 190 | def make_query(self, example_input): 191 | """Given a list of examples that are incrementally improving, return a new example. 192 | """ 193 | 194 | instr = """We want to iteratively improve the provided responses. To help improve, suggestion for each response on desired traits are provided: 1) Score, 2) Explanation. 195 | 196 | """ 197 | if self.stage == "inital": 198 | 199 | template = """Conversation history: 200 | {history} 201 | 202 | Output in the following JSON format: 203 | ```json 204 | {{ 205 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 206 | "Differential Diagnosis": "[current list of differential diagnoses]", 207 | "Recommended Tests": "[current consensus on recommended diagnostic tests]", 208 | }} 209 | ```. 210 | """ 211 | 212 | else: 213 | template = """ 214 | {history} 215 | 216 | Output in the following JSON format: 217 | ```json 218 | {{ 219 | "Most Likely Diagnosis": "[current consensus on most likely diagnosis]", 220 | "Differential Diagnosis": "[current list of differential diagnoses]", 221 | }} 222 | ```. 223 | """ 224 | 225 | prompt = template.format( 226 | history=example_input, 227 | ) 228 | 229 | query = instr + prompt 230 | return query.strip() 231 | 232 | 233 | def __call__( 234 | self, 235 | agent, 236 | responses_to_scores: Dict[str, str], 237 | ) -> str: 238 | example_input = self.make_input( 239 | responses_to_scores=responses_to_scores 240 | ) 241 | transfer_query = self.make_query(example_input) 242 | message = [{"content": transfer_query, "role": "user"}] 243 | modelresponse = agent.generate_reply(message) 244 | modelresponse = content_str(modelresponse) # Ensure it's a string 245 | 246 | return modelresponse 247 | 248 | 249 | def make_input( 250 | self, 251 | responses_to_scores: Dict[str, str], 252 | ) -> str: 253 | input_txt = "" 254 | for response, (context, score, explanation) in responses_to_scores.items(): 255 | context = context.replace('System: ', '').replace('User: ', '') 256 | input_txt += self._make_input( 257 | context=context, 258 | response=response, 259 | score=score, 260 | explanation=explanation, 261 | ) 262 | return input_txt 263 | 264 | 265 | def _make_input( 266 | self, 267 | context: str, 268 | response: str, 269 | score: str, 270 | explanation: str, 271 | ) -> str: 272 | context = context.replace('System: ', '').replace('User: ', '') 273 | input_txt = f"""Conversation history: 274 | 275 | {context} 276 | 277 | Response: {response} 278 | 279 | Score: {score} 280 | 281 | Explanation: {explanation} 282 | 283 | Okay, let's use this feedback to improve the response. 284 | """ 285 | 286 | return input_txt 287 | 288 | 289 | def parse_args(): 290 | parser = argparse.ArgumentParser(description="Medagents Setting") 291 | parser.add_argument( 292 | "--config", 293 | type=str, 294 | default="configs/OAI_Config_List.json", 295 | help="the llm models", 296 | ) 297 | parser.add_argument( 298 | "--model_name", 299 | type=str, 300 | default="x_gpt4o", 301 | choices=["x_gpt35_turbo", "x_gpt4_turbo", "x_gpt4o", "llama3.1"], 302 | help="the llm models", 303 | ) 304 | parser.add_argument( 305 | "--dataset_name", 306 | type=str, 307 | default="rare_disease_302", 308 | choices=["rare_disease_302", "rare_disease_150", "rare_disease_152"], 309 | help="choice different dataset", 310 | ) 311 | parser.add_argument( 312 | "--temperature", 313 | type=float, 314 | default=0.7, 315 | ) 316 | parser.add_argument( 317 | "--stage", 318 | type=str, 319 | default="inital", 320 | choices=["inital", "follow_up"], 321 | help="choice different stages", 322 | ) 323 | parser.add_argument( 324 | "--times", 325 | type=int, 326 | default=1, 327 | choices=[1, 2, 3, 4, 5], 328 | help="choice different stages", 329 | ) 330 | parser.add_argument( 331 | "--output_dir", 332 | type=str, 333 | default="output", 334 | help="log file", 335 | ) 336 | args = parser.parse_args() 337 | 338 | return args 339 | 340 | 341 | def load(data_path): 342 | with open(data_path, "r") as file: 343 | data = json.load(file) 344 | return data 345 | 346 | 347 | def parse_json(text): 348 | # Unified handling of all possible markdown code block formats 349 | patterns = [ 350 | r"```(?:json|JSON)\s*(.*?)\s*```", # Match ```json or ```JSON 351 | r"```\s*(.*?)\s*```", # Match regular ``` 352 | r"({.*?})" # Match bare JSON objects 353 | ] 354 | 355 | # Try all pattern matches 356 | json_str = None 357 | for pattern in patterns: 358 | match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) 359 | if match: 360 | json_str = match.group(1).strip() 361 | break 362 | 363 | # If no pattern matches, try parsing the entire text 364 | if json_str is None: 365 | json_str = text.strip() 366 | 367 | try: 368 | # Clean up the string and parse JSON 369 | json_str = json_str.replace('\n', '') 370 | return json.loads(json_str) 371 | except json.JSONDecodeError as e: 372 | raise ValueError(f"Invalid JSON format: {str(e)}") 373 | 374 | 375 | def simple_retry(max_attempts=100, delay=1): 376 | def decorator(func): 377 | @wraps(func) 378 | def wrapper(*args, **kwargs): 379 | for attempt in range(max_attempts): 380 | try: 381 | return func(*args, **kwargs) 382 | except Exception as e: 383 | if attempt < max_attempts - 1: 384 | print( 385 | f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {delay} second..." 386 | ) 387 | time.sleep(delay) 388 | else: 389 | print( 390 | f"All {max_attempts} attempts failed. Last error: {str(e)}" 391 | ) 392 | raise 393 | 394 | return wrapper 395 | 396 | return decorator 397 | 398 | 399 | # @simple_retry(max_attempts=100, delay=1) 400 | def process_single_case( 401 | args, 402 | output_dir, 403 | json_name, 404 | case_type, 405 | case_crl, 406 | presentation, 407 | message, 408 | case_name, 409 | model_config, 410 | task_init, 411 | task_feedback, 412 | task_iterate, 413 | n_model 414 | ): 415 | case_info = {} 416 | 417 | doctor_agent = ConversableAgent( 418 | "Doctor", 419 | llm_config=model_config, 420 | human_input_mode="NEVER", # Never ask for human input. 421 | ) 422 | 423 | feedback_agent = ConversableAgent( 424 | "Fed", 425 | llm_config=model_config, 426 | human_input_mode="NEVER", # Never ask for human input. 427 | ) 428 | 429 | context = message 430 | 431 | max_attempts = 4 432 | n_attempts = 0 433 | best_response = None 434 | responses_to_scores = dict() 435 | all_responses_to_scores = dict() 436 | 437 | best_score_so_far = 0 438 | reduce_window = 0 439 | 440 | 441 | iostream = IOStream.get_default() 442 | 443 | while n_attempts < max_attempts: 444 | if n_attempts == 0: 445 | response = task_init(agent=doctor_agent, context=context) 446 | else: 447 | response = task_iterate(agent=doctor_agent, responses_to_scores=responses_to_scores) 448 | 449 | # Ensure response is a string 450 | response_str = str(response) 451 | 452 | case_cost = gather_usage_summary([doctor_agent])["usage_including_cached_inference"]["total_cost"] 453 | total_tokens = gather_usage_summary([doctor_agent])["usage_including_cached_inference"][n_model]["total_tokens"] 454 | 455 | if total_tokens > 3000: 456 | reduce_window += 1 457 | if total_tokens > 3500: 458 | reduce_window += 1 459 | 460 | iostream.print(colored(doctor_agent.name, "yellow"), "response:", flush=True) 461 | iostream.print(colored(response_str, ), flush=True) 462 | iostream.print(colored("*" * 60, "light_cyan"), flush=True) 463 | iostream.print(colored("Costs: ", "yellow"), colored(case_cost, "red"), colored("Tokens: ", "yellow"), colored(total_tokens, "red"), flush=True) 464 | 465 | scores = task_feedback(agent=feedback_agent, context=context, response=response_str) 466 | 467 | case_cost = gather_usage_summary([doctor_agent])["usage_including_cached_inference"]["total_cost"] 468 | total_tokens = gather_usage_summary([doctor_agent])["usage_including_cached_inference"][n_model]["total_tokens"] 469 | 470 | iostream.print(colored(doctor_agent.name, "yellow"), "scores:", flush=True) 471 | iostream.print(colored(str(scores), ), flush=True) 472 | iostream.print(colored("*" * 60, "light_cyan"), flush=True) 473 | iostream.print(colored("Costs: ", "yellow"), colored(case_cost, "red"), colored("Tokens: ", "yellow"), colored(total_tokens, "red"), flush=True) 474 | 475 | score = int(scores["Score"]) 476 | 477 | all_responses_to_scores[response_str] = { 478 | "n_attempts": n_attempts, 479 | "score": score, 480 | "explanation": scores["Explanation"], 481 | "context": context, 482 | } 483 | 484 | if score >= best_score_so_far: # Only iterate over responses that are improving 485 | best_score_so_far = score 486 | best_response = response_str 487 | responses_to_scores[response_str] = (context, scores["Score"], scores["Explanation"]) 488 | else: 489 | print(f"Score of {response_str} is {score}, which is less than the current best of {best_score_so_far}") 490 | n_attempts += 1 491 | 492 | # At the end, 'best_response' is a string. 493 | json_output = parse_json(best_response) 494 | 495 | case_info["Type"] = case_type 496 | case_info["Crl"] = case_crl 497 | case_info["Cost"] = case_cost 498 | case_info["Presentation"] = presentation 499 | case_info["Name"] = case_name 500 | case_info["Most Likely"] = json_output["Most Likely Diagnosis"] 501 | case_info["Other Possible"] = json_output["Differential Diagnosis"] 502 | 503 | if args.stage == "inital": 504 | case_info["Recommended Tests"] = json_output["Recommended Tests"] 505 | 506 | recorder_path = osp.join(output_dir, json_name) 507 | 508 | with open(recorder_path, "w") as file: 509 | json.dump(case_info, file, indent=4) 510 | 511 | 512 | def main(): 513 | args = parse_args() 514 | 515 | filter_criteria = { 516 | "tags": [args.model_name], 517 | } 518 | 519 | config_list = config_list_from_json( 520 | env_or_file=args.config, filter_dict=filter_criteria 521 | ) 522 | n_model = config_list[0]["model"] 523 | model_config = { 524 | "cache_seed": None, 525 | "temperature": args.temperature, 526 | "config_list": config_list, 527 | "timeout": 120, 528 | } 529 | 530 | dataset = MedDataset(dataname=args.dataset_name) 531 | 532 | data_len = len(dataset) 533 | 534 | for idx in tqdm(range(data_len)): 535 | 536 | ( 537 | case_type, 538 | case_name, 539 | case_crl, 540 | case_initial_presentation, 541 | case_follow_up_presentation, 542 | ) = dataset[idx] 543 | 544 | json_name = f"{case_crl}.json" 545 | 546 | output_dir = osp.join( 547 | args.output_dir, 548 | "SELF_REFINE", 549 | args.stage, 550 | args.model_name, 551 | f"temp{str(args.temperature)}", 552 | str(args.times), 553 | ) 554 | 555 | if not osp.exists(output_dir): 556 | os.makedirs(output_dir) 557 | 558 | file_names = os.listdir(output_dir) 559 | 560 | json_files = [file for file in file_names if file.endswith(".json")] 561 | 562 | if json_name in json_files: 563 | continue 564 | 565 | # Generation of the first response 566 | task_init = ResponseGenTaskInit(engine=args.stage) 567 | 568 | # Getting feedback 569 | task_feedback = ResponseGenFeedback(engine=args.stage) 570 | 571 | # Iteratively improving the response 572 | task_iterate = ResponseGenTaskIterate(engine=args.stage) 573 | 574 | if args.stage == "inital": 575 | presentation = case_initial_presentation 576 | message = """ 577 | Here is a patient case for analysis, provide the final diagnosis, final differential diagnosis and recommended tests. 578 | {patient_history}""".format( 579 | patient_history=presentation 580 | ) 581 | else: 582 | presentation = case_follow_up_presentation 583 | message = """ 584 | Here is a patient case for analysis, provide the final diagnosis, final differential diagnosis. 585 | {patient_history}""".format( 586 | patient_history=case_follow_up_presentation 587 | ) 588 | 589 | try: 590 | process_single_case( 591 | args, 592 | output_dir, 593 | json_name, 594 | case_type, 595 | case_crl, 596 | message, 597 | presentation, 598 | case_name, 599 | model_config, 600 | task_init, 601 | task_feedback, 602 | task_iterate, 603 | n_model 604 | ) 605 | except Exception as e: 606 | print(f"Failed to process case {idx} after all attempts: {str(e)}") 607 | continue 608 | 609 | 610 | if __name__ == "__main__": 611 | 612 | main() 613 | --------------------------------------------------------------------------------