├── pyproject.toml ├── medagentboard ├── utils │ ├── encode_image.py │ ├── llm_configs.py │ ├── json_utils.py │ └── llm_scoring.py ├── medqa │ ├── run_colacare_diverse_llms.sh │ ├── run.sh │ ├── evaluate.py │ ├── preprocess_datasets.py │ ├── multi_agent_healthcareagent.py │ ├── multi_agent_mac.py │ ├── single_llm.py │ └── multi_agent_reconcile.py ├── laysummary │ ├── run.sh │ ├── preprocess_datasets.py │ └── evaluation.py └── ehr │ ├── run.sh │ ├── preprocess_dataset.py │ └── multi_agent_reconcile.py ├── .gitignore └── README.md /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "medagentboard" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "openai>=1.69.0", 9 | "pandas>=2.2.3", 10 | "python-dotenv>=1.1.0", 11 | ] 12 | -------------------------------------------------------------------------------- /medagentboard/utils/encode_image.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | 4 | def encode_image(image_path: str) -> str: 5 | """ 6 | Encode an image file as a base64 string. 7 | 8 | Args: 9 | image_path: Path to the image file 10 | 11 | Returns: 12 | Base64 encoded string of the image 13 | 14 | Raises: 15 | FileNotFoundError: If the image file doesn't exist 16 | IOError: If there's an error reading the image file 17 | """ 18 | if not os.path.isfile(image_path): 19 | raise FileNotFoundError(f"Image file not found: {image_path}") 20 | 21 | try: 22 | with open(image_path, "rb") as image_file: 23 | return base64.b64encode(image_file.read()).decode("utf-8") 24 | except IOError as e: 25 | raise IOError(f"Error reading image file: {e}") -------------------------------------------------------------------------------- /medagentboard/medqa/run_colacare_diverse_llms.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置并发数 4 | MAX_CONCURRENT=4 5 | CURRENT_JOBS=0 6 | 7 | # 创建一个临时文件来跟踪正在运行的进程 8 | TEMP_FILE=$(mktemp) 9 | trap "rm -f $TEMP_FILE" EXIT 10 | 11 | # 运行命令并管理并发 12 | run_command() { 13 | local cmd="$1" 14 | 15 | # 检查当前运行的进程数 16 | CURRENT_JOBS=$(wc -l < $TEMP_FILE) 17 | 18 | # 如果当前运行的进程数达到最大值,等待一个进程完成 19 | while [ $CURRENT_JOBS -ge $MAX_CONCURRENT ]; do 20 | for pid in $(cat $TEMP_FILE); do 21 | if ! kill -0 $pid 2>/dev/null; then 22 | # 进程已结束,从文件中删除 23 | grep -v "^$pid$" $TEMP_FILE > ${TEMP_FILE}.new 24 | mv ${TEMP_FILE}.new $TEMP_FILE 25 | fi 26 | done 27 | CURRENT_JOBS=$(wc -l < $TEMP_FILE) 28 | if [ $CURRENT_JOBS -ge $MAX_CONCURRENT ]; then 29 | sleep 1 30 | fi 31 | done 32 | 33 | # 运行命令 34 | echo "Running: $cmd" 35 | eval "$cmd" & 36 | 37 | # 记录进程ID 38 | echo $! >> $TEMP_FILE 39 | } 40 | 41 | # 定义数据集和任务类型 42 | QA_DATASETS=("MedQA" "PubMedQA") 43 | VQA_DATASETS=("PathVQA" "VQA-RAD") 44 | 45 | # 为每个数据集定义可用的qa_type 46 | declare -A DATASET_QA_TYPES 47 | DATASET_QA_TYPES[MedQA]="mc" 48 | DATASET_QA_TYPES[PubMedQA]="mc" 49 | DATASET_QA_TYPES[PathVQA]="mc" 50 | DATASET_QA_TYPES[VQA-RAD]="mc" 51 | 52 | echo "Starting experiments..." 53 | 54 | # 1. ColaCare MedQA 55 | for dataset in "${QA_DATASETS[@]}"; do 56 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 57 | cmd="python -m medagentboard.medqa.multi_agent_colacare --dataset $dataset --qa_type $qa_type --meta_model deepseek-v3-official --doctor_models deepseek-v3-official qwen-max-latest qwen3-235b-a22b" 58 | run_command "$cmd" 59 | done 60 | done 61 | 62 | # 等待所有任务完成 63 | wait 64 | 65 | echo "All experiments completed!" -------------------------------------------------------------------------------- /medagentboard/laysummary/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置并发数 4 | MAX_CONCURRENT=4 5 | CURRENT_JOBS=0 6 | 7 | # 创建一个临时文件来跟踪正在运行的进程 8 | TEMP_FILE=$(mktemp) 9 | trap "rm -f $TEMP_FILE" EXIT 10 | 11 | # 运行命令并管理并发 12 | run_command() { 13 | local cmd="$1" 14 | 15 | # 检查当前运行的进程数 16 | CURRENT_JOBS=$(wc -l < $TEMP_FILE) 17 | 18 | # 如果当前运行的进程数达到最大值,等待一个进程完成 19 | while [ $CURRENT_JOBS -ge $MAX_CONCURRENT ]; do 20 | for pid in $(cat $TEMP_FILE); do 21 | if ! kill -0 $pid 2>/dev/null; then 22 | # 进程已结束,从文件中删除 23 | grep -v "^$pid$" $TEMP_FILE > ${TEMP_FILE}.new 24 | mv ${TEMP_FILE}.new $TEMP_FILE 25 | fi 26 | done 27 | CURRENT_JOBS=$(wc -l < $TEMP_FILE) 28 | if [ $CURRENT_JOBS -ge $MAX_CONCURRENT ]; then 29 | sleep 1 30 | fi 31 | done 32 | 33 | # 运行命令 34 | echo "Running: $cmd" 35 | eval "$cmd" & 36 | 37 | # 记录进程ID 38 | echo $! >> $TEMP_FILE 39 | } 40 | 41 | # 定义所有数据集 42 | DATASETS=("PLABA" "cochrane" "elife" "med_easi" "plos_genetics") 43 | 44 | # 定义single_llm的prompt类型 45 | PROMPT_TYPES=("basic" "optimized" "few_shot") 46 | 47 | # 定义AgentSimp的参数组合 48 | COMMUNICATION_TYPES=("pipeline" "synchronous") 49 | CONSTRUCTION_TYPES=("direct" "iterative") 50 | 51 | # 定义模型 52 | MODEL="deepseek-v3-official" 53 | 54 | echo "Starting experiments..." 55 | 56 | # 1. Single LLM with different prompting strategies 57 | for dataset in "${DATASETS[@]}"; do 58 | for prompt_type in "${PROMPT_TYPES[@]}"; do 59 | cmd="python -m medagentboard.laysummary.single_llm --dataset $dataset --prompt_type $prompt_type --model_key $MODEL" 60 | run_command "$cmd" 61 | done 62 | done 63 | 64 | # 2. AgentSimp with default configuration 65 | for dataset in "${DATASETS[@]}"; do 66 | cmd="python -m medagentboard.laysummary.multi_agent_agentsimp --dataset $dataset --model $MODEL" 67 | run_command "$cmd" 68 | done 69 | 70 | # 等待所有任务完成 71 | wait 72 | 73 | echo "All experiments completed!" -------------------------------------------------------------------------------- /medagentboard/ehr/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置并发数 4 | MAX_CONCURRENT=4 5 | CURRENT_JOBS=0 6 | 7 | # 创建一个临时文件来跟踪正在运行的进程 8 | TEMP_FILE=$(mktemp) 9 | trap "rm -f $TEMP_FILE" EXIT 10 | 11 | # 运行命令并管理并发 12 | run_command() { 13 | local cmd="$1" 14 | 15 | # 检查当前运行的进程数 16 | CURRENT_JOBS=$(wc -l < $TEMP_FILE) 17 | 18 | # 如果当前运行的进程数达到最大值,等待一个进程完成 19 | while [ $CURRENT_JOBS -ge $MAX_CONCURRENT ]; do 20 | for pid in $(cat $TEMP_FILE); do 21 | if ! kill -0 $pid 2>/dev/null; then 22 | # 进程已结束,从文件中删除 23 | grep -v "^$pid$" $TEMP_FILE > ${TEMP_FILE}.new 24 | mv ${TEMP_FILE}.new $TEMP_FILE 25 | fi 26 | done 27 | CURRENT_JOBS=$(wc -l < $TEMP_FILE) 28 | if [ $CURRENT_JOBS -ge $MAX_CONCURRENT ]; then 29 | sleep 1 30 | fi 31 | done 32 | 33 | # 运行命令 34 | echo "Running: $cmd" 35 | eval "$cmd" & 36 | 37 | # 记录进程ID 38 | echo $! >> $TEMP_FILE 39 | } 40 | 41 | # 定义数据集和任务类型 42 | EHR_DATASETS=("mimic-iv" "tjh") 43 | 44 | # 为每个数据集定义可用的task类型 45 | declare -A DATASET_TASKS 46 | DATASET_TASKS[mimic-iv]="mortality readmission" 47 | DATASET_TASKS[tjh]="mortality" # tjh只有mortality任务 48 | 49 | echo "Starting EHR experiments..." 50 | 51 | # 1. ColaCare 52 | for dataset in "${EHR_DATASETS[@]}"; do 53 | for task in ${DATASET_TASKS[$dataset]}; do 54 | cmd="python -m medagentboard.ehr.multi_agent_colacare --dataset $dataset --task $task" 55 | run_command "$cmd" 56 | done 57 | done 58 | 59 | # 2. MedAgent 60 | for dataset in "${EHR_DATASETS[@]}"; do 61 | for task in ${DATASET_TASKS[$dataset]}; do 62 | cmd="python -m medagentboard.ehr.multi_agent_medagent --dataset $dataset --task $task" 63 | run_command "$cmd" 64 | done 65 | done 66 | 67 | # 3. ReConcile 68 | for dataset in "${EHR_DATASETS[@]}"; do 69 | for task in ${DATASET_TASKS[$dataset]}; do 70 | cmd="python -m medagentboard.ehr.multi_agent_reconcile --dataset $dataset --task $task" 71 | run_command "$cmd" 72 | done 73 | done 74 | 75 | 76 | 77 | 78 | 79 | # 等待所有任务完成 80 | wait 81 | 82 | echo "All EHR experiments completed!" -------------------------------------------------------------------------------- /medagentboard/ehr/preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from medagentboard.utils.json_utils import save_json, load_json, load_jsonl 3 | 4 | def curate_instructions_section(text): 5 | start_marker = "Instructions & Output Format:" 6 | end_marker = "floating-point number" 7 | 8 | start_index = text.find(start_marker) 9 | if start_index == -1: 10 | return text # Start marker not found 11 | 12 | end_index = text.find(end_marker, start_index) 13 | if end_index == -1: 14 | return text # End marker not found 15 | 16 | # Keep everything before the start marker and after the end marker 17 | text = text[:start_index] + "Please output a " + text[end_index:] 18 | 19 | start_marker = "Example Format:" 20 | end_marker = "Now, please analyze" 21 | 22 | start_index = text.find(start_marker) 23 | if start_index == -1: 24 | return text # Start marker not found 25 | 26 | end_index = text.find(end_marker, start_index) 27 | if end_index == -1: 28 | return text # End marker not found 29 | 30 | # Keep everything before the start marker and after the end marker 31 | text = text[:start_index] + text[end_index:] 32 | 33 | updated_instruction = text.replace("System Prompt: ", "").replace("User Prompt: ", "").replace("years Your Task: ", "years. ") 34 | return updated_instruction 35 | 36 | # structured EHR 37 | datasets = ["tjh", "mimic-iv"] 38 | tasks = { 39 | "tjh": ["mortality"], 40 | "mimic-iv": ["mortality", "readmission"], 41 | } 42 | 43 | for dataset in datasets: 44 | for task in tasks[dataset]: 45 | print(f"Processing {dataset} {task}") 46 | processed_data = [] 47 | 48 | test_data = pd.read_pickle(f"my_datasets/raw/structured_ehr/{dataset}/{task}/test_data.pkl") 49 | for item in test_data: 50 | qid = item['id'] 51 | question = item['x_ehr_prompt'] 52 | question = curate_instructions_section(question) 53 | answer = item[f'y_{task}'][0] 54 | processed_data.append( 55 | { 56 | "qid": qid, 57 | "question": question, 58 | "answer": answer, 59 | } 60 | ) 61 | save_json(processed_data, f"my_datasets/processed/ehr/{dataset}/ehr_timeseries_{task}_test.json") -------------------------------------------------------------------------------- /medagentboard/utils/llm_configs.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | import os 3 | 4 | load_dotenv() 5 | 6 | LLM_MODELS_SETTINGS = { 7 | "deepseek-v3-official": { 8 | "api_key": os.getenv("DEEPSEEK_API_KEY"), 9 | "base_url": "https://api.deepseek.com", 10 | "model_name": "deepseek-chat", 11 | "comment": "DeepSeek V3 Official", 12 | "reasoning": False, 13 | }, 14 | "deepseek-r1-official": { 15 | "api_key": os.getenv("DEEPSEEK_API_KEY"), 16 | "base_url": "https://api.deepseek.com", 17 | "model_name": "deepseek-reasoner", 18 | "comment": "DeepSeek R1 Reasoning Model Official", 19 | "reasoning": True, 20 | }, 21 | "deepseek-v3-ali": { 22 | "api_key": os.getenv("DASHSCOPE_API_KEY"), 23 | "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", 24 | "model_name": "deepseek-v3", 25 | "comment": "DeepSeek V3 Ali", 26 | "reasoning": False, 27 | }, 28 | "deepseek-r1-ali": { 29 | "api_key": os.getenv("DASHSCOPE_API_KEY"), 30 | "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", 31 | "model_name": "deepseek-r1", 32 | "comment": "DeepSeek R1 Reasoning Model Ali", 33 | "reasoning": True, 34 | }, 35 | "deepseek-v3-ark": { 36 | "api_key": os.getenv("ARK_API_KEY"), 37 | "base_url": "https://ark.cn-beijing.volces.com/api/v3", 38 | "model_name": "deepseek-v3-250324", 39 | "comment": "DeepSeek V3 Ark", 40 | "reasoning": False, 41 | }, 42 | "deepseek-r1-ark": { 43 | "api_key": os.getenv("ARK_API_KEY"), 44 | "base_url": "https://ark.cn-beijing.volces.com/api/v3", 45 | "model_name": "deepseek-r1-250120", 46 | "comment": "DeepSeek R1 Reasoning Model Ark", 47 | "reasoning": True, 48 | }, 49 | "qwen-max-latest": { 50 | "api_key": os.getenv("DASHSCOPE_API_KEY"), 51 | "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", 52 | "model_name": "qwen-max-latest", 53 | "comment": "Qwen Max", 54 | "reasoning": False, 55 | }, 56 | "qwen-vl-max": { 57 | "api_key": os.getenv("DASHSCOPE_API_KEY"), 58 | "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", 59 | "model_name": "qwen-vl-max", 60 | "comment": "qwen-vl-max", 61 | "reasoning": False, 62 | }, 63 | "qwen2.5-vl-72b-instruct": { 64 | "api_key": os.getenv("DASHSCOPE_API_KEY"), 65 | "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", 66 | "model_name": "qwen2.5-vl-72b-instruct", 67 | "comment": "qwen2.5-vl-72b-instruct", 68 | "reasoning": False, 69 | }, 70 | "qwen2.5-vl-32b-instruct": { 71 | "api_key": os.getenv("DASHSCOPE_API_KEY"), 72 | "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", 73 | "model_name": "qwen2.5-vl-32b-instruct", 74 | "comment": "qwen2.5-vl-32b-instruct", 75 | "reasoning": False, 76 | }, 77 | "qwen2.5-72b-instruct": { 78 | "api_key": os.getenv("DASHSCOPE_API_KEY"), 79 | "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", 80 | "model_name": "qwen2.5-72b-instruct", 81 | "comment": "qwen2.5-72b-instruct", 82 | "reasoning": False, 83 | }, 84 | "qwen3-235b-a22b": { 85 | "api_key": os.getenv("DASHSCOPE_API_KEY"), 86 | "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", 87 | "model_name": "qwen3-235b-a22b", 88 | "comment": "qwen3-235b-a22b", 89 | "reasoning": True, 90 | }, 91 | } -------------------------------------------------------------------------------- /medagentboard/laysummary/preprocess_datasets.py: -------------------------------------------------------------------------------- 1 | # TODO: 检查file path的更新 2 | 3 | import json 4 | import random 5 | import os 6 | import argparse 7 | 8 | def process_json_and_sample(input_filepath, output_filepath, sample_size=100): 9 | """ 10 | 读取JSON文件,为每条记录添加ID,然后随机抽样指定数量的记录到新文件。 11 | 12 | Args: 13 | input_filepath (str): 输入JSON文件的路径。 14 | output_filepath (str): 输出JSON文件的路径。 15 | sample_size (int): 需要抽样的记录数量。 16 | """ 17 | random.seed(42) # 设置随机种子以确保可重复性 18 | # --- 1. 读取输入 JSON 文件 --- 19 | try: 20 | with open(input_filepath, 'r', encoding='utf-8') as f_in: 21 | original_data = json.load(f_in) 22 | print(f"成功读取文件: {input_filepath}") 23 | except FileNotFoundError: 24 | print(f"错误:找不到输入文件 '{input_filepath}'") 25 | return 26 | except json.JSONDecodeError: 27 | print(f"错误:文件 '{input_filepath}' 不是有效的 JSON 格式。") 28 | return 29 | except Exception as e: 30 | print(f"读取文件时发生未知错误: {e}") 31 | return 32 | 33 | # 检查数据是否是列表格式 34 | if not isinstance(original_data, list): 35 | print(f"错误:文件 '{input_filepath}' 的顶层结构不是 JSON 列表。") 36 | return 37 | 38 | # --- 2. 为每条记录添加 ID --- 39 | data_with_ids = [] 40 | for index, item in enumerate(original_data): 41 | # 确保 item 是字典类型,以防 JSON 列表里有非对象元素 42 | if isinstance(item, dict): 43 | # 创建一个新字典或直接修改 item,这里选择直接修改 44 | item['id'] = index + 1 45 | data_with_ids.append(item) 46 | else: 47 | print(f"警告:跳过索引 {index} 处的非对象元素: {item}") 48 | 49 | total_items = len(data_with_ids) 50 | print(f"已为 {total_items} 条有效记录添加 ID。") 51 | 52 | # --- 3. 随机抽样 --- 53 | if total_items == 0: 54 | print("警告:没有有效的记录可供抽样。输出文件将为空列表。") 55 | sampled_data = [] 56 | elif total_items < sample_size: 57 | print(f"警告:记录总数 ({total_items}) 少于要求的抽样数量 ({sample_size})。将抽取所有 {total_items} 条记录。") 58 | # 直接使用所有数据,或者可以用 random.sample(data_with_ids, total_items) 效果一样 59 | sampled_data = data_with_ids 60 | else: 61 | print(f"正在从 {total_items} 条记录中随机抽取 {sample_size} 条...") 62 | sampled_data = random.sample(data_with_ids, sample_size) 63 | print(f"成功抽取 {len(sampled_data)} 条记录。") 64 | 65 | # --- 4. 将抽样结果写入新 JSON 文件 --- 66 | try: 67 | # 确保输出目录存在(如果输出路径包含目录) 68 | output_dir = os.path.dirname(output_filepath) 69 | if output_dir and not os.path.exists(output_dir): 70 | os.makedirs(output_dir) 71 | print(f"已创建输出目录: {output_dir}") 72 | 73 | with open(output_filepath, 'w', encoding='utf-8') as f_out: 74 | # indent=4 使输出的 JSON 文件格式更美观易读 75 | # ensure_ascii=False 确保中文字符能正确写入,而不是被转义成 \uXXXX 76 | json.dump(sampled_data, f_out, indent=4, ensure_ascii=False) 77 | print(f"抽样结果已成功写入到: {output_filepath}") 78 | except Exception as e: 79 | print(f"写入文件时发生错误: {e}") 80 | 81 | # --- 使用示例 --- 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--data_path', type=str, default='dataset') 85 | args = parser.parse_args() 86 | 87 | for dataset in os.listdir(args.data_path): 88 | dataset_path = os.path.join(args.data_path, dataset) 89 | if not os.path.isdir(dataset_path): 90 | continue 91 | 92 | for split in os.listdir(dataset_path): 93 | if split == "test.json": 94 | input_json_file = os.path.join(args.data_path, dataset, split) 95 | output_json_file = os.path.join('processed', dataset, split) 96 | number_to_sample = 100 97 | 98 | process_json_and_sample(input_json_file, output_json_file, number_to_sample) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | my_datasets/ 2 | logs/ 3 | prompt_*.txt 4 | 5 | *.json 6 | test_img.jpg 7 | .DS_Store 8 | debug.py 9 | *.json 10 | 11 | # TEMP FILES 12 | inf.sh 13 | 14 | # ZONE.IDENTIFIER 15 | *Zone.Identifier 16 | 17 | # ZIP 18 | *.zip 19 | 20 | # API config 21 | API_CONFIG.py 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | share/python-wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | MANIFEST 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .nox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | *.py,cover 72 | .hypothesis/ 73 | .pytest_cache/ 74 | cover/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | .pybuilder/ 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | 107 | # pyenv 108 | # For a library or package, you might want to ignore these files since the code is 109 | # intended to run in multiple environments; otherwise, check them in: 110 | # .python-version 111 | 112 | # pipenv 113 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 114 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 115 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 116 | # install all needed dependencies. 117 | #Pipfile.lock 118 | 119 | # UV 120 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 121 | # This is especially recommended for binary packages to ensure reproducibility, and is more 122 | # commonly ignored for libraries. 123 | #uv.lock 124 | 125 | # poetry 126 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 127 | # This is especially recommended for binary packages to ensure reproducibility, and is more 128 | # commonly ignored for libraries. 129 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 130 | #poetry.lock 131 | 132 | # pdm 133 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 134 | #pdm.lock 135 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 136 | # in version control. 137 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 138 | .pdm.toml 139 | .pdm-python 140 | .pdm-build/ 141 | 142 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 143 | __pypackages__/ 144 | 145 | # Celery stuff 146 | celerybeat-schedule 147 | celerybeat.pid 148 | 149 | # SageMath parsed files 150 | *.sage.py 151 | 152 | # Environments 153 | .env 154 | .venv 155 | env/ 156 | venv/ 157 | ENV/ 158 | env.bak/ 159 | venv.bak/ 160 | 161 | # Spyder project settings 162 | .spyderproject 163 | .spyproject 164 | 165 | # Rope project settings 166 | .ropeproject 167 | 168 | # mkdocs documentation 169 | /site 170 | 171 | # mypy 172 | .mypy_cache/ 173 | .dmypy.json 174 | dmypy.json 175 | 176 | # Pyre type checker 177 | .pyre/ 178 | 179 | # pytype static type analyzer 180 | .pytype/ 181 | 182 | # Cython debug symbols 183 | cython_debug/ 184 | 185 | # PyCharm 186 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 187 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 188 | # and can be added to the global gitignore or merged into this file. For a more nuclear 189 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 190 | #.idea/ 191 | 192 | # PyPI configuration file 193 | .pypirc 194 | -------------------------------------------------------------------------------- /medagentboard/medqa/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 设置并发数 4 | MAX_CONCURRENT=4 5 | CURRENT_JOBS=0 6 | 7 | # 创建一个临时文件来跟踪正在运行的进程 8 | TEMP_FILE=$(mktemp) 9 | trap "rm -f $TEMP_FILE" EXIT 10 | 11 | # 运行命令并管理并发 12 | run_command() { 13 | local cmd="$1" 14 | 15 | # 检查当前运行的进程数 16 | CURRENT_JOBS=$(wc -l < $TEMP_FILE) 17 | 18 | # 如果当前运行的进程数达到最大值,等待一个进程完成 19 | while [ $CURRENT_JOBS -ge $MAX_CONCURRENT ]; do 20 | for pid in $(cat $TEMP_FILE); do 21 | if ! kill -0 $pid 2>/dev/null; then 22 | # 进程已结束,从文件中删除 23 | grep -v "^$pid$" $TEMP_FILE > ${TEMP_FILE}.new 24 | mv ${TEMP_FILE}.new $TEMP_FILE 25 | fi 26 | done 27 | CURRENT_JOBS=$(wc -l < $TEMP_FILE) 28 | if [ $CURRENT_JOBS -ge $MAX_CONCURRENT ]; then 29 | sleep 1 30 | fi 31 | done 32 | 33 | # 运行命令 34 | echo "Running: $cmd" 35 | eval "$cmd" & 36 | 37 | # 记录进程ID 38 | echo $! >> $TEMP_FILE 39 | } 40 | 41 | # 定义数据集和任务类型 42 | QA_DATASETS=("MedQA" "PubMedQA") 43 | VQA_DATASETS=("PathVQA" "VQA-RAD") 44 | 45 | # 为每个数据集定义可用的qa_type 46 | declare -A DATASET_QA_TYPES 47 | DATASET_QA_TYPES[MedQA]="mc" 48 | DATASET_QA_TYPES[PubMedQA]="mc ff" 49 | DATASET_QA_TYPES[PathVQA]="mc" 50 | DATASET_QA_TYPES[VQA-RAD]="mc ff" 51 | 52 | echo "Starting experiments..." 53 | 54 | # 1. ColaCare 55 | for dataset in "${QA_DATASETS[@]}"; do 56 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 57 | cmd="python -m medagentboard.medqa.multi_agent_colacare --dataset $dataset --qa_type $qa_type --meta_model deepseek-v3-ark --doctor_models deepseek-v3-ark deepseek-v3-ark deepseek-v3-ark" 58 | run_command "$cmd" 59 | done 60 | done 61 | 62 | for dataset in "${VQA_DATASETS[@]}"; do 63 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 64 | cmd="python -m medagentboard.medqa.multi_agent_colacare --dataset $dataset --qa_type $qa_type --meta_model deepseek-v3-ark --doctor_models qwen-vl-max qwen-vl-max qwen-vl-max" 65 | run_command "$cmd" 66 | done 67 | done 68 | 69 | # 2. MedAgent 70 | for dataset in "${QA_DATASETS[@]}"; do 71 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 72 | cmd="python -m medagentboard.medqa.multi_agent_medagent --dataset $dataset --qa_type $qa_type --model deepseek-v3-ark --meta_model deepseek-v3-ark --decision_model deepseek-v3-ark" 73 | run_command "$cmd" 74 | done 75 | done 76 | 77 | for dataset in "${VQA_DATASETS[@]}"; do 78 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 79 | cmd="python -m medagentboard.medqa.multi_agent_medagent --dataset $dataset --qa_type $qa_type --model qwen-vl-max --meta_model deepseek-v3-ark --decision_model qwen-vl-max" 80 | run_command "$cmd" 81 | done 82 | done 83 | 84 | # 3. MDAgents 85 | for dataset in "${QA_DATASETS[@]}"; do 86 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 87 | cmd="python -m medagentboard.medqa.multi_agent_mdagents --dataset $dataset --qa_type $qa_type --moderator_model deepseek-v3-ark --recruiter_model deepseek-v3-ark --agent_model deepseek-v3-ark" 88 | run_command "$cmd" 89 | done 90 | done 91 | 92 | for dataset in "${VQA_DATASETS[@]}"; do 93 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 94 | cmd="python -m medagentboard.medqa.multi_agent_mdagents --dataset $dataset --qa_type $qa_type --moderator_model deepseek-v3-ark --recruiter_model deepseek-v3-ark --agent_model qwen-vl-max" 95 | run_command "$cmd" 96 | done 97 | done 98 | 99 | # 4. ReConcile 100 | for dataset in "${QA_DATASETS[@]}"; do 101 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 102 | cmd="python -m medagentboard.medqa.multi_agent_reconcile --dataset $dataset --qa_type $qa_type --agents deepseek-v3-ark qwen-max-latest qwen-vl-max --max_rounds 3" 103 | run_command "$cmd" 104 | done 105 | done 106 | 107 | for dataset in "${VQA_DATASETS[@]}"; do 108 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 109 | cmd="python -m medagentboard.medqa.multi_agent_reconcile --dataset $dataset --qa_type $qa_type --agents qwen-vl-max qwen2.5-vl-32b-instruct qwen2.5-vl-72b-instruct --max_rounds 3" 110 | run_command "$cmd" 111 | done 112 | done 113 | 114 | # 5. Single LLM with different prompting strategies 115 | PROMPT_TYPES=("zero_shot" "few_shot" "cot" "self_consistency" "cot_sc") 116 | 117 | for dataset in "${QA_DATASETS[@]}"; do 118 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 119 | for prompt_type in "${PROMPT_TYPES[@]}"; do 120 | cmd="python -m medagentboard.medqa.single_llm --dataset $dataset --qa_type $qa_type --prompt_type $prompt_type --model_key deepseek-v3-ark" 121 | run_command "$cmd" 122 | done 123 | done 124 | done 125 | 126 | for dataset in "${VQA_DATASETS[@]}"; do 127 | for qa_type in ${DATASET_QA_TYPES[$dataset]}; do 128 | for prompt_type in "${PROMPT_TYPES[@]}"; do 129 | cmd="python -m medagentboard.medqa.single_llm --dataset $dataset --qa_type $qa_type --prompt_type $prompt_type --model_key qwen-vl-max" 130 | run_command "$cmd" 131 | done 132 | done 133 | done 134 | 135 | # 等待所有任务完成 136 | wait 137 | 138 | echo "All experiments completed!" -------------------------------------------------------------------------------- /medagentboard/utils/json_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from typing import Any, List, Dict, Union, Optional, Iterator 5 | 6 | def save_json(data: Any, filepath: str, indent: int = 2) -> None: 7 | """ 8 | Save data to a JSON file 9 | 10 | Args: 11 | data: Data to be saved (must be JSON serializable) 12 | filepath: Save path including filename 13 | indent: JSON indentation format, defaults to 2 14 | """ 15 | # Ensure directory exists 16 | directory = os.path.dirname(filepath) 17 | if directory and not os.path.exists(directory): 18 | os.makedirs(directory) 19 | 20 | # Write JSON file 21 | with open(filepath, 'w', encoding='utf-8') as f: 22 | json.dump(data, f, ensure_ascii=False, indent=indent) 23 | 24 | print(f"Data saved to: {filepath}") 25 | 26 | def load_json(filepath: str) -> Any: 27 | """ 28 | Load data from a JSON file 29 | 30 | Args: 31 | filepath: Path to JSON file 32 | 33 | Returns: 34 | Loaded data 35 | 36 | Raises: 37 | FileNotFoundError: When file doesn't exist 38 | json.JSONDecodeError: When JSON format is invalid 39 | """ 40 | if not os.path.exists(filepath): 41 | raise FileNotFoundError(f"File not found: {filepath}") 42 | 43 | with open(filepath, 'r', encoding='utf-8') as f: 44 | data = json.load(f) 45 | 46 | return data 47 | 48 | def save_jsonl(data_list: List[Any], filepath: str) -> None: 49 | """ 50 | Save a list of items to a JSONL file (each item on a separate line) 51 | 52 | Args: 53 | data_list: List of items to be saved (each must be JSON serializable) 54 | filepath: Save path including filename 55 | """ 56 | # Ensure directory exists 57 | directory = os.path.dirname(filepath) 58 | if directory and not os.path.exists(directory): 59 | os.makedirs(directory) 60 | 61 | # Write JSONL file - one JSON object per line 62 | with open(filepath, 'w', encoding='utf-8') as f: 63 | for item in data_list: 64 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 65 | 66 | print(f"Data saved to JSONL file: {filepath}") 67 | 68 | def load_jsonl(filepath: str) -> List[Any]: 69 | """ 70 | Load data from a JSONL file (each line is a separate JSON object) 71 | 72 | Args: 73 | filepath: Path to JSONL file 74 | 75 | Returns: 76 | List of loaded items 77 | 78 | Raises: 79 | FileNotFoundError: When file doesn't exist 80 | json.JSONDecodeError: When JSON format is invalid 81 | """ 82 | if not os.path.exists(filepath): 83 | raise FileNotFoundError(f"File not found: {filepath}") 84 | 85 | data_list = [] 86 | with open(filepath, 'r', encoding='utf-8') as f: 87 | for line in f: 88 | line = line.strip() 89 | if line: # Skip empty lines 90 | data_list.append(json.loads(line)) 91 | 92 | return data_list 93 | 94 | def iter_jsonl(filepath: str) -> Iterator[Any]: 95 | """ 96 | Iterate through items in a JSONL file without loading everything into memory 97 | 98 | Args: 99 | filepath: Path to JSONL file 100 | 101 | Yields: 102 | Each item from the JSONL file 103 | 104 | Raises: 105 | FileNotFoundError: When file doesn't exist 106 | json.JSONDecodeError: When JSON format is invalid 107 | """ 108 | if not os.path.exists(filepath): 109 | raise FileNotFoundError(f"File not found: {filepath}") 110 | 111 | with open(filepath, 'r', encoding='utf-8') as f: 112 | for line in f: 113 | line = line.strip() 114 | if line: # Skip empty lines 115 | yield json.loads(line) 116 | 117 | def merge_json_files(filepaths: List[str], output_filepath: str) -> None: 118 | """ 119 | Merge multiple JSON files (assuming each contains list data) 120 | 121 | Args: 122 | filepaths: List of JSON file paths to merge 123 | output_filepath: Output file path for merged data 124 | """ 125 | merged_data = [] 126 | 127 | for filepath in filepaths: 128 | data = load_json(filepath) 129 | if isinstance(data, list): 130 | merged_data.extend(data) 131 | else: 132 | merged_data.append(data) 133 | 134 | save_json(merged_data, output_filepath) 135 | print(f"Merged {len(filepaths)} files into: {output_filepath}") 136 | 137 | def update_json(filepath: str, new_data: Any) -> None: 138 | """ 139 | Update an existing JSON file 140 | 141 | Args: 142 | filepath: Path to JSON file to update 143 | new_data: New data (if dict, will merge with existing; if list, will append to existing) 144 | """ 145 | if os.path.exists(filepath): 146 | existing_data = load_json(filepath) 147 | 148 | if isinstance(existing_data, dict) and isinstance(new_data, dict): 149 | # If both are dictionaries, merge them 150 | existing_data.update(new_data) 151 | elif isinstance(existing_data, list) and isinstance(new_data, list): 152 | # If both are lists, extend them 153 | existing_data.extend(new_data) 154 | else: 155 | # Other cases, replace completely 156 | existing_data = new_data 157 | else: 158 | existing_data = new_data 159 | 160 | save_json(existing_data, filepath) 161 | print(f"Updated file: {filepath}") 162 | 163 | def preprocess_response_string(response_text: str) -> str: 164 | if response_text.startswith('```json') and response_text.endswith('```'): 165 | response_text = response_text[7:-3].strip() 166 | elif response_text.startswith('```') and response_text.endswith('```'): 167 | response_text = response_text[3:-3].strip() 168 | response_text = response_text.replace("```", "").replace("json", "").strip() 169 | # Remove trailing commas 170 | response_text = re.sub(r',\s*}', '}', response_text) 171 | response_text = re.sub(r',\s*]', ']', response_text) 172 | return response_text 173 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🏥 *MedAgentBoard* 2 | 3 | **🎉 Our paper has been accepted to the NeurIPS 2025 Datasets & Benchmarks Track! 🎉** 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-2505.12371-b31b1b.svg)](https://arxiv.org/abs/2505.12371) 6 | [![Project Website](https://img.shields.io/badge/Project%20Website-MedAgentBoard-0066cc.svg)](https://medagentboard.netlify.app/) 7 | 8 | 📄 [**Read the Paper →**](https://arxiv.org/abs/2505.12371) **Benchmarking Multi-Agent Collaboration with Conventional Methods for Diverse Medical Tasks** 9 | 10 | **Authors:** Yinghao Zhu, Ziyi He, Haoran Hu, Xiaochen Zheng, Xichen Zhang, Zixiang Wang, Junyi Gao, Liantao Ma, Lequan Yu 11 | 12 | ## Overview 13 | 14 | **MedAgentBoard** is a comprehensive benchmark for the systematic evaluation of multi-agent collaboration, single-LLM, and conventional (non-LLM) approaches across diverse medical tasks. The rapid advancement of Large Language Models (LLMs) has spurred interest in multi-agent collaboration for complex medical challenges. However, the practical advantages of these multi-agent systems are not yet well understood. Existing evaluations often lack generalizability to diverse real-world clinical tasks and frequently omit rigorous comparisons against both advanced single-LLM baselines and established conventional methods. 15 | 16 | MedAgentBoard addresses this critical gap by introducing a benchmark suite covering four distinct medical task categories, utilizing varied data modalities including text, medical images, and structured Electronic Health Records (EHRs): 17 | 1. **Medical (Visual) Question Answering:** Evaluating systems on answering questions from medical texts and/or medical images. 18 | 2. **Lay Summary Generation:** Assessing the ability to convert complex medical texts into easily understandable summaries for patients. 19 | 3. **Structured EHR Predictive Modeling:** Benchmarking predictions of clinical outcomes (e.g., mortality, readmission) using structured patient data. 20 | 4. **Clinical Workflow Automation:** Evaluating the automation of multi-step clinical data analysis workflows, from data extraction to reporting. 21 | 22 | Our extensive experiments reveal a nuanced landscape: while multi-agent collaboration demonstrates benefits in specific scenarios (e.g., enhancing task completeness in clinical workflow automation), it does not consistently outperform advanced single LLMs (e.g., in textual medical QA) or, critically, specialized conventional methods, which generally maintain superior performance in tasks like medical VQA and EHR-based prediction. 23 | 24 | MedAgentBoard serves as a vital resource, offering actionable insights for researchers and practitioners. It underscores the necessity of a task-specific, evidence-based approach when selecting and developing AI solutions in medicine, highlighting that the inherent complexity and overhead of multi-agent systems must be carefully weighed against tangible performance gains. 25 | 26 | **All code, datasets, detailed prompts, and experimental results are open-sourced! If you have any questions about this paper, please feel free to contact Yinghao Zhu, yhzhu99@gmail.com.** 27 | 28 | ## Key Features & Contributions 29 | 30 | * **Comprehensive Benchmark:** Provides a platform for rigorous evaluation and extensive comparative analysis of multi-agent collaboration, single LLMs, and conventional methods across diverse medical tasks and data modalities. 31 | * **Addresses Critical Gaps:** Directly tackles limitations in current research concerning generalizability and the completeness of baselines by synthesizing prior work with LLM-era evaluations. 32 | * **Clarity on Multi-Agent Efficacy:** Offers a unified framework for adjudicating the often conflicting claims about the true advantages of multi-agent approaches in the rapidly evolving field of medical AI. 33 | * **Actionable Insights:** Distills experimental findings into practical guidance for researchers and practitioners to make informed decisions about selecting, developing, and deploying AI solutions in various medical settings. 34 | 35 | ## Related Multi-Agent Frameworks and Baselines 36 | 37 | The MedAgentBoard benchmark evaluates various approaches, including adaptations or implementations based on principles from the following (and other) influential multi-agent frameworks and related research. The project structure reflects implementations for some of these: 38 | 39 | - **WWW 2025** [ColaCare: Enhancing Electronic Health Record Modeling through Large Language Model-Driven Multi-Agent Collaboration](https://dl.acm.org/doi/abs/10.1145/3696410.3714877) 40 | - **NPJ Digital Medicine 2025** [Enhancing diagnostic capability with multi-agents conversational large language models](https://www.nature.com/articles/s41746-025-01550-0) 41 | - **NPJ Artificial Intelligence 2025** [Healthcare agent: eliciting the power of large language models for medical consultation](https://www.nature.com/articles/s44387-025-00021-x) 42 | - **ACL 2024** [ReConcile: Round-Table Conference Improves Reasoning via Consensus among Diverse LLMs](https://aclanthology.org/2024.acl-long.381/) 43 | - **NeurIPS 2024** [MDAgents: An Adaptive Collaboration of LLMs for Medical Decision-Making](https://proceedings.neurips.cc/paper_files/paper/2024/hash/90d1fc07f46e31387978b88e7e057a31-Abstract-Conference.html) 44 | - **ACL 2024 Findings** [MedAgents: Large Language Models as Collaborators for Zero-shot Medical Reasoning](https://aclanthology.org/2024.findings-acl.33/) 45 | - Other frameworks like AgentSimp, SmolAgents, OpenManus, and Owl are also discussed and utilized for specific tasks within MedAgentBoard (see paper for details). 46 | 47 | ## Associated Repositories 48 | 49 | * [MedAgentBoard-playground](https://github.com/yhzhu99/MedAgentBoard-playground): Contains the complete code for the project website. 50 | * [MedAgentBoard-WorkflowAutomation](https://github.com/yhzhu99/MedAgentBoard-WorkflowAutomation): Contains the complete code and results for Task 4 (Clinical Workflow Automation). 51 | 52 | ## Project Structure 53 | 54 | ``` 55 | medagentboard/ 56 | ├── ehr/ # EHR-related multi-agent implementations 57 | │ ├── multi_agent_colacare.py 58 | │ ├── multi_agent_medagent.py 59 | │ ├── multi_agent_reconcile.py 60 | │ ├── preprocess_dataset.py 61 | │ └── run.sh 62 | ├── laysummary/ # Lay summary generation components 63 | │ ├── evaluation.py 64 | │ ├── multi_agent_agentsimp.py 65 | │ ├── preprocess_datasets.py 66 | │ ├── run.sh 67 | │ └── single_llm.py 68 | ├── medqa/ # Medical QA system implementations 69 | │ ├── evaluate.py 70 | │ ├── multi_agent_colacare.py 71 | │ ├── multi_agent_mdagents.py 72 | │ ├── multi_agent_medagent.py 73 | │ ├── multi_agent_reconcile.py 74 | │ ├── preprocess_datasets.py 75 | │ ├── run.sh 76 | │ └── single_llm.py 77 | └── utils/ # Shared utility functions 78 | ├── encode_image.py 79 | ├── json_utils.py 80 | ├── llm_configs.py 81 | └── llm_scoring.py 82 | ``` 83 | 84 | ## Getting Started 85 | 86 | ### Prerequisites 87 | 88 | 1. Python 3.10 or higher 89 | 2. [uv](https://github.com/astral-sh/uv) package manager 90 | 91 | ### Installation 92 | 93 | ```bash 94 | # Install dependencies from uv.lock 95 | uv sync 96 | ``` 97 | 98 | ### Environment Setup 99 | 100 | Please setup the .env file with your API keys: 101 | 102 | ``` 103 | DEEPSEEK_API_KEY=sk-xxx 104 | DASHSCOPE_API_KEY=sk-xxx 105 | ARK_API_KEY=sk-xxx 106 | # Add other API keys as needed (e.g., for GPT-4, Gemini, etc.) 107 | ``` 108 | 109 | ## Usage 110 | 111 | ### Running Medical QA 112 | 113 | ```bash 114 | # Run all MedQA tasks (example from paper, may need specific setup) 115 | bash medagentboard/medqa/run.sh 116 | 117 | # Run specific MedQA task 118 | python -m medagentboard.medqa.multi_agent_colacare --dataset PubMedQA --qa_type mc 119 | # Refer to medqa/run.sh and run_colacare_diverse_llms.sh for more examples 120 | ``` 121 | *Note: Clinical Workflow Automation tasks involve more complex setups; please refer to the paper and codebase for detailed instructions on reproducing those experiments.* 122 | 123 | ### Running Lay Summary Generation 124 | 125 | ```bash 126 | python -m medagentboard.laysummary.multi_agent_agentsimp 127 | # Refer to laysummary/run.sh for more examples 128 | ``` 129 | 130 | ### Running EHR Components 131 | 132 | ```bash 133 | python -m medagentboard.ehr.multi_agent_colacare 134 | # Refer to ehr/run.sh for more examples 135 | ``` 136 | 137 | ## Citation 138 | 139 | If you find MedAgentBoard useful in your research, please cite our paper: 140 | 141 | ```bibtex 142 | @article{zhu2025medagentboard, 143 | title={{MedAgentBoard}: Benchmarking Multi-Agent Collaboration with Conventional Methods for Diverse Medical Tasks}, 144 | author={Zhu, Yinghao and He, Ziyi and Hu, Haoran and Zheng, Xiaochen and Zhang, Xichen and Wang, Zixiang and Gao, Junyi and Ma, Liantao and Yu, Lequan}, 145 | journal={arXiv preprint arXiv:2505.12371}, 146 | year={2025} 147 | } 148 | ``` 149 | -------------------------------------------------------------------------------- /medagentboard/utils/llm_scoring.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | from typing import Any, Optional 3 | from .llm_configs import LLM_MODELS_SETTINGS 4 | 5 | def vqa_rad_ff_prompt(question, ground_truth, model_answer): 6 | """ 7 | Generates a prompt for an LLM to evaluate the binary correctness 8 | of a model's answer against a ground truth, specific to VQA-RAD context. 9 | """ 10 | 11 | system_prompt = """ 12 | **You are a Medical Expert specialized in questions associated with radiological images. Your task is to act as an impartial judge and evaluate the correctness of an AI model's response to a medical visual question.** 13 | """ 14 | 15 | user_prompt = f""" 16 | **Inputs You Will Receive:** 17 | 18 | 1. **Question:** The question asked, likely referring to an (unseen) medical image. 19 | 2. **Ground Truth Answer:** The accepted correct answer based on the image and question. 20 | 3. **Model's Answer:** The answer generated by the AI model you need to evaluate. 21 | 22 | **Evaluation Dimension: Binary Correctness** 23 | 24 | Assess whether the **Model's Answer** is essentially correct when compared to the **Ground Truth Answer**, considering the **Question**. 25 | 26 | **Criteria:** 27 | 28 | * **1:** The **Model's Answer** is essentially correct. It accurately answers the **Question** and aligns with the core meaning of the **Ground Truth Answer**. Minor phrasing differences are acceptable if the core meaning is preserved. 29 | * **0:** The **Model's Answer** is incorrect. It fails to answer the **Question** accurately, or significantly contradicts the **Ground Truth Answer**. 30 | 31 | **Output Requirement:** 32 | 33 | **Output ONLY the single digit '1' (if correct) or '0' (if incorrect).** Do NOT provide any justification, explanation, or any other text. Your entire response must be just the single digit '1' or '0'. 34 | 35 | **Evaluation Task:** 36 | 37 | **Question:** {question} 38 | **Ground Truth Answer:** {ground_truth} 39 | **Model's Answer:** {model_answer} 40 | 41 | """ 42 | 43 | return system_prompt, user_prompt 44 | 45 | def pubmedqa_ff_prompt(question: str, ground_truth: str, model_answer: Any) -> tuple[str, str]: 46 | """ 47 | Generates a prompt for an LLM to act as a critical medical expert judge, 48 | evaluating the correctness and quality of a model's answer based *only* 49 | on the question and a ground truth answer, using a 1-10 scale and 50 | requiring JSON output with detailed reasoning. 51 | """ 52 | 53 | # Ensure model_answer is a string for insertion into the prompt 54 | if not isinstance(model_answer, str): 55 | try: 56 | model_answer_str = json.dumps(model_answer) 57 | except Exception: 58 | model_answer_str = str(model_answer) 59 | else: 60 | model_answer_str = model_answer 61 | 62 | system_prompt = """ 63 | **You are a highly knowledgeable and critical Medical Expert. Your task is to act as an impartial judge and rigorously evaluate the quality and correctness of an AI model's response to a medical question. You will assess this *solely* by comparing the model's response to the provided Ground Truth Answer, considering the original Question.** 64 | """ 65 | 66 | user_prompt = f""" 67 | **Inputs You Will Receive:** 68 | 69 | 1. **Question:** The original question asked. 70 | 2. **Ground Truth Answer:** The reference answer, considered correct and complete for the given question. This is your primary standard for evaluation. 71 | 3. **Model's Response:** The answer generated by the AI model you must evaluate. 72 | 73 | **Evaluation Dimension: Correctness and Alignment with Ground Truth** 74 | 75 | Assess the **Model's Response** based *only* on its factual accuracy, completeness, relevance, and overall alignment compared to the **Ground Truth Answer**, considering the scope of the **Question**. 76 | 77 | * **Factual Accuracy & Alignment:** Does the information presented in the **Model's Response** accurately reflect the information in the **Ground Truth Answer**? Are the key facts, conclusions, and nuances the same? Identify any contradictions, inaccuracies, or misrepresentations compared to the ground truth. 78 | * **Completeness:** Does the **Model's Response** cover the essential information present in the **Ground Truth Answer** needed to fully address the **Question**? Note significant omissions of key details found in the ground truth. 79 | * **Relevance & Conciseness:** Is all information in the **Model's Response** relevant to answering the **Question**, as exemplified by the **Ground Truth Answer**? Penalize irrelevant information, excessive verbosity, or details not present in the ground truth that don't enhance the answer's quality. **Focus on the accuracy and completeness relative to the ground truth, not length.** 80 | * **Overall Semantic Equivalence:** Does the **Model's Response** convey the same meaning and conclusion as the **Ground Truth Answer**, even if phrased differently? 81 | 82 | **Scoring Guide (1-10 Scale):** 83 | 84 | * **10: Perfect Match:** The answer is factually identical or perfectly semantically equivalent to the ground truth. It fully answers the question accurately and concisely, mirroring the ground truth's content and conclusion. 85 | * **9: Excellent Alignment:** Minor phrasing differences from the ground truth, but all key facts and the conclusion are perfectly represented. Negligible, harmless deviations. 86 | * **8: Very Good Alignment:** Accurately reflects the main points and conclusion of the ground truth. May omit very minor details from the ground truth or have slightly different phrasing, but the core meaning is identical. 87 | * **7: Good Alignment:** Captures the core message and conclusion of the ground truth correctly. May omit some secondary details present in the ground truth or contain minor inaccuracies that don't significantly alter the main point. 88 | * **6: Mostly Fair Alignment:** Addresses the question and aligns with the ground truth's main conclusion, but contains noticeable factual discrepancies compared to the ground truth or omits important details found in the ground truth. 89 | * **5: Fair Alignment:** Contains a mix of information that aligns and contradicts the ground truth. May get the general idea but includes significant errors or omissions when compared to the ground truth. The conclusion might be partially correct but poorly represented. 90 | * **4: Mostly Poor Alignment:** Attempts to answer the question but significantly deviates from the ground truth in facts or conclusion. Misses key information from the ground truth or introduces substantial inaccuracies. 91 | * **3: Poor Alignment:** Largely incorrect compared to the ground truth. Shows a fundamental misunderstanding or misrepresentation of the information expected based on the ground truth. 92 | * **2: Very Poor Alignment:** Almost entirely incorrect or irrelevant when compared to the ground truth. Fails to address the question meaningfully in a way that aligns with the expected answer. 93 | * **1: No Alignment/Incorrect:** Completely incorrect, irrelevant, or contradicts the ground truth entirely. Offers no valid information related to the question based on the ground truth standard. 94 | 95 | **Output Requirement:** 96 | 97 | **Output ONLY a single JSON object** in the following format. Do NOT include any text before or after the JSON object. Ensure your reasoning specifically compares the Model's Response to the Ground Truth Answer. 98 | 99 | ```json 100 | {{ 101 | "reasoning": "Provide your step-by-step thinking process here. \n1. Compare Content: Directly compare the facts, details, and conclusions in the 'Model's Response' against the 'Ground Truth Answer'. Note specific points of alignment, discrepancy, omission, or addition. \n2. Assess Relevance & Completeness: Evaluate if the 'Model's Response' fully addresses the 'Question' as comprehensively as the 'Ground Truth Answer' does. Is there irrelevant content not present or implied by the ground truth? \n3. Evaluate Semantic Equivalence: Does the model's answer mean the same thing as the ground truth? \n4. Final Assessment & Score Justification: Synthesize the comparison. Explicitly state why the assigned score is appropriate based on the rubric, highlighting the degree of match/mismatch between the Model's Response and the Ground Truth.", 102 | "score": 103 | }} 104 | ``` 105 | 106 | **Evaluation Task:** 107 | 108 | **Question:** {question} 109 | **Ground Truth Answer:** {ground_truth} 110 | **Model's Response:** {model_answer_str} 111 | """ 112 | 113 | return system_prompt, user_prompt 114 | 115 | def llm_score( 116 | question: str, 117 | ground_truth: str, 118 | model_answer: str, 119 | dataset: str, 120 | model_key:str): 121 | """ 122 | Evaluates the correctness of a model's answer against a ground truth 123 | using an LLM. The function generates a prompt based on the dataset type 124 | and returns the score given by the LLM. 125 | """ 126 | 127 | if dataset == "VQA-RAD": 128 | system_prompt, user_prompt = vqa_rad_ff_prompt(question, ground_truth, model_answer) 129 | elif dataset == "PubMedQA": 130 | system_prompt, user_prompt = pubmedqa_ff_prompt(question, ground_truth, model_answer) 131 | else: 132 | raise ValueError(f"Unsupported dataset: {dataset}") 133 | 134 | # Call the LLM with the generated prompt 135 | messages = [ 136 | {"role": "system", "content": system_prompt}, 137 | {"role": "user", "content": user_prompt}, 138 | ] 139 | 140 | model_settings = LLM_MODELS_SETTINGS[model_key] 141 | client = OpenAI(api_key=model_settings["api_key"],base_url=model_settings["base_url"]) 142 | 143 | response = client.chat.completions.create( 144 | model=model_settings["model_name"], 145 | messages=messages, 146 | stream=False 147 | ) 148 | 149 | return response.choices[0].message.content -------------------------------------------------------------------------------- /medagentboard/medqa/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import argparse 5 | from medagentboard.utils.llm_scoring import llm_score 6 | from medagentboard.utils.json_utils import preprocess_response_string 7 | 8 | def bootstrap(data_path): 9 | """ 10 | Bootstrap the sample qids with replacement, the size of the sample is the same as the original data. 11 | """ 12 | data_qids = [] 13 | with open(data_path, "r") as f: 14 | data = json.load(f) 15 | for datum in data: 16 | data_qids.append(datum["qid"]) 17 | 18 | return [data_qids[i] for i in np.random.randint(0, len(data_qids), len(data_qids))] 19 | 20 | def extract_digits(string): 21 | """ 22 | Extract digits from a string. 23 | """ 24 | return "".join(filter(str.isdigit, string)) 25 | 26 | def add_llm_score_to_json(json_file_path: str, llm_score: str): 27 | """ 28 | Reads a JSON file, adds an 'llm_score' key at the top level, 29 | and saves the modified data back to the original file path. 30 | 31 | Args: 32 | json_file_path: The path to the JSON file to modify. 33 | llm_score: The score string to add under the 'llm_score' key. 34 | 35 | Raises: 36 | FileNotFoundError: If the specified json_file_path does not exist. 37 | json.JSONDecodeError: If the file content is not valid JSON. 38 | Exception: For other potential I/O or unexpected errors. 39 | """ 40 | try: 41 | # 1. Read the existing JSON data 42 | with open(json_file_path, 'r', encoding='utf-8') as f: 43 | data = json.load(f) 44 | 45 | # 2. Add the llm_score key-value pair 46 | # Ensure it's treated as a string if that's the requirement, 47 | # otherwise, you might want to convert llm_score to int/float earlier. 48 | data['llm_score'] = llm_score 49 | 50 | # 3. Write the modified data back to the original file 51 | with open(json_file_path, 'w', encoding='utf-8') as f: 52 | # Use indent for readability; ensure_ascii=False for wider char support 53 | json.dump(data, f, indent=2, ensure_ascii=False) 54 | 55 | # print(f"Successfully added 'llm_score' to '{json_file_path}'") 56 | 57 | except FileNotFoundError: 58 | print(f"Error: File not found at '{json_file_path}'") 59 | raise # Re-raise the exception if you want the caller to handle it 60 | except json.JSONDecodeError: 61 | print(f"Error: Failed to decode JSON from '{json_file_path}'. Is it a valid JSON file?") 62 | raise # Re-raise 63 | except Exception as e: 64 | print(f"An unexpected error occurred while processing '{json_file_path}': {e}") 65 | raise # Re-raise 66 | 67 | def extract_score_from_llm_output(output_string: str) -> int | None: 68 | """ 69 | Extracts the first integer value associated with the key "score" 70 | from a potentially malformed JSON-like string output by an LLM. 71 | 72 | It specifically looks for '"score":' followed by optional whitespace 73 | and then an integer. It does not rely on full JSON parsing. 74 | 75 | Args: 76 | output_string: The string output from the LLM, expected to contain 77 | a '"score": ' pattern. 78 | 79 | Returns: 80 | The extracted integer score if found, otherwise None. 81 | """ 82 | if not output_string: 83 | return None 84 | 85 | # Option 1: Using string manipulation (more step-by-step) 86 | try: 87 | # Find the position of '"score":' 88 | score_key = '"score":' 89 | key_index = output_string.find(score_key) 90 | 91 | if key_index == -1: 92 | # Try with single quotes as a fallback, as LLMs might hallucinate them 93 | score_key = "'score':" 94 | key_index = output_string.find(score_key) 95 | if key_index == -1: 96 | return None # Key not found 97 | 98 | # Start searching for the number right after '"score":' 99 | start_search_index = key_index + len(score_key) 100 | 101 | # Skip any whitespace characters immediately after the colon 102 | num_start_index = start_search_index 103 | while num_start_index < len(output_string) and output_string[num_start_index].isspace(): 104 | num_start_index += 1 105 | 106 | if num_start_index == len(output_string): 107 | return None # Reached end of string without finding a number 108 | 109 | # Extract consecutive digits 110 | num_end_index = num_start_index 111 | while num_end_index < len(output_string) and output_string[num_end_index].isdigit(): 112 | num_end_index += 1 113 | 114 | # If no digits were found right after skipping whitespace 115 | if num_end_index == num_start_index: 116 | return None 117 | 118 | # Extract the number string and convert to int 119 | number_str = output_string[num_start_index:num_end_index] 120 | return int(number_str) 121 | 122 | except Exception: 123 | # Catch any unexpected errors during string processing 124 | return None 125 | 126 | if __name__ == "__main__": 127 | np.random.seed(42) 128 | 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("--logs_dir", type=str, default="logs/medqa") 131 | parser.add_argument("--bootstrap", type=bool, default=True, help="Whether to bootstrap the data") 132 | parser.add_argument("--n_bootstrap", type=int, default=10, help="Number of bootstrap samples") 133 | parser.add_argument("--judge_model", type=str, default="deepseek-v3-ark", help="LLM used for free-form evaluation") 134 | args = parser.parse_args() 135 | 136 | # Loop through the datasets 137 | for dataset_dir in os.listdir(args.logs_dir): 138 | if not os.path.isdir(os.path.join(args.logs_dir, dataset_dir)): 139 | continue 140 | 141 | dataset = dataset_dir 142 | print(f"Dataset: {dataset}") 143 | 144 | # Loop through the question types 145 | for qtype_dir in os.listdir(os.path.join(args.logs_dir, dataset)): 146 | if not os.path.isdir(os.path.join(args.logs_dir, dataset, qtype_dir)): 147 | continue 148 | 149 | qtype = qtype_dir 150 | print(f"Question Type: {qtype}") 151 | data_path = f"my_datasets/processed/{dataset}/medqa_{"mc" if qtype == "multiple_choice" else "ff"}.json" 152 | qids = [] 153 | MODEL_WIDTH = 27 154 | METRIC_WIDTH = 10 155 | VALUE_WIDTH = 8 156 | TOTAL_WIDTH = 8 157 | 158 | print(f"{'Model':<{MODEL_WIDTH}} | {'Metric':<{METRIC_WIDTH}} | {'Mean':>{VALUE_WIDTH}} | {'Std':>{VALUE_WIDTH}} | {'Total':>{TOTAL_WIDTH}}") 159 | print(f"{'-' * MODEL_WIDTH}-+-{'-' * METRIC_WIDTH}-+-{'-' * VALUE_WIDTH}-+-{'-' * VALUE_WIDTH}-+-{'-' * TOTAL_WIDTH}") 160 | 161 | model_order = ["ColaCare", "MDAgents", "MedAgent", "ReConcile", "SingleLLM_zero_shot", "SingleLLM_few_shot", "SingleLLM_self_consistency", "SingleLLM_cot", "SingleLLM_cot_sc", "linkbert", "gatortron", "m3ae", "biomedgpt", "mumc"] 162 | 163 | if bootstrap: 164 | for i in range(args.n_bootstrap): 165 | qids.append(bootstrap(data_path)) # qid shape: (n_bootstrap, n) 166 | 167 | # Loop through the model results 168 | for model_dir in model_order: 169 | if model_dir in os.listdir(os.path.join(args.logs_dir, dataset, qtype)): 170 | if not os.path.isdir(os.path.join(args.logs_dir, dataset, qtype, model_dir)): 171 | continue 172 | 173 | model = model_dir 174 | result = {"model": model, "acc": [], "score": [], "total": 0} 175 | 176 | # Loop through each bootstrap sample 177 | for i in range(len(qids)): 178 | if qtype == "multiple_choice": 179 | correct = 0 180 | 181 | elif qtype == "free-form": 182 | score = 0 183 | 184 | total = len(qids[0]) 185 | for qid in qids[i]: 186 | for ans_file in os.listdir(os.path.join(args.logs_dir, dataset, qtype, model)): 187 | if extract_digits(qid) == extract_digits(ans_file): 188 | try: 189 | ans_data = json.load(open(os.path.join(args.logs_dir, dataset, qtype, model, ans_file), "r")) 190 | except Exception as e: 191 | print(f"Error loading {os.path.join(args.logs_dir, dataset, qtype, model, ans_file)}: {e}") 192 | continue 193 | 194 | if qtype == "multiple_choice" and ans_data["ground_truth"] == ans_data["predicted_answer"]: 195 | correct += 1 196 | 197 | # Use LLM-as-a-judge for free-form questions 198 | elif qtype == "free-form": 199 | # Check if the llm_score is already computed 200 | if "llm_score" in ans_data: 201 | score += int(ans_data["llm_score"]) 202 | # If not, compute it and save it 203 | else: 204 | try: 205 | ans_score = llm_score(ans_data["question"], ans_data["ground_truth"], ans_data["predicted_answer"], dataset, args.judge_model).strip() 206 | if len(ans_score) > 10: 207 | ans_score = extract_score_from_llm_output(ans_score) 208 | add_llm_score_to_json(os.path.join(args.logs_dir, dataset, qtype, model, ans_file), ans_score) # Save the score to the JSON file 209 | score += int(ans_score) 210 | 211 | except Exception as e: 212 | print(f"Error adding llm score to {os.path.join(args.logs_dir, dataset, qtype, model, ans_file)}: {e}") 213 | continue 214 | 215 | if qtype == "multiple_choice": 216 | result["acc"].append(correct / total) 217 | result["total"] += total 218 | 219 | elif qtype == "free-form": 220 | result["score"].append(score / total) 221 | result["total"] += total 222 | 223 | if qtype == "multiple_choice": 224 | metric_name = "Accuracy" 225 | mean_value = round(np.mean(result["acc"]), 4) 226 | std_dev = round(np.std(result["acc"]), 4) 227 | total_str = str(result['total']) 228 | 229 | elif qtype == "free-form": 230 | metric_name = "LLM Score" 231 | mean_value = round(np.mean(result["score"]), 4) 232 | std_dev = round(np.std(result["score"]), 4) 233 | total_str = str(result['total']) 234 | 235 | print(f"{model:<{MODEL_WIDTH}} | {metric_name:<{METRIC_WIDTH}} | {mean_value:>{VALUE_WIDTH}.4f} | {std_dev:>{VALUE_WIDTH}.4f} | {total_str:>{TOTAL_WIDTH}}") 236 | # else: 237 | # with open(data_path, "r") as f: 238 | # data = json.load(f) 239 | # for datum in data: 240 | # qids.append(datum["qid"]) -------------------------------------------------------------------------------- /medagentboard/medqa/preprocess_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | from typing import List, Dict, Any 5 | import pandas as pd 6 | from medagentboard.utils.json_utils import save_json, load_json, load_jsonl 7 | 8 | # Define paths 9 | RAW_DATA_DIR = "./my_datasets/raw/medqa" 10 | PROCESSED_DATA_DIR = "./my_datasets/processed/medqa" 11 | 12 | def random_select_samples(data: List[Dict[str, Any]], sample_size: int = 200, seed: int = 42) -> List[Dict[str, Any]]: 13 | """ 14 | Randomly select a subset of samples from the dataset. 15 | 16 | Args: 17 | data: The complete dataset to sample from 18 | sample_size: Number of samples to select (default: 200) 19 | seed: Random seed for reproducibility (default: 42) 20 | 21 | Returns: 22 | A randomly selected subset of the input data 23 | """ 24 | if sample_size >= len(data): 25 | return data 26 | 27 | random.seed(seed) 28 | return random.sample(data, sample_size) 29 | 30 | 31 | def process_medqa(raw_dir=RAW_DATA_DIR, output_dir=PROCESSED_DATA_DIR, sample_size: int = None): 32 | """ 33 | Process the MedQA dataset from raw format to standardized format. 34 | 35 | Args: 36 | raw_dir: Directory containing raw dataset 37 | output_dir: Directory to save processed dataset 38 | sample_size: Number of samples to select (None for all samples) 39 | """ 40 | # Define paths 41 | medqa_path = os.path.join(raw_dir, "MedQA", "questions", "US", "test.jsonl") 42 | output_path = os.path.join(output_dir, "MedQA", "medqa_mc_test.json") 43 | 44 | # Create output directory if it doesn't exist 45 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 46 | 47 | # Load JSONL data 48 | medqa_data = load_jsonl(medqa_path) 49 | 50 | processed_data = [] 51 | 52 | for i, item in enumerate(medqa_data): 53 | curated_data = { 54 | "qid": f"medqa_mc_{str(i + 1).zfill(3)}", 55 | "question": item["question"], 56 | "options": item["options"], 57 | "answer": item["answer_idx"] 58 | } 59 | 60 | processed_data.append(curated_data) 61 | 62 | # Apply sampling if requested 63 | if sample_size is not None: 64 | processed_data = random_select_samples(processed_data, sample_size) 65 | 66 | # Save processed data 67 | save_json(processed_data, output_path) 68 | print(f"MedQA dataset processed and saved to: {output_path}") 69 | 70 | 71 | def process_pubmedqa(raw_dir=RAW_DATA_DIR, output_dir=PROCESSED_DATA_DIR, sample_size: int = None): 72 | """ 73 | Process the PubMedQA dataset from raw format to standardized format. 74 | 75 | Args: 76 | raw_dir: Directory containing raw dataset 77 | output_dir: Directory to save processed dataset 78 | sample_size: Number of samples to select (None for all samples) 79 | """ 80 | # Define paths 81 | ori_pqal_path = os.path.join(raw_dir, "PubMedQA", "ori_pqal.json") 82 | test_ground_truth_path = os.path.join(raw_dir, "PubMedQA", "test_ground_truth.json") 83 | output_path_base = os.path.join(output_dir, "PubMedQA") 84 | output_path_mc = os.path.join(output_path_base, "medqa_mc_test.json") 85 | output_path_ff = os.path.join(output_path_base, "medqa_ff_test.json") 86 | 87 | # Create output directory if it doesn't exist 88 | os.makedirs(os.path.dirname(output_path_base), exist_ok=True) 89 | 90 | # Load datasets 91 | data = load_json(ori_pqal_path) 92 | labels = load_json(test_ground_truth_path) 93 | 94 | # Define standard options 95 | options = {"A": "Yes", "B": "No", "C": "Maybe"} 96 | options_map = {"Yes": "A", "No": "B", "Maybe": "C"} 97 | 98 | processed_data_mc = [] 99 | processed_data_ff = [] 100 | 101 | for qid, item_data in data.items(): 102 | # only qid in labels are test set 103 | if qid not in labels: 104 | continue 105 | context = " ".join(item_data["CONTEXTS"]) # Concatenate contexts into a single string 106 | question = item_data["QUESTION"] 107 | answer = item_data["final_decision"].capitalize() 108 | 109 | # Free-form version: Concatenate context into the question 110 | free_form_question = ( 111 | f"{question}\n\n" 112 | f"Context: {context}" 113 | ) 114 | 115 | # Multiple-choice version: Concatenate context into the question 116 | mc_question = ( 117 | f"{question}\n\n" 118 | f"Context: {context}" 119 | ) 120 | 121 | # Add both versions to the processed data 122 | free_form_data = { 123 | "qid": f"pubmedqa_ff_{qid}", 124 | "question": free_form_question, 125 | "answer": f"{answer}. {item_data["LONG_ANSWER"]}" # Use the long answer for free-form 126 | } 127 | 128 | mc_data = { 129 | "qid": f"pubmedqa_mc_{qid}", 130 | "question": mc_question, 131 | "options": options, 132 | "answer": options_map[answer] # Map ground truth to option key 133 | } 134 | 135 | processed_data_ff.append(free_form_data) 136 | processed_data_mc.append(mc_data) 137 | 138 | # Apply sampling if requested 139 | if sample_size is not None: 140 | processed_data_ff = random_select_samples(processed_data_ff, sample_size) 141 | processed_data_mc = random_select_samples(processed_data_mc, sample_size) 142 | 143 | # Save processed data 144 | save_json(processed_data_ff, output_path_ff) 145 | print(f"PubMedQA dataset (free-form) processed and saved to: {output_path_ff}") 146 | save_json(processed_data_mc, output_path_mc) 147 | print(f"PubMedQA dataset (free-form) processed and saved to: {output_path_mc}") 148 | 149 | 150 | def process_pathvqa(raw_dir=RAW_DATA_DIR, output_dir=PROCESSED_DATA_DIR, sample_size: int = None): 151 | """ 152 | Process the PathVQA dataset from raw format to standardized format. 153 | This function expects the dataset to be in JSON format (converted from .pkl), 154 | to avoid using the pickle module. 155 | 156 | Args: 157 | raw_dir: Directory containing raw dataset 158 | output_dir: Directory to save processed dataset 159 | sample_size: Number of samples to select (None for all samples) 160 | """ 161 | path_vqa_path = os.path.join(raw_dir, "PathVQA", "qas", "test", "test.pkl") 162 | path_vqa_images = os.path.join(raw_dir, "PathVQA", "images", "test") 163 | output_path = os.path.join(output_dir, "PathVQA", "medqa_mc_test.json") 164 | 165 | # Create output directory if it doesn't exist 166 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 167 | 168 | # Load the dataset 169 | data = pd.read_pickle(path_vqa_path) 170 | 171 | processed_data = [] 172 | 173 | # Define standard options for yes/no questions 174 | options = {"A": "Yes", "B": "No"} 175 | options_map = {"yes": "A", "no": "B"} 176 | 177 | for i, item in enumerate(data): 178 | answer = str(item["answer"]).lower().strip() 179 | 180 | # Only process questions with yes/no answers 181 | if answer not in ["yes", "no"]: 182 | continue 183 | 184 | question = item["question"] 185 | image_path = os.path.join(path_vqa_images, item["image"]) + ".jpg" 186 | 187 | curated_data = { 188 | "qid": f"pathvqa_mc_{str(i + 1).zfill(6)}", 189 | "question": question, 190 | "image_path": image_path, 191 | "options": options, 192 | "answer": options_map[answer] # Map answer to option key 193 | } 194 | 195 | processed_data.append(curated_data) 196 | 197 | # Apply sampling if requested 198 | if sample_size is not None: 199 | processed_data = random_select_samples(processed_data, sample_size) 200 | 201 | # Save processed data 202 | save_json(processed_data, output_path) 203 | print(f"PathVQA dataset (yes/no questions only) processed and saved to: {output_path}") 204 | print(f"Total yes/no questions found: {len(processed_data)}") 205 | 206 | 207 | def process_vqa_rad(raw_dir=RAW_DATA_DIR, output_dir=PROCESSED_DATA_DIR, sample_size: int = None): 208 | """ 209 | Process the VQA-RAD dataset from raw format to standardized format. 210 | 211 | Args: 212 | raw_dir: Directory containing raw dataset 213 | output_dir: Directory to save processed dataset 214 | sample_size: Number of samples to select (None for all samples) 215 | """ 216 | # Define paths 217 | vqa_rad_path = os.path.join(raw_dir, "VQA-RAD", "testset.json") 218 | vqa_rad_images = os.path.join(raw_dir, "VQA-RAD", "images") 219 | output_path_base = os.path.join(output_dir, "VQA-RAD") 220 | output_path_mc = os.path.join(output_path_base, "medqa_mc_test.json") 221 | output_path_ff = os.path.join(output_path_base, "medqa_ff_test.json") 222 | 223 | # Create output directory if it doesn't exist 224 | os.makedirs(output_path_base, exist_ok=True) 225 | 226 | # Load dataset 227 | data = load_json(vqa_rad_path) 228 | 229 | processed_data_mc = [] 230 | processed_data_ff = [] 231 | 232 | # Define standard options for yes/no questions 233 | options = {"A": "Yes", "B": "No"} 234 | options_map = {"yes": "A", "no": "B"} 235 | 236 | for item in data: 237 | qid = item["qid"] 238 | question = item["question"] 239 | image_path = os.path.join(vqa_rad_images, item["image_name"]) 240 | answer = item["answer"] 241 | answer_type = item["answer_type"] 242 | 243 | if answer_type == "CLOSED": 244 | # Process CLOSED type as multiple choice (yes/no) 245 | answer_lower = answer.lower().strip() 246 | if answer_lower in ["yes", "no"]: 247 | mc_data = { 248 | "qid": f"vqa_rad_mc_{qid}", 249 | "question": question, 250 | "image_path": image_path, 251 | "options": options, 252 | "answer": options_map[answer_lower] 253 | } 254 | processed_data_mc.append(mc_data) 255 | elif answer_type == "OPEN": 256 | # Process OPEN type as free-form 257 | ff_data = { 258 | "qid": f"vqa_rad_ff_{qid}", 259 | "question": question, 260 | "image_path": image_path, 261 | "answer": answer 262 | } 263 | processed_data_ff.append(ff_data) 264 | 265 | # Apply sampling if requested 266 | if sample_size is not None: 267 | processed_data_mc = random_select_samples(processed_data_mc, sample_size) 268 | processed_data_ff = random_select_samples(processed_data_ff, sample_size) 269 | 270 | # Save processed data 271 | save_json(processed_data_mc, output_path_mc) 272 | save_json(processed_data_ff, output_path_ff) 273 | print(f"VQA-RAD dataset (multiple-choice) processed and saved to: {output_path_mc}") 274 | print(f"VQA-RAD dataset (free-form) processed and saved to: {output_path_ff}") 275 | 276 | 277 | def main(): 278 | parser = argparse.ArgumentParser(description="Process medical datasets into a standardized format") 279 | parser.add_argument("--medqa", action="store_true", help="Process MedQA dataset") 280 | parser.add_argument("--pubmedqa", action="store_true", help="Process PubMedQA dataset") 281 | parser.add_argument("--pathvqa", action="store_true", help="Process PathVQA dataset") 282 | parser.add_argument("--vqa-rad", action="store_true", help="Process VQA-RAD dataset") 283 | parser.add_argument("--all", action="store_true", help="Process all datasets") 284 | parser.add_argument("--raw-dir", type=str, default=RAW_DATA_DIR, help="Directory containing raw datasets") 285 | parser.add_argument("--output-dir", type=str, default=PROCESSED_DATA_DIR, help="Directory to save processed datasets") 286 | parser.add_argument("--sample-size", type=int, default=200, help="Number of samples to randomly select (None for all samples)") 287 | 288 | args = parser.parse_args() 289 | 290 | # If no dataset is specified, show help 291 | if not (args.medqa or args.pubmedqa or args.pathvqa or args.vqa_rad or args.all): 292 | parser.print_help() 293 | return 294 | 295 | # Process requested datasets 296 | if args.all or args.medqa: 297 | process_medqa(args.raw_dir, args.output_dir, args.sample_size) 298 | 299 | if args.all or args.pubmedqa: 300 | process_pubmedqa(args.raw_dir, args.output_dir, args.sample_size) 301 | 302 | if args.all or args.pathvqa: 303 | process_pathvqa(args.raw_dir, args.output_dir, args.sample_size) 304 | 305 | if args.all or args.vqa_rad: 306 | process_vqa_rad(args.raw_dir, args.output_dir, args.sample_size) 307 | 308 | print("All requested datasets processed successfully!") 309 | 310 | 311 | if __name__ == "__main__": 312 | main() -------------------------------------------------------------------------------- /medagentboard/medqa/multi_agent_healthcareagent.py: -------------------------------------------------------------------------------- 1 | """ 2 | medagentboard/medqa/multi_agent_healthcareagent.py 3 | 4 | This file implements the HealthcareAgent framework as a standalone, end-to-end baseline. 5 | It is inspired by the paper "Healthcare agent: eliciting the power of large language models for medical consultation". 6 | The framework processes a single medical query through a multi-step pipeline involving planning, 7 | preliminary analysis, internal safety review ("discuss"), and final response modification. 8 | """ 9 | 10 | import os 11 | import json 12 | import time 13 | import argparse 14 | from typing import Dict, Any, Optional, List 15 | from openai import OpenAI 16 | from tqdm import tqdm 17 | 18 | from medagentboard.utils.llm_configs import LLM_MODELS_SETTINGS 19 | from medagentboard.utils.encode_image import encode_image 20 | from medagentboard.utils.json_utils import load_json, save_json, preprocess_response_string 21 | 22 | # --- Prompts adapted from the "Healthcare agent" paper's logic --- 23 | 24 | # Corresponds to the "Planner" module to decide the initial action 25 | PLANNER_PROMPT_TEMPLATE = """ 26 | Based on the provided medical query, determine the best initial course of action. 27 | - If the query is ambiguous, lacks critical details for a safe conclusion, or would benefit from further clarification, choose 'INQUIRY'. 28 | - If you have sufficient information to provide a confident and safe diagnosis or answer, choose 'DIAGNOSE'. 29 | 30 | Medical Query: 31 | Question: {question} 32 | {options_text} 33 | {image_text} 34 | 35 | Respond with a single word: DIAGNOSE or INQUIRY. 36 | """ 37 | 38 | # Corresponds to the "Inquiry" submodule to generate clarifying questions 39 | INQUIRY_PROMPT_TEMPLATE = """ 40 | You are a medical doctor analyzing a case. To form an accurate and safe conclusion for the query below, you need more information. 41 | Generate a list of the top 3 most critical follow-up questions you would ask to better understand the situation. 42 | 43 | Medical Query: 44 | Question: {question} 45 | {options_text} 46 | {image_text} 47 | 48 | Return a JSON object with a single key "questions" containing a list of strings. 49 | Example: {{"questions": ["How long have you experienced this symptom?", "Is there any associated pain?"]}} 50 | """ 51 | 52 | # Corresponds to the "Medical Diagnosis" submodule for generating a preliminary response 53 | PRELIMINARY_ANALYSIS_PROMPT_TEMPLATE = """ 54 | As a medical doctor, provide a preliminary analysis of the following case based on the available information. 55 | {inquiry_context} 56 | 57 | Your output MUST be a JSON object with two keys: 58 | 1. "explanation": Your detailed reasoning and diagnostic process. 59 | 2. "answer": Your conclusion. For multiple-choice questions, this must be ONLY the option letter (e.g., 'A', 'B'). 60 | 61 | Medical Query: 62 | Question: {question} 63 | {options_text} 64 | {image_text} 65 | """ 66 | 67 | # --- Safety Module Prompts (The "Discuss" Phase) --- 68 | 69 | SAFETY_ETHICS_PROMPT = """ 70 | As a safety supervisor, review the following AI doctor's response. 71 | Critique it on one specific point: Does it include necessary disclaimers about being an AI and the potential risks of its advice? 72 | Provide concise feedback for improvement if it's lacking. If it's good, state that. 73 | 74 | AI Response to be Reviewed: 75 | {preliminary_response} 76 | 77 | Your Feedback: 78 | """ 79 | 80 | SAFETY_EMERGENCY_PROMPT = """ 81 | As a safety supervisor, review the following AI doctor's response. 82 | Critique it on one specific point: Does the case involve any potentially serious or life-threatening symptoms? 83 | If so, highlight them and suggest adding a clear warning to seek immediate medical attention. If not, state that. 84 | 85 | AI Response to be Reviewed: 86 | {preliminary_response} 87 | 88 | Your Feedback: 89 | """ 90 | 91 | SAFETY_ERROR_PROMPT = """ 92 | As a safety supervisor, review the following AI doctor's response. 93 | Critique it on one specific point: Are there any potential factual errors, misinterpretations of the image/text, or logical contradictions? 94 | Point out any potential errors and suggest corrections. If none are found, state that. 95 | 96 | AI Response to be Reviewed: 97 | {preliminary_response} 98 | 99 | Your Feedback: 100 | """ 101 | 102 | # --- Final Modification Prompt (The "Modify" Phase) --- 103 | 104 | FINAL_MODIFICATION_PROMPT_TEMPLATE = """ 105 | You are a senior medical supervisor tasked with creating the final, definitive response. 106 | Revise the preliminary analysis below by incorporating the feedback from the internal safety review. 107 | The final output must be a single, polished JSON object with "explanation" and "answer" keys. 108 | 109 | 1. **Original Medical Query:** 110 | Question: {question} 111 | {options_text} 112 | {image_text} 113 | 114 | 2. **Preliminary Analysis (Draft):** 115 | {preliminary_response} 116 | 117 | 3. **Internal Safety Review Feedback:** 118 | - Ethics & Disclaimer Feedback: {ethics_feedback} 119 | - Emergency Situation Feedback: {emergency_feedback} 120 | - Factual Error Feedback: {error_feedback} 121 | 122 | Your task is to integrate the feedback to create a final, safe, and accurate response. 123 | Ensure the explanation is comprehensive and the answer is correct. 124 | For multiple-choice questions, the 'answer' field must contain ONLY the option letter. 125 | 126 | **Final Revised JSON Output:** 127 | """ 128 | 129 | 130 | class HealthcareAgentFramework: 131 | """ 132 | A standalone framework that implements the HealthcareAgent methodology. 133 | """ 134 | 135 | def __init__(self, model_key: str): 136 | """ 137 | Initialize the framework. 138 | 139 | Args: 140 | model_key: The LLM model key from LLM_MODELS_SETTINGS to be used for all internal steps. 141 | """ 142 | self.model_key = model_key 143 | 144 | if model_key not in LLM_MODELS_SETTINGS: 145 | raise ValueError(f"Model key '{model_key}' not found in LLM_MODELS_SETTINGS") 146 | 147 | model_settings = LLM_MODELS_SETTINGS[model_key] 148 | self.client = OpenAI( 149 | api_key=model_settings["api_key"], 150 | base_url=model_settings["base_url"], 151 | ) 152 | self.model_name = model_settings["model_name"] 153 | print(f"Initialized HealthcareAgentFramework with model: {self.model_name}") 154 | 155 | def _call_llm(self, 156 | prompt: str, 157 | image_path: Optional[str] = None, 158 | expect_json: bool = True, 159 | max_retries: int = 3) -> str: 160 | """ 161 | A helper function to call the LLM with a given prompt and optional image. 162 | """ 163 | system_message = {"role": "system", "content": "You are a highly capable and meticulous medical AI assistant."} 164 | user_content = [{"type": "text", "text": prompt}] 165 | 166 | if image_path: 167 | if not os.path.exists(image_path): 168 | raise FileNotFoundError(f"Image not found at {image_path}") 169 | base64_image = encode_image(image_path) 170 | user_content.insert(0, { 171 | "type": "image_url", 172 | "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, 173 | }) 174 | 175 | user_message = {"role": "user", "content": user_content} 176 | 177 | messages = [system_message, user_message] 178 | response_format = {"type": "json_object"} if expect_json else None 179 | 180 | retries = 0 181 | while retries < max_retries: 182 | try: 183 | print(f"Calling LLM (JSON: {expect_json})...") 184 | completion = self.client.chat.completions.create( 185 | model=self.model_name, 186 | messages=messages, 187 | response_format=response_format 188 | ) 189 | response = completion.choices[0].message.content 190 | print(f"LLM call successful. Response snippet: {response[:80]}...") 191 | return response 192 | except Exception as e: 193 | retries += 1 194 | print(f"LLM API call error (attempt {retries}/{max_retries}): {e}") 195 | if retries >= max_retries: 196 | raise Exception(f"LLM API call failed after {max_retries} attempts.") 197 | time.sleep(1) 198 | return "" # Should not be reached 199 | 200 | def run_query(self, data_item: Dict) -> Dict: 201 | """ 202 | Processes a single medical query through the full HealthcareAgent pipeline. 203 | """ 204 | qid = data_item["qid"] 205 | question = data_item["question"] 206 | options = data_item.get("options") 207 | image_path = data_item.get("image_path") 208 | ground_truth = data_item.get("answer") 209 | 210 | print(f"\n{'='*20} Processing QID: {qid} with HealthcareAgentFramework {'='*20}") 211 | start_time = time.time() 212 | 213 | case_history = { 214 | "steps": [] 215 | } 216 | 217 | # --- Prepare context strings used in multiple prompts --- 218 | options_text = "" 219 | if options: 220 | options_text = "Options:\n" + "\n".join([f"{key}: {value}" for key, value in options.items()]) 221 | image_text = "An image is provided for context." if image_path else "" 222 | 223 | try: 224 | # === STEP 1: Planner Module === 225 | planner_prompt = PLANNER_PROMPT_TEMPLATE.format( 226 | question=question, options_text=options_text, image_text=image_text 227 | ) 228 | action = self._call_llm(planner_prompt, image_path, expect_json=False).strip().upper() 229 | case_history["steps"].append({"step": "1_Planner", "decision": action}) 230 | 231 | # === STEP 2: Inquiry Module (Optional) === 232 | inquiry_context = "" 233 | if "INQUIRY" in action: 234 | inquiry_prompt = INQUIRY_PROMPT_TEMPLATE.format( 235 | question=question, options_text=options_text, image_text=image_text 236 | ) 237 | inquiry_response_str = self._call_llm(inquiry_prompt, image_path, expect_json=True) 238 | inquiry_result = json.loads(preprocess_response_string(inquiry_response_str)) 239 | questions = inquiry_result.get("questions", []) 240 | case_history["steps"].append({"step": "2_Inquiry", "generated_questions": questions}) 241 | if questions: 242 | inquiry_context = "To provide a robust answer, the following questions should be considered:\n- " + "\n- ".join(questions) 243 | inquiry_context += "\n\nGiven this, here is a preliminary analysis based on the limited information:" 244 | else: 245 | case_history["steps"].append({"step": "2_Inquiry", "generated_questions": "Skipped as per planner's decision."}) 246 | 247 | # === STEP 3: Preliminary Analysis (Function Module) === 248 | analysis_prompt = PRELIMINARY_ANALYSIS_PROMPT_TEMPLATE.format( 249 | inquiry_context=inquiry_context, 250 | question=question, 251 | options_text=options_text, 252 | image_text=image_text 253 | ) 254 | preliminary_response_str = self._call_llm(analysis_prompt, image_path, expect_json=True) 255 | case_history["steps"].append({"step": "3_Preliminary_Analysis", "response": preliminary_response_str}) 256 | 257 | # === STEP 4: Safety Module ("Discuss" Phase) === 258 | ethics_feedback = self._call_llm(SAFETY_ETHICS_PROMPT.format(preliminary_response=preliminary_response_str), expect_json=False) 259 | emergency_feedback = self._call_llm(SAFETY_EMERGENCY_PROMPT.format(preliminary_response=preliminary_response_str), expect_json=False) 260 | error_feedback = self._call_llm(SAFETY_ERROR_PROMPT.format(preliminary_response=preliminary_response_str), expect_json=False) 261 | case_history["steps"].append({ 262 | "step": "4_Safety_Review", 263 | "ethics_feedback": ethics_feedback, 264 | "emergency_feedback": emergency_feedback, 265 | "error_feedback": error_feedback 266 | }) 267 | 268 | # === STEP 5: Final Modification ("Modify" Phase) === 269 | final_prompt = FINAL_MODIFICATION_PROMPT_TEMPLATE.format( 270 | question=question, options_text=options_text, image_text=image_text, 271 | preliminary_response=preliminary_response_str, 272 | ethics_feedback=ethics_feedback, 273 | emergency_feedback=emergency_feedback, 274 | error_feedback=error_feedback 275 | ) 276 | final_response_str = self._call_llm(final_prompt, image_path, expect_json=True) 277 | case_history["steps"].append({"step": "5_Final_Modification", "response": final_response_str}) 278 | 279 | # === STEP 6: Parse Final Result === 280 | final_result_json = json.loads(preprocess_response_string(final_response_str)) 281 | predicted_answer = final_result_json.get("answer", "Parsing Error") 282 | explanation = final_result_json.get("explanation", "Parsing Error") 283 | 284 | except Exception as e: 285 | print(f"FATAL ERROR during query processing for QID {qid}: {e}") 286 | predicted_answer = "Framework Error" 287 | explanation = str(e) 288 | case_history["error"] = str(e) 289 | 290 | processing_time = time.time() - start_time 291 | print(f"Finished QID: {qid}. Time: {processing_time:.2f}s. Final Answer: {predicted_answer}") 292 | 293 | # Assemble final result object in the required format 294 | final_output = { 295 | "qid": qid, 296 | "timestamp": int(time.time()), 297 | "question": question, 298 | "options": options, 299 | "image_path": image_path, 300 | "ground_truth": ground_truth, 301 | "predicted_answer": predicted_answer, 302 | "explanation": explanation, 303 | "case_history": case_history, 304 | "processing_time": processing_time 305 | } 306 | return final_output 307 | 308 | def main(): 309 | parser = argparse.ArgumentParser(description="Run HealthcareAgent Framework on medical datasets") 310 | parser.add_argument("--dataset", type=str, required=True, help="Specify dataset name") 311 | parser.add_argument("--qa_type", type=str, choices=["mc", "ff"], required=True, help="QA type: multiple-choice (mc) or free-form (ff)") 312 | parser.add_argument("--model", type=str, default="qwen-vl-max", help="Model key to use for all agent steps") 313 | args = parser.parse_args() 314 | 315 | method_name = "HealthcareAgent" 316 | 317 | # Set up paths 318 | logs_dir = os.path.join("logs", "medqa", args.dataset, "multiple_choice" if args.qa_type == "mc" else "free-form", method_name) 319 | os.makedirs(logs_dir, exist_ok=True) 320 | data_path = f"./my_datasets/processed/medqa/{args.dataset}/medqa_{args.qa_type}_test.json" 321 | 322 | # Load data 323 | if not os.path.exists(data_path): 324 | print(f"Error: Dataset file not found at {data_path}") 325 | return 326 | data = load_json(data_path) 327 | print(f"Loaded {len(data)} samples from {data_path}") 328 | 329 | # Initialize the framework 330 | framework = HealthcareAgentFramework(model_key=args.model) 331 | 332 | # Process each item in the dataset 333 | for item in tqdm(data, desc=f"Running HealthcareAgent on {args.dataset}"): 334 | qid = item["qid"] 335 | result_path = os.path.join(logs_dir, f"{qid}-result.json") 336 | 337 | if os.path.exists(result_path): 338 | print(f"Skipping {qid} - already processed") 339 | continue 340 | 341 | try: 342 | result = framework.run_query(item) 343 | save_json(result, result_path) 344 | except Exception as e: 345 | print(f"CRITICAL MAIN LOOP ERROR processing item {qid}: {e}") 346 | # Save an error file 347 | error_result = { 348 | "qid": qid, 349 | "error": str(e), 350 | "timestamp": int(time.time()) 351 | } 352 | save_json(error_result, result_path) 353 | 354 | if __name__ == "__main__": 355 | main() -------------------------------------------------------------------------------- /medagentboard/medqa/multi_agent_mac.py: -------------------------------------------------------------------------------- 1 | """ 2 | medagentboard/medqa/multi_agent_mac.py 3 | """ 4 | 5 | import os 6 | import time 7 | import argparse 8 | import json 9 | from openai import OpenAI 10 | from enum import Enum 11 | from typing import Dict, Any, Optional, List, Union 12 | from tqdm import tqdm 13 | 14 | from medagentboard.utils.llm_configs import LLM_MODELS_SETTINGS 15 | from medagentboard.utils.encode_image import encode_image 16 | from medagentboard.utils.json_utils import load_json, save_json, preprocess_response_string 17 | 18 | # --- Constants and Enums --- 19 | 20 | class AgentRole(Enum): 21 | """Enumeration for different agent roles in the MAC framework.""" 22 | DOCTOR = "Doctor" 23 | SUPERVISOR = "Supervisor" 24 | 25 | # Default settings from the paper and for the framework 26 | DEFAULT_DOCTOR_MODEL = "qwen-vl-max" 27 | DEFAULT_SUPERVISOR_MODEL = "qwen-vl-max" # Supervisor might need strong reasoning 28 | DEFAULT_NUM_DOCTORS = 4 # Optimal number identified in the paper 29 | DEFAULT_MAX_ROUNDS = 5 # Paper mentions up to 13, but 5 is a practical starting point to balance performance and cost 30 | 31 | # --- Base Agent Class (as provided) --- 32 | 33 | class BaseAgent: 34 | """Base class for all agents in the MAC framework, adapted from your code.""" 35 | 36 | def __init__(self, 37 | agent_id: str, 38 | role: Union[AgentRole, str], 39 | model_key: str, 40 | instruction: Optional[str] = None): 41 | """ 42 | Initialize the base agent. 43 | 44 | Args: 45 | agent_id: Unique identifier for the agent. 46 | role: The role of the agent (Doctor, Supervisor). 47 | model_key: Key for the LLM model configuration in LLM_MODELS_SETTINGS. 48 | instruction: System-level instruction defining the agent's persona and task. 49 | """ 50 | self.agent_id = agent_id 51 | self.role = role if isinstance(role, str) else role.value 52 | self.model_key = model_key 53 | self.instruction = instruction or f"You are a helpful assistant playing the role of a {self.role}." 54 | 55 | if model_key not in LLM_MODELS_SETTINGS: 56 | raise ValueError(f"Model key '{model_key}' not found in LLM_MODELS_SETTINGS") 57 | 58 | model_settings = LLM_MODELS_SETTINGS[model_key] 59 | self.llm_client = OpenAI( 60 | api_key=model_settings["api_key"], 61 | base_url=model_settings["base_url"], 62 | ) 63 | self.model_name = model_settings["model_name"] 64 | print(f"Initialized Agent: ID={self.agent_id}, Role={self.role}, Model={self.model_key} ({self.model_name})") 65 | 66 | def call_llm(self, 67 | messages: List[Dict[str, Any]], 68 | max_retries: int = 3) -> str: 69 | """ 70 | Call the LLM with a list of messages and handle retries. 71 | 72 | Args: 73 | messages: List of message dictionaries. 74 | max_retries: Maximum number of retry attempts. 75 | 76 | Returns: 77 | LLM response text. 78 | """ 79 | retries = 0 80 | while retries < max_retries: 81 | try: 82 | print(f"Agent {self.agent_id} calling LLM ({self.model_name}). Attempt {retries + 1}/{max_retries}.") 83 | completion = self.llm_client.chat.completions.create( 84 | model=self.model_name, 85 | messages=messages, 86 | response_format={"type": "json_object"} 87 | ) 88 | response = completion.choices[0].message.content 89 | print(f"Agent {self.agent_id} received response successfully.") 90 | return response 91 | except Exception as e: 92 | retries += 1 93 | print(f"LLM API call error for agent {self.agent_id} (attempt {retries}/{max_retries}): {e}") 94 | if retries >= max_retries: 95 | raise Exception(f"LLM API call failed for agent {self.agent_id} after {max_retries} attempts: {e}") 96 | time.sleep(2) 97 | raise Exception(f"LLM call failed unexpectedly for agent {self.agent_id}.") 98 | 99 | # --- MAC Framework Class --- 100 | 101 | class MACFramework: 102 | """ 103 | Orchestrates the Multi-Agent Conversation (MAC) workflow based on the paper. 104 | This framework facilitates a discussion between multiple Doctor agents and a Supervisor agent. 105 | """ 106 | 107 | def __init__(self, 108 | log_dir: str, 109 | dataset_name: str, 110 | doctor_model_key: str = DEFAULT_DOCTOR_MODEL, 111 | supervisor_model_key: str = DEFAULT_SUPERVISOR_MODEL, 112 | num_doctors: int = DEFAULT_NUM_DOCTORS, 113 | max_rounds: int = DEFAULT_MAX_ROUNDS): 114 | """ 115 | Initialize the MAC framework orchestrator. 116 | 117 | Args: 118 | log_dir: Directory to save logs and results. 119 | dataset_name: Name of the dataset being processed. 120 | doctor_model_key: Model key for all Doctor agents. 121 | supervisor_model_key: Model key for the Supervisor agent. 122 | num_doctors: The number of Doctor agents to use in the conversation. 123 | max_rounds: The maximum number of conversational rounds. 124 | """ 125 | self.log_dir = log_dir 126 | self.dataset_name = dataset_name 127 | self.num_doctors = num_doctors 128 | self.max_rounds = max_rounds 129 | os.makedirs(self.log_dir, exist_ok=True) 130 | 131 | # --- Initialize Agents based on paper's roles --- 132 | self.doctor_agents = [ 133 | BaseAgent( 134 | agent_id=f"doctor_{i+1}", 135 | role=AgentRole.DOCTOR, 136 | model_key=doctor_model_key, 137 | instruction=( 138 | "You are an expert medical professional. Your task is to analyze the provided medical case, which includes a question, optional multiple-choice options, and possibly an image. " 139 | "You will participate in a multi-agent discussion. In each round, review the conversation history and the opinions of other doctors. " 140 | "Then, provide your own updated analysis, clearly stating your reasoning and conclusion. " 141 | "If you change your mind based on others' arguments, explain why. Your goal is to contribute to a correct and well-reasoned consensus. " 142 | "Respond in JSON format with 'explanation' and 'answer' fields." 143 | ) 144 | ) for i in range(num_doctors) 145 | ] 146 | 147 | self.supervisor_agent = BaseAgent( 148 | agent_id="supervisor", 149 | role=AgentRole.SUPERVISOR, 150 | model_key=supervisor_model_key, 151 | instruction=( 152 | "You are the Supervisor of a medical multi-agent discussion. Your role is to facilitate the conversation and drive towards a consensus. " 153 | "After each round of discussion among the Doctor agents, you will: " 154 | "1. Summarize the current state of the discussion, noting points of agreement and disagreement. " 155 | "2. Challenge the doctors' reasoning if it seems weak or contradictory. " 156 | "3. Evaluate if a consensus has been reached. A consensus is defined as strong agreement among the majority of doctors on both the answer and the core reasoning. " 157 | "4. If consensus is reached or this is the final round, provide the final definitive answer. " 158 | "Respond in JSON format with 'summary' (your analysis of the round), 'consensus_reached' (boolean), and 'final_answer' (your final concluded answer, which can be null if consensus is not yet reached)." 159 | ) 160 | ) 161 | 162 | print("MACFramework Initialized.") 163 | print(f" - Log Directory: {self.log_dir}") 164 | print(f" - Dataset: {self.dataset_name}") 165 | print(f" - Models: Doctors={doctor_model_key}, Supervisor={supervisor_model_key}") 166 | print(f" - Settings: Doctors={self.num_doctors}, Max Rounds={self.max_rounds}") 167 | 168 | def _format_initial_prompt(self, data_item: Dict[str, Any]) -> List[Dict[str, Any]]: 169 | """Formats the initial problem statement from the Admin Agent's perspective.""" 170 | question = data_item["question"] 171 | options = data_item.get("options") 172 | image_path = data_item.get("image_path") 173 | 174 | # The user message content can be a list (for VQA) or a string (for QA) 175 | user_content: Union[str, List[Dict[str, Any]]] 176 | 177 | prompt_text = f"A new case has been presented. Please begin the diagnostic discussion.\n\n--- Case Information ---\nQuestion: {question}\n" 178 | if options: 179 | options_str = "\n".join([f"({k}) {v}" for k, v in options.items()]) 180 | prompt_text += f"Options:\n{options_str}\n" 181 | 182 | if image_path: 183 | if not os.path.exists(image_path): 184 | raise FileNotFoundError(f"Image path does not exist: {image_path}") 185 | base64_image = encode_image(image_path) 186 | user_content = [ 187 | {"type": "text", "text": prompt_text}, 188 | {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, 189 | ] 190 | else: 191 | user_content = prompt_text 192 | 193 | return [{"role": "user", "content": user_content}] 194 | 195 | def _format_conversation_history(self, history: List[Dict]) -> str: 196 | """Converts the list of conversation turns into a readable string.""" 197 | formatted_history = "--- Start of Conversation History ---\n" 198 | for turn in history: 199 | # Handle different content structures (string vs. list for VQA) 200 | content = turn.get("content") 201 | if isinstance(content, list): 202 | # Extract text part for history 203 | text_content = next((item['text'] for item in content if item['type'] == 'text'), "") 204 | content_str = f"{text_content} [Image was provided]" 205 | else: 206 | content_str = str(content) 207 | 208 | # The role is now a string like 'Doctor (doctor_1)' 209 | formatted_history += f"Role: {turn['role']}\n" 210 | formatted_history += f"Message: {content_str}\n" 211 | formatted_history += "-------------------------------------\n" 212 | formatted_history += "--- End of Conversation History ---\n" 213 | return formatted_history 214 | 215 | def run_query(self, data_item: Dict) -> Dict: 216 | """ 217 | Processes a single data item through the MAC framework. 218 | 219 | Args: 220 | data_item: Dictionary containing query details. 221 | 222 | Returns: 223 | A dictionary containing the full results and conversation log. 224 | """ 225 | qid = data_item["qid"] 226 | print(f"\n{'='*20} Processing QID: {qid} {'='*20}") 227 | start_time = time.time() 228 | 229 | conversation_log = [] 230 | final_answer_obj = {"answer": "Error", "explanation": "Processing failed to produce a final answer."} 231 | 232 | try: 233 | # The 'Admin Agent' provides the initial information 234 | initial_messages = self._format_initial_prompt(data_item) 235 | conversation_log.append({ 236 | "role": "Admin", 237 | "content": initial_messages[0]['content'] 238 | }) 239 | 240 | for round_num in range(1, self.max_rounds + 1): 241 | print(f"\n--- Starting Round {round_num}/{self.max_rounds} for QID: {qid} ---") 242 | 243 | # --- Doctors' Turn --- 244 | round_doctor_responses = [] 245 | for doctor in self.doctor_agents: 246 | history_str = self._format_conversation_history(conversation_log) 247 | doctor_prompt = ( 248 | f"{history_str}\n" 249 | f"This is round {round_num}. Based on the full conversation history, provide your updated analysis. " 250 | "If other doctors have provided compelling arguments, acknowledge them and refine your position. " 251 | "State your current answer and explanation clearly." 252 | ) 253 | 254 | # The message list for the LLM needs to include the initial prompt with the image 255 | messages_for_llm = [ 256 | {"role": "system", "content": doctor.instruction}, 257 | *initial_messages, 258 | {"role": "user", "content": doctor_prompt} 259 | ] 260 | 261 | response_str = doctor.call_llm(messages_for_llm) 262 | round_doctor_responses.append({ 263 | "role": f"Doctor ({doctor.agent_id})", 264 | "content": response_str 265 | }) 266 | 267 | conversation_log.extend(round_doctor_responses) 268 | 269 | # --- Supervisor's Turn --- 270 | print(f"\n--- Supervisor Turn for Round {round_num} ---") 271 | history_str = self._format_conversation_history(conversation_log) 272 | supervisor_prompt = ( 273 | f"{history_str}\n" 274 | f"This is the end of round {round_num}. As the Supervisor, please analyze the doctors' latest inputs. " 275 | "Provide your summary, challenge any weak points, and determine if consensus has been reached. " 276 | f"If consensus is met or if this is the final round ({self.max_rounds}), you must provide the 'final_answer'." 277 | ) 278 | 279 | # Supervisor does not need the image, only the text discussion 280 | messages_for_llm = [ 281 | {"role": "system", "content": self.supervisor_agent.instruction}, 282 | {"role": "user", "content": supervisor_prompt} 283 | ] 284 | 285 | supervisor_response_str = self.supervisor_agent.call_llm(messages_for_llm) 286 | conversation_log.append({ 287 | "role": f"Supervisor ({self.supervisor_agent.agent_id})", 288 | "content": supervisor_response_str 289 | }) 290 | 291 | # Parse supervisor's response to check for consensus 292 | try: 293 | supervisor_json = json.loads(preprocess_response_string(supervisor_response_str)) 294 | consensus_reached = supervisor_json.get("consensus_reached", False) 295 | final_answer_from_supervisor = supervisor_json.get("final_answer") 296 | 297 | print(f"Supervisor Summary: {supervisor_json.get('summary', 'N/A')}") 298 | print(f"Consensus Reached: {consensus_reached}") 299 | 300 | if final_answer_from_supervisor: 301 | # The final answer could be a string or a dict. We want a dict. 302 | if isinstance(final_answer_from_supervisor, dict) and "answer" in final_answer_from_supervisor: 303 | final_answer_obj = final_answer_from_supervisor 304 | else: 305 | # If it's not a dict, we wrap it. This is a fallback. 306 | final_answer_obj = {"answer": final_answer_from_supervisor, "explanation": supervisor_json.get('summary', '')} 307 | 308 | if consensus_reached: 309 | print("Consensus reached. Ending conversation.") 310 | break 311 | 312 | if round_num == self.max_rounds and not final_answer_from_supervisor: 313 | print("Max rounds reached. Supervisor did not provide a final answer. Forcing a final decision.") 314 | # This would be a place to make one last call to the supervisor asking for a forced decision. 315 | # For simplicity, we'll use the last summary as the basis for the answer. 316 | final_answer_obj = { 317 | "answer": "Inconclusive", 318 | "explanation": f"Max rounds reached without a clear final answer. Last summary: {supervisor_json.get('summary', 'N/A')}" 319 | } 320 | 321 | 322 | except json.JSONDecodeError: 323 | print(f"Error: Could not parse supervisor's response in round {round_num}. Continuing.") 324 | except Exception as e: 325 | print(f"An unexpected error occurred while processing supervisor response: {e}") 326 | 327 | except Exception as e: 328 | print(f"ERROR processing QID {qid}: {e}") 329 | final_answer_obj = {"answer": "Error", "explanation": str(e)} 330 | 331 | processing_time = time.time() - start_time 332 | print(f"Finished QID: {qid}. Time: {processing_time:.2f}s") 333 | 334 | # Assemble final result object 335 | final_result = { 336 | "qid": qid, 337 | "timestamp": int(time.time()), 338 | "question": data_item["question"], 339 | "options": data_item.get("options"), 340 | "image_path": data_item.get("image_path"), 341 | "ground_truth": data_item.get("answer"), 342 | "predicted_answer": final_answer_obj.get("answer", "Error"), 343 | "explanation": final_answer_obj.get("explanation", "N/A"), 344 | "processing_time_seconds": processing_time, 345 | "details": { 346 | "conversation_log": conversation_log 347 | } 348 | } 349 | 350 | return final_result 351 | 352 | def run_dataset(self, data: List[Dict]): 353 | """ 354 | Runs the MAC framework over an entire dataset. 355 | 356 | Args: 357 | data: List of data items (dictionaries). 358 | """ 359 | print(f"\nStarting MAC framework processing for {len(data)} items in dataset '{self.dataset_name}'.") 360 | 361 | for item in tqdm(data, desc=f"Running MAC on {self.dataset_name}"): 362 | qid = item.get("qid", "unknown_qid") 363 | result_path = os.path.join(self.log_dir, f"{qid}-result.json") 364 | 365 | if os.path.exists(result_path): 366 | print(f"Skipping {qid} - result file already exists.") 367 | continue 368 | 369 | try: 370 | result = self.run_query(item) 371 | save_json(result, result_path) 372 | except Exception as e: 373 | print(f"FATAL ERROR during run_query for QID {qid}: {e}") 374 | # Save an error record 375 | error_result = {"qid": qid, "error": str(e)} 376 | save_json(error_result, result_path) 377 | 378 | print(f"Finished processing dataset '{self.dataset_name}'. Results saved in {self.log_dir}") 379 | 380 | 381 | def main(): 382 | parser = argparse.ArgumentParser(description="Run MAC Framework on medical datasets") 383 | parser.add_argument("--dataset", type=str, required=True, help="Specify dataset name (e.g., vqa_rad, pathvqa, medqa)") 384 | parser.add_argument("--qa_type", type=str, choices=["mc", "ff"], required=True, help="QA type: multiple-choice (mc) or free-form (ff)") 385 | parser.add_argument("--doctor_model", type=str, default=DEFAULT_DOCTOR_MODEL, help="Model key for the Doctor agents") 386 | parser.add_argument("--supervisor_model", type=str, default=DEFAULT_SUPERVISOR_MODEL, help="Model key for the Supervisor agent") 387 | parser.add_argument("--num_doctors", type=int, default=DEFAULT_NUM_DOCTORS, help="Number of doctor agents to use") 388 | parser.add_argument("--max_rounds", type=int, default=DEFAULT_MAX_ROUNDS, help="Maximum number of discussion rounds") 389 | 390 | args = parser.parse_args() 391 | 392 | method_name = "MAC" 393 | 394 | data_path = f"./my_datasets/processed/medqa/{args.dataset}/medqa_{args.qa_type}_test.json" 395 | logs_dir = os.path.join("./logs", "medqa", args.dataset, 396 | "multiple_choice" if args.qa_type == "mc" else "free-form", 397 | method_name) 398 | os.makedirs(logs_dir, exist_ok=True) 399 | print(f"Using Log Directory: {logs_dir}") 400 | 401 | if not os.path.exists(data_path): 402 | print(f"Error: Dataset file not found at {data_path}") 403 | return 404 | 405 | data = load_json(data_path) 406 | print(f"Loaded {len(data)} samples from {data_path}") 407 | 408 | framework = MACFramework( 409 | log_dir=logs_dir, 410 | dataset_name=args.dataset, 411 | doctor_model_key=args.doctor_model, 412 | supervisor_model_key=args.supervisor_model, 413 | num_doctors=args.num_doctors, 414 | max_rounds=args.max_rounds 415 | ) 416 | 417 | framework.run_dataset(data) 418 | 419 | if __name__ == "__main__": 420 | main() -------------------------------------------------------------------------------- /medagentboard/laysummary/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import nltk 5 | import evaluate 6 | from transformers import set_seed 7 | import argparse 8 | from rouge_score import rouge_scorer 9 | import textstat 10 | import random 11 | 12 | # Make sure NLTK data is available 13 | try: 14 | nltk.data.find('tokenizers/punkt') 15 | except nltk.downloader.DownloadError: 16 | print("Downloading NLTK punkt tokenizer...") 17 | nltk.download('punkt', quiet=True) 18 | 19 | # ======================================================================== 20 | # Keep all your functions (bootstrap, calc_rouge, get_metrics, etc.) 21 | # exactly as they were in your *original* script, with the 22 | # small modification to bootstrap to use 'id'. 23 | # ======================================================================== 24 | 25 | def bootstrap(data_path): 26 | """ 27 | Bootstrap the sample ids with replacement, the size of the sample is the same as the original data. 28 | --- MODIFIED TO USE 'id' key --- 29 | """ 30 | data_ids = [] # Changed from data_qids 31 | try: 32 | with open(data_path, "r", encoding='utf-8') as f: 33 | data = json.load(f) 34 | if not isinstance(data, list): 35 | print(f"Error: Expected a list in {data_path}, found {type(data)}") 36 | return [] # Return empty list on error 37 | for datum in data: 38 | # Check if datum is a dict and has the 'id' key 39 | if isinstance(datum, dict) and "id" in datum: 40 | data_ids.append(datum["id"]) # Use 'id' key 41 | else: 42 | print(f"Warning: Skipping item in {data_path} missing 'id' or not a dict: {datum}") 43 | 44 | except FileNotFoundError: 45 | print(f"Error: Bootstrap data file not found: {data_path}") 46 | return [] 47 | except json.JSONDecodeError: 48 | print(f"Error: Could not decode JSON from {data_path}") 49 | return [] 50 | except Exception as e: 51 | print(f"Error reading bootstrap file {data_path}: {e}") 52 | return [] 53 | 54 | if not data_ids: 55 | print(f"Warning: No IDs found in {data_path}") 56 | return [] 57 | 58 | # Use the global SEED for reproducibility if desired, or manage state differently 59 | # Ensure numpy random state is consistent if used repeatedly 60 | # np.random.seed(SEED) # You might need to pass SEED or handle this globally 61 | return [data_ids[i] for i in np.random.randint(0, len(data_ids), len(data_ids))] 62 | 63 | 64 | def calc_rouge(preds, refs): 65 | # Get ROUGE F1 scores 66 | scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], \ 67 | use_stemmer=True, split_summaries=True) 68 | # Ensure refs is a list of strings, not list of lists if only one reference 69 | processed_refs = [ref[0] if isinstance(ref, list) else ref for ref in refs] 70 | scores = [] 71 | for i, p in enumerate(preds): 72 | try: 73 | # Basic check for empty strings 74 | if not p or not processed_refs[i]: 75 | scores.append({'rouge1': rouge_scorer.Score(0,0,0), 76 | 'rouge2': rouge_scorer.Score(0,0,0), 77 | 'rougeLsum': rouge_scorer.Score(0,0,0)}) 78 | else: 79 | scores.append(scorer.score(p, processed_refs[i])) 80 | except Exception as e: 81 | print(f"Warning: ROUGE calculation error for item {i}: {e}. Assigning 0.") 82 | scores.append({'rouge1': rouge_scorer.Score(0,0,0), 83 | 'rouge2': rouge_scorer.Score(0,0,0), 84 | 'rougeLsum': rouge_scorer.Score(0,0,0)}) # Assign 0 score on error 85 | # Handle case where no scores were calculated 86 | if not scores: return 0.0, 0.0, 0.0 87 | # Original script didn't multiply by 100, reverting that too 88 | return np.mean([s['rouge1'].fmeasure for s in scores]), \ 89 | np.mean([s['rouge2'].fmeasure for s in scores]), \ 90 | np.mean([s['rougeLsum'].fmeasure for s in scores]) 91 | 92 | # def calc_bertscore(preds, refs): # Your original commented out function 93 | # # Get BERTScore F1 scores 94 | # P, R, F1 = score(preds, refs, lang="en", verbose=True, device='cuda:0') 95 | # return np.mean(F1.tolist()) 96 | 97 | def calc_readability(preds): 98 | fkgl_scores = [] 99 | cli_scores = [] 100 | dcrs_scores = [] 101 | for pred in preds: 102 | try: 103 | # Handle potential empty strings for textstat 104 | if not pred or not pred.strip(): 105 | fkgl_scores.append(0) # Assign a default, e.g., 0 or handle as needed 106 | cli_scores.append(0) 107 | dcrs_scores.append(0) 108 | else: 109 | fkgl_scores.append(textstat.flesch_kincaid_grade(pred)) 110 | cli_scores.append(textstat.coleman_liau_index(pred)) 111 | dcrs_scores.append(textstat.dale_chall_readability_score(pred)) 112 | except Exception as e: 113 | print(f"Warning: Readability calculation error: {e}. Assigning 0 score.") 114 | fkgl_scores.append(0) 115 | cli_scores.append(0) 116 | dcrs_scores.append(0) 117 | if not fkgl_scores: return 0.0, 0.0, 0.0 # Handle empty input 118 | return np.mean(fkgl_scores), np.mean(cli_scores), np.mean(dcrs_scores) 119 | 120 | 121 | def get_metrics(preds, goldens, sources, seed): 122 | # Set reproducibility (as in your original) 123 | SEED = seed 124 | os.environ['PYTHONHASHSEED']=str(SEED) 125 | # Note: set_seed affects transformers randomness, numpy affects np.random, random affects random module 126 | random.seed(SEED) 127 | np.random.seed(SEED) 128 | set_seed(SEED) 129 | 130 | if not preds or not goldens or not sources: 131 | print("Warning: Empty input list(s) provided to get_metrics. Returning zero scores.") 132 | # Return structure matching original expectations (14 zeros) 133 | return 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 134 | 135 | # SARI requires tokenized inputs (using original formatting style) 136 | sari_preds = [" ".join(nltk.sent_tokenize(pred.strip())) for pred in preds] 137 | sari_refs = [[" ".join(nltk.sent_tokenize(golden.strip()))] for golden in goldens] 138 | sari_sources = sources # Original used sources directly, keep that way unless tokenization is strictly needed by evaluate 139 | sari_score = 0 140 | try: 141 | metric = evaluate.load('sari', seed=SEED) # Keep seed=SEED if original had it 142 | sari_result = metric.compute(sources=sari_sources, predictions = sari_preds, references = sari_refs) 143 | sari_score = sari_result['sari'] 144 | except Exception as e: 145 | print(f"Error calculating SARI: {e}. Assigning 0.") 146 | 147 | # ROUGE uses sentence tokenization and newlines (original formatting style) 148 | rouge_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in preds] 149 | rouge_refs = ["\n".join(nltk.sent_tokenize(golden.strip())) for golden in goldens] 150 | rouge1, rouge2, rougeL = calc_rouge(rouge_preds, rouge_refs) 151 | 152 | # SACREBLEU needs tokenized inputs, references as list of lists (original formatting style) 153 | bleu_preds = [" ".join(nltk.sent_tokenize(pred.strip())) for pred in preds] 154 | bleu_refs = [[" ".join(nltk.sent_tokenize(golden.strip()))] for golden in goldens] 155 | bleu_score = 0 156 | try: 157 | metric = evaluate.load('sacrebleu', seed=SEED) # Keep seed=SEED if original had it 158 | bleu_result = metric.compute(predictions=bleu_preds, references=bleu_refs) 159 | bleu_score = bleu_result['score'] # Original script returned 'score' 160 | except Exception as e: 161 | print(f"Error calculating BLEU: {e}. Assigning 0.") 162 | 163 | # BertScore (if used) 164 | # bertscore = calc_bertscore(...) # Keep commented as original 165 | 166 | # Readability for (1) sources, (2) goldens, and (3) model outputs (use original texts) 167 | fkgl_abs, cli_abs, dcrs_abs = calc_readability(sources) 168 | fkgl_pls, cli_pls, dcrs_pls = calc_readability(goldens) 169 | fkgl_model, cli_model, dcrs_model = calc_readability(preds) 170 | 171 | # Return tuple in the original order 172 | return sari_score, rouge1, rouge2, rougeL, bleu_score, fkgl_abs, cli_abs, dcrs_abs, fkgl_pls, cli_pls, dcrs_pls, fkgl_model, cli_model, dcrs_model 173 | 174 | 175 | # ======================================================================== 176 | # MAIN SCRIPT LOGIC 177 | # ======================================================================== 178 | 179 | if __name__ == "__main__": 180 | SEED = 42 181 | # Set global seeds once (as in original) 182 | random.seed(SEED) 183 | np.random.seed(SEED) 184 | os.environ['PYTHONHASHSEED']=str(SEED) 185 | set_seed(SEED) # Primarily for transformers, but good practice 186 | 187 | # --- Original Argument Parsing --- 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument('--logs_dir', type=str, default='logs/laysummary') 190 | parser.add_argument("--bootstrap", type=bool, default=True, help="Whether to bootstrap the data") # Original default was True 191 | parser.add_argument("--n_bootstrap", type=int, default=10, help="Number of bootstrap samples") # Original default was 10 192 | args = parser.parse_args() 193 | # --------------------------------- 194 | 195 | # Check base logs directory 196 | if not os.path.isdir(args.logs_dir): 197 | print(f"Error: Logs directory not found: {args.logs_dir}") 198 | exit(1) 199 | 200 | # Loop through the dataset directories inside logs_dir 201 | for dataset_dir_name in os.listdir(args.logs_dir): 202 | dataset_path = os.path.join(args.logs_dir, dataset_dir_name) 203 | if not os.path.isdir(dataset_path): 204 | continue # Skip files, only process directories 205 | 206 | dataset_name = dataset_dir_name # Use the directory name as dataset identifier 207 | print(f"\n{'='*10} Dataset: {dataset_name} {'='*10}") 208 | 209 | # Define the path to the reference test data *using the original assumption* 210 | # Assumes structure like 'my_datasets/laysummary/dataset_1/test.json' relative to script execution? 211 | # Or maybe it was expected to be somewhere else based on original script context. 212 | # THIS PATH MAY NEED ADJUSTMENT based on where your original script expected 'test.json' 213 | data_path = os.path.join("my_datasets/laysummary", dataset_name, "test.json") # Original path structure 214 | 215 | if not os.path.exists(data_path): 216 | print(f"Warning: Reference data file not found at {data_path}. Skipping bootstrapping for dataset {dataset_name}.") 217 | # Decide if you want to skip the whole dataset or just bootstrapping 218 | # continue # Option: skip dataset entirely if ref data missing 219 | _skip_bootstrap_for_dataset = True # Flag to skip bootstrap section 220 | else: 221 | _skip_bootstrap_for_dataset = False 222 | 223 | 224 | # Define the order of models to process (as original) 225 | model_order = ["AgentSimp", "SingleLLM_basic", "SingleLLM_few_shot", "SingleLLM_optimized", "bart-base", "t5-base","bart-large-cnn", "pegasus-large"] # Your model list 226 | 227 | # --- Generate bootstrap IDs if needed --- 228 | ids = [] 229 | # Only bootstrap if enabled AND reference data exists 230 | if args.bootstrap and not _skip_bootstrap_for_dataset: 231 | print(f"Generating {args.n_bootstrap} bootstrap ID samples from {data_path}...") 232 | ids_generated = True 233 | for i in range(args.n_bootstrap): 234 | sample = bootstrap(data_path) 235 | if not sample: # Check if bootstrap function returned an empty list (error) 236 | print(f"Error generating bootstrap sample {i+1} for dataset {dataset_name}. Disabling bootstrap for this dataset.") 237 | ids = [] 238 | ids_generated = False 239 | break # Stop trying to generate for this dataset 240 | ids.append(sample) 241 | if ids_generated: 242 | print("Bootstrap ID generation complete.") 243 | elif args.bootstrap and _skip_bootstrap_for_dataset: 244 | print("Skipping bootstrap ID generation because reference file was not found.") 245 | # --------------------------------------- 246 | 247 | # --- Loop through models --- 248 | for model_name in model_order: 249 | # Define the path to the model's *directory* in the new structure 250 | model_results_dir = os.path.join(dataset_path, model_name) 251 | 252 | # Check if the model *directory* exists 253 | # (Original script checked for model.json in os.listdir - this replaces that check) 254 | if not os.path.isdir(model_results_dir): 255 | # Original script checked os.listdir, so finding the directory means the model exists in some form 256 | # Let's check if the original file existed for compatibility, though we won't use it 257 | original_file_path = os.path.join(dataset_path, f"{model_name}.json") 258 | if os.path.exists(original_file_path): 259 | print(f"Warning: Original file {model_name}.json exists, but directory {model_name} not found. Skipping.") 260 | # else: # Directory doesn't exist, and original file didn't either (or wasn't checked) 261 | # print(f"Info: Directory for model '{model_name}' not found. Skipping.") 262 | continue # Skip this model if its directory doesn't exist 263 | 264 | print(f"\n--- Model: {model_name} ---") 265 | 266 | # === MODIFICATION START: Load data from the new structure === 267 | model_results = [] # This will store the list of dicts, like the original script expected 268 | print(f" Loading results from: {model_results_dir}") 269 | loaded_count = 0 270 | error_count = 0 271 | try: 272 | for filename in os.listdir(model_results_dir): 273 | # Check if the file matches the expected pattern 274 | if filename.startswith("laysummary_") and filename.endswith("-result.json"): 275 | file_path = os.path.join(model_results_dir, filename) 276 | try: 277 | with open(file_path, 'r', encoding='utf-8') as f: 278 | data_item = json.load(f) 279 | # Basic validation: ensure it's a dict and has required keys 280 | if isinstance(data_item, dict) and all(k in data_item for k in ["id", "source", "target", "pred"]): 281 | model_results.append(data_item) 282 | loaded_count += 1 283 | else: 284 | print(f" Warning: Skipping file {filename}. Invalid content or missing keys.") 285 | error_count += 1 286 | except json.JSONDecodeError: 287 | print(f" Warning: Invalid JSON in {filename}. Skipping.") 288 | error_count += 1 289 | except Exception as e: 290 | print(f" Warning: Error reading {filename}: {e}. Skipping.") 291 | error_count += 1 292 | print(f" Successfully loaded {loaded_count} results, skipped {error_count} files.") 293 | 294 | except FileNotFoundError: 295 | print(f" Error: Directory {model_results_dir} not found during file listing.") 296 | continue # Skip to next model 297 | except Exception as e: 298 | print(f" Error listing files in {model_results_dir}: {e}") 299 | continue # Skip to next model 300 | 301 | if not model_results: 302 | print(f" Warning: No valid result files found or loaded for model '{model_name}'. Skipping evaluation.") 303 | continue 304 | # === MODIFICATION END === 305 | 306 | 307 | # --- The rest of the logic remains the same as your original script --- 308 | # --- It now operates on the 'model_results' list loaded above --- 309 | 310 | # Initialize lists to store metrics for each bootstrap run 311 | sari, rouge1, rouge2, rougeL, bleu = [], [], [], [], [] 312 | fkgl_abs, cli_abs, dcrs_abs = [], [], [] 313 | fkgl_pls, cli_pls, dcrs_pls = [], [], [] 314 | fkgl_model, cli_model, dcrs_model = [], [], [] 315 | 316 | # Use the generated bootstrap IDs (ids list) 317 | if args.bootstrap and ids: # Check if bootstrap enabled AND ids were generated successfully 318 | print(f" Calculating metrics using {len(ids)} bootstrap samples...") 319 | # Build a lookup dictionary from the loaded results for efficiency 320 | results_dict = {item['id']: item for item in model_results} 321 | 322 | # Loop through the bootstrap samples (ids[i] contains a list of sampled IDs) 323 | for i in range(len(ids)): # Use len(ids) which matches n_bootstrap unless errors occurred 324 | sampled_ids = ids[i] 325 | # Original script filtered the list - using the dictionary lookup is more efficient here 326 | sampled_model_results = [] 327 | missing_ids = 0 328 | for id_ in sampled_ids: 329 | result = results_dict.get(id_) 330 | if result: 331 | sampled_model_results.append(result) 332 | else: 333 | missing_ids += 1 334 | 335 | if missing_ids > 0: 336 | print(f" Bootstrap sample {i+1}: {missing_ids} IDs not found in loaded results.") 337 | 338 | if not sampled_model_results: 339 | print(f" Warning: No data retrieved for bootstrap sample {i+1}. Skipping.") 340 | continue # Skip this sample if no data was found 341 | 342 | # Get the sources, targets, and preds *for this specific bootstrap sample* 343 | sources = [model_res['source'] for model_res in sampled_model_results] 344 | targets = [model_res['target'] for model_res in sampled_model_results] 345 | preds = [model_res['pred'] for model_res in sampled_model_results] 346 | 347 | # Calculate metrics using your original function 348 | sari_, r1_, r2_, rL_, bleu_, fa_, ca_, da_, fp_, cp_, dp_, fm_, cm_, dm_ = get_metrics(preds, targets, sources, seed=SEED + i) # Vary seed slightly per sample 349 | 350 | # Append metrics for this run 351 | sari.append(sari_) 352 | rouge1.append(r1_) 353 | rouge2.append(r2_) 354 | rougeL.append(rL_) 355 | bleu.append(bleu_) 356 | fkgl_abs.append(fa_) 357 | cli_abs.append(ca_) 358 | dcrs_abs.append(da_) 359 | fkgl_pls.append(fp_) 360 | cli_pls.append(cp_) 361 | dcrs_pls.append(dp_) 362 | fkgl_model.append(fm_) 363 | cli_model.append(cm_) 364 | dcrs_model.append(dm_) 365 | 366 | # --- After looping through bootstrap samples --- 367 | if not sari: # Check if any samples were processed 368 | print(" Warning: No bootstrap samples were successfully processed.") 369 | continue # Skip printing results for this model 370 | 371 | # Get the mean and std of the metrics (using original rounding) 372 | sari_mean, sari_std = round(np.mean(sari), 2), round(np.std(sari), 2) 373 | # Note: ROUGE scores from calc_rouge are now 0-1, so mean/std will be too. 374 | # Multiply by 100 here if you want 0-100 scale output, matching the original print format expectation 375 | rouge1_mean, rouge1_std = round(np.mean(rouge1) * 100, 2), round(np.std(rouge1) * 100, 2) 376 | rouge2_mean, rouge2_std = round(np.mean(rouge2) * 100, 2), round(np.std(rouge2) * 100, 2) 377 | rougeL_mean, rougeL_std = round(np.mean(rougeL) * 100, 2), round(np.std(rougeL) * 100, 2) 378 | # Bleu score from evaluate is 0-100, so no scaling needed here 379 | bleu_mean, bleu_std = round(np.mean(bleu), 2), round(np.std(bleu), 2) 380 | fkgl_abs_mean, fkgl_abs_std = round(np.mean(fkgl_abs), 2), round(np.std(fkgl_abs), 2) 381 | cli_abs_mean, cli_abs_std = round(np.mean(cli_abs), 2), round(np.std(cli_abs), 2) 382 | dcrs_abs_mean, dcrs_abs_std = round(np.mean(dcrs_abs), 2), round(np.std(dcrs_abs), 2) 383 | fkgl_pls_mean, fkgl_pls_std = round(np.mean(fkgl_pls), 2), round(np.std(fkgl_pls), 2) 384 | cli_pls_mean, cli_pls_std = round(np.mean(cli_pls), 2), round(np.std(cli_pls), 2) 385 | dcrs_pls_mean, dcrs_pls_std = round(np.mean(dcrs_pls), 2), round(np.std(dcrs_pls), 2) 386 | fkgl_model_mean, fkgl_model_std = round(np.mean(fkgl_model), 2), round(np.std(fkgl_model), 2) 387 | cli_model_mean, cli_model_std = round(np.mean(cli_model), 2), round(np.std(cli_model), 2) 388 | dcrs_model_mean, dcrs_model_std = round(np.mean(dcrs_model), 2), round(np.std(dcrs_model), 2) 389 | 390 | # Print the results using the original dictionary structure 391 | metrics = { 392 | 'sari': f"{sari_mean} ± {sari_std}", 393 | 'rouge1': f"{rouge1_mean} ± {rouge1_std}", # Now correctly scaled to 0-100 394 | 'rouge2': f"{rouge2_mean} ± {rouge2_std}", # Now correctly scaled to 0-100 395 | 'rougeL': f"{rougeL_mean} ± {rougeL_std}", # Now correctly scaled to 0-100 396 | 'bleu': f"{bleu_mean} ± {bleu_std}", 397 | 'abs_readability': { 398 | 'fkgl': f"{fkgl_abs_mean} ± {fkgl_abs_std}", 399 | 'cli': f"{cli_abs_mean} ± {cli_abs_std}", 400 | 'dcrs': f"{dcrs_abs_mean} ± {dcrs_abs_std}" 401 | }, 402 | 'pls_readability': { 403 | 'fkgl': f"{fkgl_pls_mean} ± {fkgl_pls_std}", 404 | 'cli': f"{cli_pls_mean} ± {cli_pls_std}", 405 | 'dcrs': f"{dcrs_pls_mean} ± {dcrs_pls_std}" 406 | }, 407 | 'model_readability': { 408 | 'fkgl': f"{fkgl_model_mean} ± {fkgl_model_std}", 409 | 'cli': f"{cli_model_mean} ± {cli_model_std}", 410 | 'dcrs': f"{dcrs_model_mean} ± {dcrs_model_std}" 411 | } 412 | } 413 | print(f"Metrics for {model_name} (Bootstrap Mean ± Std Dev):") 414 | # Pretty print the dictionary (original script didn't use json.dumps, just print(metrics)) 415 | print(metrics) # Reverted to original print style 416 | 417 | else: # --- No Bootstrapping --- 418 | # Check if bootstrap was intended but skipped due to missing ref file 419 | if args.bootstrap and _skip_bootstrap_for_dataset: 420 | print(" Skipping full dataset evaluation because bootstrapping was requested but reference file was missing.") 421 | continue # Skip full evaluation for this model too 422 | 423 | print(" Calculating metrics using all loaded data (no bootstrapping)...") 424 | # Use all the data loaded into model_results 425 | sources = [model_res['source'] for model_res in model_results] 426 | targets = [model_res['target'] for model_res in model_results] 427 | preds = [model_res['pred'] for model_res in model_results] 428 | 429 | # Calculate metrics once 430 | sari_, r1_, r2_, rL_, bleu_, fa_, ca_, da_, fp_, cp_, dp_, fm_, cm_, dm_ = get_metrics(preds, targets, sources, seed=SEED) 431 | 432 | # Print results directly (no std dev), scale ROUGE here for printing if needed 433 | print(f"Metrics for {model_name} (Full Dataset):") 434 | print(f" SARI : {round(sari_, 2)}") 435 | print(f" ROUGE1: {round(r1_ * 100, 2)}") # Scale for printing 436 | print(f" ROUGE2: {round(r2_ * 100, 2)}") # Scale for printing 437 | print(f" ROUGEL: {round(rL_ * 100, 2)}") # Scale for printing 438 | print(f" BLEU : {round(bleu_, 2)}") 439 | print(" Readability (Source):") 440 | print(f" FKGL: {round(fa_, 2)}, CLI: {round(ca_, 2)}, DCRS: {round(da_, 2)}") 441 | print(" Readability (Target):") 442 | print(f" FKGL: {round(fp_, 2)}, CLI: {round(cp_, 2)}, DCRS: {round(dp_, 2)}") 443 | print(" Readability (Model):") 444 | print(f" FKGL: {round(fm_, 2)}, CLI: {round(cm_, 2)}, DCRS: {round(dm_, 2)}") 445 | 446 | 447 | print("\nEvaluation script finished.") -------------------------------------------------------------------------------- /medagentboard/medqa/single_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | medagentboard/medqa/single_llm.py 3 | 4 | Unified script for handling both text-only and vision-language model inference 5 | for medical question answering tasks. Supports multiple prompting techniques: 6 | - Zero-shot prompting 7 | - Few-shot prompting with examples 8 | - Chain-of-thought (CoT) prompting 9 | - Self-consistency (majority voting) 10 | - CoT with self-consistency 11 | 12 | Works with both multiple-choice and free-form questions, and can process 13 | image-based questions when an image path is provided. 14 | """ 15 | 16 | from openai import OpenAI 17 | import os 18 | import json 19 | import argparse 20 | from collections import Counter 21 | import time 22 | from tqdm import tqdm 23 | from typing import Dict, Any, Optional, List, Union 24 | 25 | from medagentboard.utils.llm_configs import LLM_MODELS_SETTINGS 26 | from medagentboard.utils.encode_image import encode_image 27 | from medagentboard.utils.json_utils import load_json, save_json, preprocess_response_string 28 | 29 | 30 | class SingleModelInference: 31 | """ 32 | Unified class for running inference with a single LLM or VLLM model 33 | using various prompting techniques. 34 | """ 35 | 36 | def __init__(self, model_key: str = "qwen-max-latest", sample_size: int = 5): 37 | """ 38 | Initialize the inference handler. 39 | 40 | Args: 41 | model_key: Key identifying the model in LLM_MODELS_SETTINGS 42 | sample_size: Number of samples for self-consistency methods 43 | """ 44 | self.model_key = model_key 45 | self.sample_size = sample_size 46 | 47 | if model_key not in LLM_MODELS_SETTINGS: 48 | raise ValueError(f"Model key '{model_key}' not found in LLM_MODELS_SETTINGS") 49 | 50 | # Set up OpenAI client based on model settings 51 | model_settings = LLM_MODELS_SETTINGS[model_key] 52 | self.client = OpenAI( 53 | api_key=model_settings["api_key"], 54 | base_url=model_settings["base_url"], 55 | ) 56 | self.model_name = model_settings["model_name"] 57 | print(f"Initialized SingleModelInference with model: {model_key}, sample_size: {sample_size}") 58 | 59 | def _call_llm(self, 60 | system_message: str, 61 | user_message: Union[str, List], 62 | response_format: Optional[Dict] = None, 63 | n_samples: int = 1, 64 | max_retries: int = 3) -> List[str]: 65 | """ 66 | Call the LLM with messages and handle retries. 67 | 68 | Args: 69 | system_message: System message setting context 70 | user_message: User message (text or multimodal content) 71 | response_format: Optional format specification for response 72 | n_samples: Number of samples to generate 73 | max_retries: Maximum number of retry attempts 74 | 75 | Returns: 76 | List of LLM response texts 77 | """ 78 | retries = 0 79 | all_responses = [] 80 | 81 | # For each sample we need 82 | remaining_samples = n_samples 83 | 84 | while remaining_samples > 0 and retries < max_retries: 85 | try: 86 | messages = [ 87 | {"role": "system", "content": system_message}, 88 | {"role": "user", "content": user_message} 89 | ] 90 | 91 | # Some models might not properly support n > 1, so we make multiple calls if needed 92 | current_n = min(remaining_samples, 1) # Request just 1 at a time to be safe 93 | 94 | completion = self.client.chat.completions.create( 95 | model=self.model_name, 96 | messages=messages, 97 | response_format=response_format, 98 | n=current_n, 99 | stream=False 100 | ) 101 | 102 | responses = [choice.message.content for choice in completion.choices] 103 | all_responses.extend(responses) 104 | remaining_samples -= len(responses) 105 | 106 | # Reset retry counter on successful API call 107 | retries = 0 108 | 109 | except Exception as e: 110 | retries += 1 111 | print(f"LLM API call error (attempt {retries}/{max_retries}): {e}") 112 | if retries >= max_retries: 113 | if all_responses: # If we have some responses, use those rather than failing 114 | print(f"Warning: Only obtained {len(all_responses)}/{n_samples} samples after max retries") 115 | break 116 | else: 117 | raise Exception(f"LLM API call failed after {max_retries} attempts: {e}") 118 | time.sleep(1) # Brief pause before retrying 119 | 120 | return all_responses 121 | 122 | def _prepare_user_message(self, 123 | prompt: str, 124 | image_path: Optional[str] = None) -> Union[str, List]: 125 | """ 126 | Prepare user message with optional image content. 127 | 128 | Args: 129 | prompt: Text prompt 130 | image_path: Optional path to image 131 | 132 | Returns: 133 | User message as string or list for multimodal content 134 | """ 135 | if image_path: 136 | try: 137 | base64_image = encode_image(image_path) 138 | return [ 139 | {"type": "text", "text": prompt}, 140 | {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} 141 | ] 142 | except Exception as e: 143 | print(f"Error encoding image {image_path}: {e}") 144 | # Fall back to text-only if image encoding fails 145 | return prompt 146 | else: 147 | return prompt 148 | 149 | def zero_shot_prompt(self, 150 | question: str, 151 | options: Optional[Dict[str, str]] = None) -> str: 152 | """ 153 | Create a zero-shot prompt for either multiple-choice or free-form questions. 154 | 155 | Args: 156 | question: Question text 157 | options: Optional multiple choice options 158 | 159 | Returns: 160 | Formatted prompt string 161 | """ 162 | if options: 163 | # Multiple choice 164 | options_text = "\n".join([f"{k}: {v}" for k, v in options.items()]) 165 | prompt = ( 166 | f"Question: {question}\n\n" 167 | f"Options:\n{options_text}\n\n" 168 | f"Please respond with the letter of the correct option (A, B, C, etc.) only." 169 | ) 170 | else: 171 | # Free form 172 | prompt = ( 173 | f"Question: {question}\n\n" 174 | f"Please provide a concise and accurate answer." 175 | ) 176 | 177 | return prompt 178 | 179 | def few_shot_prompt(self, 180 | question: str, 181 | options: Optional[Dict[str, str]] = None, 182 | dataset: str = "MedQA") -> str: 183 | """ 184 | Create a few-shot prompt with examples relevant to the dataset. 185 | 186 | Args: 187 | question: Question text 188 | options: Optional multiple choice options 189 | dataset: Dataset name to select appropriate examples 190 | 191 | Returns: 192 | Formatted prompt string with examples 193 | """ 194 | # Define example pairs for different datasets and question types 195 | examples = { 196 | "MedQA_mc": ( 197 | "Example 1: Question: A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. " 198 | "She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. " 199 | "She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), " 200 | "blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. " 201 | "Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. " 202 | "Which of the following is the best treatment for this patient?\n" 203 | "Options:\n" 204 | "A: Ampicillin\n" 205 | "B: Ceftriaxone\n" 206 | "C: Ciprofloxacin\n" 207 | "D: Doxycycline\n" 208 | "E: Nitrofurantoin\n" 209 | "Answer: E\n\n" 210 | 211 | "Example 2: Question: A 3-month-old baby died suddenly at night while asleep. His mother noticed that he had died " 212 | "only after she awoke in the morning. No cause of death was determined based on the autopsy. " 213 | "Which of the following precautions could have prevented the death of the baby?\n" 214 | "Options:\n" 215 | "A: Placing the infant in a supine position on a firm mattress while sleeping\n" 216 | "B: Routine postnatal electrocardiogram (ECG)\n" 217 | "C: Keeping the infant covered and maintaining a high room temperature\n" 218 | "D: Application of a device to maintain the sleeping position\n" 219 | "E: Avoiding pacifier use during sleep\n" 220 | "Answer: A" 221 | ), 222 | 223 | "PubMedQA_mc": ( 224 | "Example 1: Question: Are sugars-free medicines more erosive than sugars-containing medicines?\n" 225 | "Options:\n" 226 | "A: Yes\n" 227 | "B: No\n" 228 | "C: Maybe\n" 229 | "Answer: B\n\n" 230 | 231 | "Example 2: Question: Can autologous platelet-rich plasma gel enhance healing after surgical extraction of mandibular third molars?\n" 232 | "Options:\n" 233 | "A: Yes\n" 234 | "B: No\n" 235 | "C: Maybe\n" 236 | "Answer: A" 237 | ), 238 | 239 | "PubMedQA_ff": ( 240 | "Example 1: Question: Does melatonin supplementation improve sleep quality in adults with primary insomnia?\n" 241 | "Answer: Yes, melatonin supplementation has been shown to improve sleep quality parameters in adults with primary insomnia, " 242 | "including reduced sleep onset latency, increased total sleep time, and improved overall sleep quality without significant adverse effects.\n\n" 243 | 244 | "Example 2: Question: Is chronic stress associated with increased risk of cardiovascular disease?\n" 245 | "Answer: Yes, chronic stress is associated with increased risk of cardiovascular disease through multiple mechanisms, " 246 | "including elevated blood pressure, increased inflammation, endothelial dysfunction, and unhealthy behavioral coping mechanisms." 247 | ), 248 | 249 | "PathVQA_mc": ( 250 | "Example 1: Question: Are bile duct cells stained with this immunohistochemical marker?\n" 251 | "Options:\n" 252 | "A: Yes\n" 253 | "B: No\n" 254 | "Answer: A\n\n" 255 | 256 | "Example 2: Question: Is this a well-differentiated tumor?\n" 257 | "Options:\n" 258 | "A: Yes\n" 259 | "B: No\n" 260 | "Answer: B" 261 | ), 262 | 263 | "VQA-RAD_mc": ( 264 | "Example 1: Question: Is there evidence of pneumonia in this chest X-ray?\n" 265 | "Options:\n" 266 | "A: Yes\n" 267 | "B: No\n" 268 | "Answer: A\n\n" 269 | 270 | "Example 2: Question: Is the heart size enlarged in this radiograph?\n" 271 | "Options:\n" 272 | "A: Yes\n" 273 | "B: No\n" 274 | "Answer: B" 275 | ), 276 | 277 | "VQA-RAD_ff": ( 278 | "Example 1: Question: What abnormality is visible in this CT scan of the abdomen?\n" 279 | "Answer: The CT scan shows a hypodense mass in the liver, approximately 3cm in diameter, with irregular borders, " 280 | "suggestive of a hepatocellular carcinoma. There is also mild splenomegaly but no ascites.\n\n" 281 | 282 | "Example 2: Question: What is the primary finding in this chest X-ray?\n" 283 | "Answer: The chest X-ray demonstrates right upper lobe consolidation with air bronchograms, " 284 | "consistent with lobar pneumonia. No pleural effusion or pneumothorax is identified." 285 | ) 286 | } 287 | 288 | # Determine which example set to use 289 | example_key = f"{dataset}_mc" if options else f"{dataset}_ff" 290 | 291 | # Fallbacks if specific dataset examples are not available 292 | if example_key not in examples: 293 | if options: 294 | example_key = "MedQA_mc" # Default MC examples 295 | else: 296 | example_key = "PubMedQA_ff" # Default FF examples 297 | 298 | example_text = examples[example_key] 299 | 300 | # Format options if provided 301 | if options: 302 | options_text = "\n".join([f"{k}: {v}" for k, v in options.items()]) 303 | prompt = ( 304 | f"Question: {question}\n\n" 305 | f"Options:\n{options_text}\n\n" 306 | f"Please respond with the letter of the correct option (A, B, C, etc.) only.\n\n" 307 | f"Here are some examples for your reference:\n\n{example_text}" 308 | ) 309 | else: 310 | prompt = ( 311 | f"Question: {question}\n\n" 312 | f"Please provide a concise and accurate answer.\n\n" 313 | f"Here are some examples for your reference:\n\n{example_text}" 314 | ) 315 | 316 | return prompt 317 | 318 | def cot_prompt(self, 319 | question: str, 320 | options: Optional[Dict[str, str]] = None) -> str: 321 | """ 322 | Create a chain-of-thought prompt that encourages step-by-step reasoning. 323 | 324 | Args: 325 | question: Question text 326 | options: Optional multiple choice options 327 | 328 | Returns: 329 | Formatted CoT prompt string 330 | """ 331 | if options: 332 | options_text = "\n".join([f"{k}: {v}" for k, v in options.items()]) 333 | response_format = ( 334 | "{\n" 335 | " \"Thought\": \"step-by-step reasoning process\",\n" 336 | " \"Answer\": \"selected option letter (A, B, C, etc.)\"\n" 337 | "}" 338 | ) 339 | 340 | prompt = ( 341 | f"Question: {question}\n\n" 342 | f"Options:\n{options_text}\n\n" 343 | f"Let's work through this step-by-step to find the correct answer.\n\n" 344 | f"Please provide your response in JSON format as follows:\n{response_format}" 345 | ) 346 | else: 347 | response_format = ( 348 | "{\n" 349 | " \"Thought\": \"step-by-step reasoning process\",\n" 350 | " \"Answer\": \"final answer to the question\"\n" 351 | "}" 352 | ) 353 | 354 | prompt = ( 355 | f"Question: {question}\n\n" 356 | f"Let's work through this step-by-step to find the correct answer.\n\n" 357 | f"Please provide your response in JSON format as follows:\n{response_format}" 358 | ) 359 | 360 | return prompt 361 | 362 | def process_item(self, 363 | item: Dict[str, Any], 364 | prompt_type: str, 365 | dataset: str) -> Dict[str, Any]: 366 | """ 367 | Process a single item using the specified prompting technique. 368 | 369 | Args: 370 | item: Input data dictionary with question, options, etc. 371 | prompt_type: Type of prompting to use 372 | dataset: Dataset name 373 | 374 | Returns: 375 | Result dictionary with predicted answer and metadata 376 | """ 377 | start_time = time.time() 378 | 379 | # Extract item fields 380 | qid = item.get("qid", "unknown") 381 | question = item.get("question", "") 382 | options = item.get("options") 383 | image_path = item.get("image_path") 384 | ground_truth = item.get("answer", "") 385 | 386 | print(f"Processing {qid} with {prompt_type} prompting") 387 | 388 | # Determine if it's a multiple-choice or free-form question 389 | is_mc = options is not None 390 | 391 | # Set system message based on task type 392 | if image_path: 393 | system_message = "You are a medical vision expert analyzing medical images and answering questions about them." 394 | else: 395 | system_message = "You are a medical expert answering medical questions with precise and accurate information." 396 | 397 | # Generate prompt based on technique 398 | if prompt_type == "zero_shot": 399 | prompt = self.zero_shot_prompt(question, options) 400 | response_format = None 401 | n_samples = 1 402 | 403 | elif prompt_type == "few_shot": 404 | prompt = self.few_shot_prompt(question, options, dataset) 405 | response_format = None 406 | n_samples = 1 407 | 408 | elif prompt_type == "cot": 409 | prompt = self.cot_prompt(question, options) 410 | response_format = {"type": "json_object"} 411 | n_samples = 1 412 | 413 | elif prompt_type == "self_consistency": 414 | # For self-consistency, use zero-shot but with multiple samples 415 | prompt = self.zero_shot_prompt(question, options) 416 | response_format = None 417 | n_samples = self.sample_size # Use the configured sample_size 418 | 419 | elif prompt_type == "cot_sc": 420 | # For CoT with self-consistency, use CoT with multiple samples 421 | prompt = self.cot_prompt(question, options) 422 | response_format = {"type": "json_object"} 423 | n_samples = self.sample_size # Use the configured sample_size 424 | 425 | else: 426 | raise ValueError(f"Unknown prompt type: {prompt_type}") 427 | 428 | # Prepare user message (text-only or multimodal) 429 | user_message = self._prepare_user_message(prompt, image_path) 430 | 431 | # Call LLM to get responses 432 | responses = self._call_llm( 433 | system_message=system_message, 434 | user_message=user_message, 435 | response_format=response_format, 436 | n_samples=n_samples 437 | ) 438 | 439 | voting_details = None 440 | # Process responses based on prompt type 441 | if prompt_type in ["cot", "cot_sc"]: 442 | # Extract answers from JSON responses 443 | parsed_responses = [] 444 | 445 | for response in responses: 446 | try: 447 | parsed = json.loads(preprocess_response_string(response)) 448 | thought = parsed.get("Thought", "") or parsed.get("thought", "") 449 | answer = parsed.get("Answer", "") or parsed.get("answer", "") 450 | parsed_responses.append({"thought": thought, "answer": answer, "full_response": response}) 451 | except json.JSONDecodeError: 452 | # Fallback parsing for malformed JSON 453 | lines = response.strip().split('\n') 454 | thought = "" 455 | answer = "" 456 | 457 | for line in lines: 458 | if "thought" in line.lower() and ":" in line: 459 | thought = line.split(":", 1)[1].strip() 460 | elif "answer" in line.lower() and ":" in line: 461 | answer = line.split(":", 1)[1].strip() 462 | 463 | parsed_responses.append({"thought": thought, "answer": answer, "full_response": response}) 464 | 465 | if prompt_type == "cot": 466 | # For CoT, use the first response 467 | predicted_answer = parsed_responses[0]["answer"] 468 | reasoning = parsed_responses[0]["thought"] 469 | individual_responses = [parsed_responses[0]] 470 | else: 471 | # For CoT-SC, use majority voting on answers 472 | answers = [r["answer"] for r in parsed_responses] 473 | answer_counts = Counter(answers) 474 | predicted_answer = answer_counts.most_common(1)[0][0] 475 | 476 | # Detailed voting breakdown 477 | voting_details = { 478 | "vote_counts": dict(answer_counts), 479 | "winning_answer": predicted_answer, 480 | "total_votes": sum(answer_counts.values()) 481 | } 482 | 483 | # Collect all reasoning paths 484 | reasoning = "\n\n".join([f"Path {i+1}: {r['thought']}\nAnswer: {r['answer']}" for i, r in enumerate(parsed_responses)]) 485 | individual_responses = parsed_responses 486 | 487 | elif prompt_type == "self_consistency": 488 | # For self-consistency, use majority voting 489 | answer_counts = Counter(responses) 490 | predicted_answer = answer_counts.most_common(1)[0][0] 491 | 492 | # Detailed voting breakdown 493 | voting_details = { 494 | "vote_counts": dict(answer_counts), 495 | "winning_answer": predicted_answer, 496 | "total_votes": sum(answer_counts.values()) 497 | } 498 | 499 | reasoning = f"Majority vote from {len(responses)} samples: {dict(answer_counts)}" 500 | individual_responses = [{"answer": r, "full_response": r} for r in responses] 501 | 502 | else: 503 | # For zero-shot and few-shot, use the first response 504 | predicted_answer = responses[0].strip() 505 | reasoning = "Direct answer, no explicit reasoning" 506 | individual_responses = [{"answer": predicted_answer, "full_response": responses[0]}] 507 | 508 | # Clean up the predicted answer (extract just the option letter for MC) 509 | if is_mc and len(predicted_answer) > 1: 510 | # Look for option letters in the answer 511 | for option in options.keys(): 512 | if option in predicted_answer or option.lower() in predicted_answer.lower(): 513 | predicted_answer = option 514 | break 515 | 516 | # Calculate processing time 517 | processing_time = time.time() - start_time 518 | 519 | # Prepare the result structure with improved details 520 | result = { 521 | "qid": qid, 522 | "timestamp": int(time.time()), 523 | "question": question, 524 | "options": options, 525 | "image_path": image_path, 526 | "ground_truth": ground_truth, 527 | "predicted_answer": predicted_answer, 528 | "case_history": { 529 | "reasoning": reasoning, 530 | "prompt_type": prompt_type, 531 | "model": self.model_key, 532 | "raw_responses": responses, 533 | "individual_responses": individual_responses, 534 | "voting_details": voting_details, 535 | "processing_time": processing_time 536 | } 537 | } 538 | 539 | return result 540 | 541 | 542 | def main(): 543 | """ 544 | Main function to process medical QA datasets with various prompting techniques. 545 | """ 546 | parser = argparse.ArgumentParser(description="Run single model inference on medical datasets") 547 | parser.add_argument("--dataset", type=str, required=True, help="Dataset name (MedQA, PubMedQA, PathVQA, VQA-RAD)") 548 | parser.add_argument("--qa_type", type=str, choices=["mc", "ff"], required=True, 549 | help="QA type: multiple-choice (mc) or free-form (ff)") 550 | parser.add_argument("--prompt_type", type=str, required=True, 551 | choices=["zero_shot", "few_shot", "cot", "self_consistency", "cot_sc"], 552 | help="Prompting technique to use") 553 | parser.add_argument("--model_key", type=str, default="qwen-max-latest", 554 | help="Model key from LLM_MODELS_SETTINGS") 555 | parser.add_argument("--sample_size", type=int, default=5, 556 | help="Number of samples for self-consistency methods") 557 | args = parser.parse_args() 558 | 559 | # Dataset and QA type 560 | dataset_name = args.dataset 561 | qa_type = args.qa_type 562 | prompt_type = args.prompt_type 563 | model_key = args.model_key 564 | sample_size = args.sample_size 565 | 566 | print(f"Dataset: {dataset_name}") 567 | print(f"QA Type: {qa_type}") 568 | print(f"Prompt Type: {prompt_type}") 569 | print(f"Model: {model_key}") 570 | print(f"Sample Size: {sample_size}") 571 | 572 | # Method name for logging 573 | method = f"SingleLLM_{prompt_type}" 574 | 575 | # Set up data path 576 | data_path = f"./my_datasets/processed/medqa/{dataset_name}/medqa_{qa_type}_test.json" 577 | 578 | # Set up logs directory 579 | qa_format_dir = "multiple_choice" if qa_type == "mc" else "free-form" 580 | logs_dir = os.path.join("logs", "medqa", dataset_name, qa_format_dir, method) 581 | os.makedirs(logs_dir, exist_ok=True) 582 | 583 | print(f"Data path: {data_path}") 584 | print(f"Logs directory: {logs_dir}") 585 | 586 | # Initialize the model 587 | model = SingleModelInference(model_key=model_key, sample_size=sample_size) 588 | 589 | # Load the data 590 | data = load_json(data_path) 591 | print(f"Loaded {len(data)} items from {data_path}") 592 | 593 | # Track stats 594 | processed_count = 0 595 | skipped_count = 0 596 | error_count = 0 597 | correct_count = 0 598 | 599 | # Process each item 600 | for item in tqdm(data, desc=f"Processing {dataset_name} with {prompt_type}"): 601 | qid = item["qid"] 602 | 603 | # Skip if already processed 604 | result_path = os.path.join(logs_dir, f"{qid}-result.json") 605 | if os.path.exists(result_path): 606 | print(f"Skipping {qid} - already processed") 607 | skipped_count += 1 608 | continue 609 | 610 | try: 611 | # Process the item 612 | result = model.process_item( 613 | item=item, 614 | prompt_type=prompt_type, 615 | dataset=dataset_name 616 | ) 617 | 618 | # Save the result 619 | save_json(result, result_path) 620 | 621 | # Update stats 622 | processed_count += 1 623 | if result["predicted_answer"] == result["ground_truth"]: 624 | correct_count += 1 625 | 626 | except Exception as e: 627 | print(f"Error processing item {qid}: {e}") 628 | error_count += 1 629 | 630 | # Print summary 631 | print("\n" + "="*50) 632 | print(f"Processing Summary for {dataset_name} ({qa_type}) with {prompt_type}:") 633 | print(f"Total items: {len(data)}") 634 | print(f"Processed: {processed_count}") 635 | print(f"Skipped (already processed): {skipped_count}") 636 | print(f"Errors: {error_count}") 637 | 638 | if processed_count > 0: 639 | accuracy = (correct_count / processed_count) * 100 640 | print(f"Accuracy of processed items: {accuracy:.2f}%") 641 | 642 | print("="*50) 643 | 644 | 645 | if __name__ == "__main__": 646 | main() -------------------------------------------------------------------------------- /medagentboard/ehr/multi_agent_reconcile.py: -------------------------------------------------------------------------------- 1 | """ 2 | medagentboard/ehr/multi_agent_reconcile.py 3 | 4 | This module implements the Reconcile framework for multi-model, 5 | multi-agent discussion for EHR predictive modeling tasks. Each agent generates 6 | a prediction with step-by-step reasoning and an estimated confidence level. 7 | Then, the agents engage in multi-round discussions and a confidence-weighted 8 | aggregation produces the final team prediction. 9 | """ 10 | 11 | import os 12 | import json 13 | import time 14 | import numpy as np 15 | from enum import Enum 16 | from typing import Dict, List, Any, Optional, Union, Tuple 17 | import argparse 18 | from tqdm import tqdm 19 | 20 | # Import utilities 21 | from medagentboard.utils.llm_configs import LLM_MODELS_SETTINGS 22 | from medagentboard.utils.json_utils import load_json, save_json, preprocess_response_string 23 | 24 | 25 | ############################################################################### 26 | # Discussion Phase Enumeration 27 | ############################################################################### 28 | class DiscussionPhase(Enum): 29 | """Enumeration of discussion phases in the Reconcile framework.""" 30 | INITIAL = "initial" # Initial prediction generation 31 | DISCUSSION = "discussion" # Multi-round discussion 32 | FINAL = "final" # Final team prediction 33 | 34 | 35 | ############################################################################### 36 | # ReconcileAgent: an LLM agent for the Reconcile framework 37 | ############################################################################### 38 | class ReconcileAgent: 39 | """ 40 | An agent participating in the Reconcile framework for EHR prediction. 41 | 42 | Each agent uses a specified LLM model to generate a prediction, 43 | detailed reasoning, and an estimated confidence level (between 0.0 and 1.0). 44 | 45 | Attributes: 46 | agent_id: Unique identifier for the agent 47 | model_key: Key of the LLM model in LLM_MODELS_SETTINGS 48 | model_name: Name of the model used by this agent 49 | client: OpenAI-compatible client for making API calls 50 | discussion_history: List of agent's responses throughout the discussion 51 | memory: Agent's memory of the case 52 | """ 53 | def __init__(self, agent_id: str, model_key: str): 54 | """ 55 | Initialize a Reconcile agent. 56 | 57 | Args: 58 | agent_id: Unique identifier for the agent 59 | model_key: Key of the LLM model in LLM_MODELS_SETTINGS 60 | 61 | Raises: 62 | ValueError: If model_key is not found in LLM_MODELS_SETTINGS 63 | """ 64 | self.agent_id = agent_id 65 | self.model_key = model_key 66 | self.discussion_history = [] 67 | self.memory = [] 68 | 69 | if model_key not in LLM_MODELS_SETTINGS: 70 | raise ValueError(f"Model key '{model_key}' not configured in LLM_MODELS_SETTINGS") 71 | self.model_config = LLM_MODELS_SETTINGS[model_key] 72 | 73 | # Set up the LLM client using the OpenAI-based client 74 | try: 75 | from openai import OpenAI 76 | except ImportError as e: 77 | raise ImportError("OpenAI client is not installed. Please install it.") from e 78 | 79 | self.client = OpenAI( 80 | api_key=self.model_config["api_key"], 81 | base_url=self.model_config["base_url"], 82 | ) 83 | self.model_name = self.model_config["model_name"] 84 | print(f"Initialized agent {self.agent_id} with model {self.model_name}") 85 | 86 | def call_llm(self, messages: List[Dict[str, Any]], max_retries: int = 3) -> str: 87 | """ 88 | Call the LLM with the provided messages and a retry mechanism. 89 | 90 | Args: 91 | messages: List of messages (each as a dictionary) to send to the LLM 92 | max_retries: Maximum number of retry attempts 93 | 94 | Returns: 95 | The text content from the LLM response 96 | """ 97 | attempt = 0 98 | wait_time = 1 99 | 100 | while attempt < max_retries: 101 | try: 102 | print(f"Agent {self.agent_id} calling LLM with model {self.model_name} (attempt {attempt+1}/{max_retries})") 103 | completion = self.client.chat.completions.create( 104 | model=self.model_name, 105 | messages=messages, 106 | response_format={"type": "json_object"} 107 | ) 108 | response_text = completion.choices[0].message.content 109 | print(f"Agent {self.agent_id} received response: {response_text[:100]}...") 110 | return response_text 111 | except Exception as e: 112 | attempt += 1 113 | print(f"Agent {self.agent_id} LLM call attempt {attempt}/{max_retries} failed: {e}") 114 | if attempt < max_retries: 115 | print(f"Waiting {wait_time} seconds before retry...") 116 | time.sleep(wait_time) 117 | 118 | # If all retries fail, return an error JSON message 119 | print(f"Agent {self.agent_id} all LLM call attempts failed, returning default response") 120 | return json.dumps({ 121 | "reasoning": "LLM call failed after multiple attempts", 122 | "prediction": 0.5, 123 | "confidence": 0.0 124 | }) 125 | 126 | def generate_initial_response(self, question: str) -> Dict[str, Any]: 127 | """ 128 | Generate an initial prediction for the EHR time series data. 129 | 130 | Args: 131 | question: The input question containing EHR data and prediction task 132 | 133 | Returns: 134 | A dictionary containing reasoning, prediction, and confidence 135 | """ 136 | print(f"Agent {self.agent_id} generating initial response") 137 | 138 | # Construct system message 139 | system_message = { 140 | "role": "system", 141 | "content": ( 142 | "You are a medical expert specializing in analyzing electronic health records (EHR) " 143 | "and making clinical predictions. Analyze the following patient data " 144 | "and provide a clear prediction along with detailed step-by-step reasoning. " 145 | "Based on your understanding, estimate your confidence in your prediction " 146 | "on a scale from 0.0 to 1.0, where 1.0 means complete certainty." 147 | ) 148 | } 149 | 150 | # Construct user message 151 | prompt_text = ( 152 | f"{question}\n\n" 153 | f"Provide your response in JSON format with the following fields:\n" 154 | f"- 'reasoning': your detailed step-by-step analysis of the patient data\n" 155 | f"- 'prediction': a floating-point number between 0 and 1 representing the predicted probability\n" 156 | f"- 'confidence': a number between 0.0 and 1.0 representing your confidence level in your prediction\n\n" 157 | f"Ensure your JSON is properly formatted." 158 | ) 159 | 160 | user_message = { 161 | "role": "user", 162 | "content": prompt_text 163 | } 164 | 165 | # Call LLM and parse response 166 | response_text = self.call_llm([system_message, user_message]) 167 | result = self._parse_response(response_text) 168 | 169 | # Store in agent's memory 170 | self.memory.append({ 171 | "phase": DiscussionPhase.INITIAL.value, 172 | "response": result 173 | }) 174 | 175 | return result 176 | 177 | def generate_discussion_response(self, question: str, discussion_prompt: str) -> Dict[str, Any]: 178 | """ 179 | Generate a response during the discussion phase. 180 | 181 | Args: 182 | question: The original question with EHR data 183 | discussion_prompt: The formatted discussion prompt with other agents' responses 184 | 185 | Returns: 186 | A dictionary containing reasoning, prediction, and confidence 187 | """ 188 | print(f"Agent {self.agent_id} generating discussion response") 189 | 190 | # Construct system message 191 | system_message = { 192 | "role": "system", 193 | "content": ( 194 | "You are a medical expert participating in a multi-agent discussion about " 195 | "electronic health records (EHR) analysis. Review the opinions from other experts, " 196 | "then provide your updated analysis. You may adjust your prediction if others' " 197 | "reasoning convinces you, or defend your position with clear explanations. " 198 | "Estimate your confidence in your prediction on a scale from 0.0 to 1.0." 199 | ) 200 | } 201 | 202 | # Construct user message 203 | prompt_text = ( 204 | f"Original patient data and task:\n{question}\n\n" 205 | f"Discussion from other experts:\n{discussion_prompt}\n\n" 206 | f"Based on this discussion, provide your updated analysis in JSON format with the following fields:\n" 207 | f"- 'reasoning': your detailed step-by-step analysis of the patient data\n" 208 | f"- 'prediction': a floating-point number between 0 and 1 representing the predicted probability\n" 209 | f"- 'confidence': a number between 0.0 and 1.0 representing your confidence level in your prediction\n\n" 210 | f"Ensure your JSON is properly formatted." 211 | ) 212 | 213 | user_message = { 214 | "role": "user", 215 | "content": prompt_text 216 | } 217 | 218 | # Call LLM and parse response 219 | response_text = self.call_llm([system_message, user_message]) 220 | result = self._parse_response(response_text) 221 | 222 | # Determine the current round number 223 | current_round = sum(1 for mem in self.memory if mem["phase"] == DiscussionPhase.DISCUSSION.value) + 1 224 | 225 | # Store in agent's memory 226 | self.memory.append({ 227 | "phase": DiscussionPhase.DISCUSSION.value, 228 | "round": current_round, 229 | "response": result 230 | }) 231 | 232 | return result 233 | 234 | def _parse_response(self, response_text: str) -> Dict[str, Any]: 235 | """ 236 | Parse the LLM response into a structured format. 237 | 238 | Args: 239 | response_text: The raw response text from the LLM 240 | 241 | Returns: 242 | A dictionary with reasoning, prediction, and confidence 243 | """ 244 | try: 245 | result = json.loads(preprocess_response_string(response_text)) 246 | 247 | # Validate required fields 248 | if "reasoning" not in result: 249 | result["reasoning"] = "No reasoning provided" 250 | 251 | if "prediction" not in result: 252 | result["prediction"] = 0.5 253 | else: 254 | # Ensure prediction is a float between 0 and 1 255 | try: 256 | result["prediction"] = float(result["prediction"]) 257 | result["prediction"] = max(0.0, min(1.0, result["prediction"])) 258 | except (ValueError, TypeError): 259 | result["prediction"] = 0.5 260 | 261 | if "confidence" not in result: 262 | result["confidence"] = 0.0 263 | else: 264 | # Ensure confidence is a float between 0 and 1 265 | try: 266 | result["confidence"] = float(result["confidence"]) 267 | result["confidence"] = max(0.0, min(1.0, result["confidence"])) 268 | except (ValueError, TypeError): 269 | result["confidence"] = 0.0 270 | 271 | return result 272 | 273 | except json.JSONDecodeError: 274 | print(f"Agent {self.agent_id} failed to parse JSON response: {response_text[:100]}...") 275 | 276 | # Attempt to extract with simple parsing 277 | reasoning = "" 278 | prediction = 0.5 279 | confidence = 0.0 280 | 281 | lines = response_text.split('\n') 282 | for line in lines: 283 | if line.lower().startswith("reasoning:"): 284 | reasoning = line.split(":", 1)[1].strip() 285 | elif line.lower().startswith("prediction:"): 286 | try: 287 | prediction = float(line.split(":", 1)[1].strip()) 288 | prediction = max(0.0, min(1.0, prediction)) 289 | except (ValueError, IndexError): 290 | prediction = 0.5 291 | elif line.lower().startswith("confidence:"): 292 | try: 293 | confidence = float(line.split(":", 1)[1].strip()) 294 | confidence = max(0.0, min(1.0, confidence)) 295 | except (ValueError, IndexError): 296 | confidence = 0.0 297 | 298 | # If basic parsing doesn't work, use the raw text 299 | if not reasoning: 300 | reasoning = response_text 301 | 302 | return { 303 | "reasoning": reasoning, 304 | "prediction": prediction, 305 | "confidence": confidence 306 | } 307 | 308 | 309 | ############################################################################### 310 | # ReconcileCoordinator: orchestrates the multi-agent discussion process 311 | ############################################################################### 312 | class ReconcileCoordinator: 313 | """ 314 | The coordinator for the Reconcile framework in EHR prediction tasks. 315 | 316 | This class orchestrates the following phases: 317 | 1. Initial Prediction Generation: Each agent generates an initial prediction. 318 | 2. Multi-Round Discussion: Agents update their predictions based on the grouped responses. 319 | 3. Team Prediction Generation: A confidence-weighted aggregation produces the final prediction. 320 | 321 | Attributes: 322 | agents: List of ReconcileAgent objects participating in the discussion 323 | max_rounds: Maximum number of discussion rounds 324 | """ 325 | def __init__(self, agent_configs: List[Dict[str, str]], max_rounds: int = 3): 326 | """ 327 | Initialize the Reconcile coordinator. 328 | 329 | Args: 330 | agent_configs: List of agent configurations (each with agent_id and model_key) 331 | max_rounds: Maximum number of discussion rounds 332 | """ 333 | # Instantiate Reconcile agents using provided configurations 334 | self.agents = [ 335 | ReconcileAgent(cfg["agent_id"], cfg["model_key"]) 336 | for cfg in agent_configs 337 | ] 338 | self.max_rounds = max_rounds 339 | print(f"Initialized ReconcileCoordinator with {len(self.agents)} agents, max_rounds={max_rounds}") 340 | 341 | def _group_predictions(self, predictions: List[Dict[str, Any]]) -> str: 342 | """ 343 | Group and summarize predictions from agents. 344 | 345 | Args: 346 | predictions: List of agent response dictionaries 347 | 348 | Returns: 349 | A formatted string with grouped predictions and their supporting explanations 350 | """ 351 | # Define groups based on prediction ranges 352 | groups = { 353 | "low_risk": {"range": (0.0, 0.33), "count": 0, "explanations": [], "avg_pred": 0.0, "confidence_sum": 0.0}, 354 | "medium_risk": {"range": (0.33, 0.67), "count": 0, "explanations": [], "avg_pred": 0.0, "confidence_sum": 0.0}, 355 | "high_risk": {"range": (0.67, 1.0), "count": 0, "explanations": [], "avg_pred": 0.0, "confidence_sum": 0.0} 356 | } 357 | 358 | # Group predictions and explanations 359 | for pred in predictions: 360 | prediction_value = pred.get("prediction", 0.5) 361 | confidence = pred.get("confidence", 0.0) 362 | reasoning = pred.get("reasoning", "") 363 | 364 | # Determine which group this prediction belongs to 365 | for group_name, group_data in groups.items(): 366 | lower, upper = group_data["range"] 367 | if lower <= prediction_value < upper or (group_name == "high_risk" and prediction_value == upper): 368 | group_data["count"] += 1 369 | group_data["explanations"].append(reasoning) 370 | group_data["avg_pred"] += prediction_value 371 | group_data["confidence_sum"] += confidence 372 | break 373 | 374 | # Format grouped predictions 375 | grouped_str = "" 376 | for group_name, data in groups.items(): 377 | if data["count"] > 0: 378 | avg_pred = data["avg_pred"] / data["count"] 379 | avg_confidence = data["confidence_sum"] / data["count"] if data["count"] > 0 else 0 380 | 381 | grouped_str += f"Prediction Group: {group_name.replace('_', ' ').title()} (Range: {data['range'][0]:.2f}-{data['range'][1]:.2f})\n" 382 | grouped_str += f"Number of experts in this group: {data['count']}\n" 383 | grouped_str += f"Average prediction: {avg_pred:.3f}\n" 384 | grouped_str += f"Average confidence: {avg_confidence:.2f}\n" 385 | grouped_str += f"Explanations from this group:\n" 386 | 387 | # Add each explanation with a bullet point 388 | for i, exp in enumerate(data["explanations"]): 389 | # Truncate very long explanations 390 | if len(exp) > 500: 391 | exp = exp[:500] + "... (truncated)" 392 | grouped_str += f"• Expert {i+1}: {exp}\n" 393 | 394 | grouped_str += "\n" 395 | 396 | return grouped_str.strip() 397 | 398 | def _consensus_threshold(self, predictions: List[float]) -> bool: 399 | """ 400 | Check if predictions have reached a reasonable consensus. 401 | 402 | Args: 403 | predictions: List of prediction values 404 | 405 | Returns: 406 | True if consensus reached, False otherwise 407 | """ 408 | if not predictions: 409 | return False 410 | 411 | # Calculate standard deviation of predictions 412 | std_dev = np.std(predictions) 413 | 414 | # If standard deviation is below threshold, consider it a consensus 415 | return std_dev < 0.1 # Threshold can be adjusted based on desired sensitivity 416 | 417 | def _weighted_average(self, predictions: List[Dict[str, Any]]) -> float: 418 | """ 419 | Compute the final team prediction using a confidence-weighted average. 420 | 421 | Args: 422 | predictions: List of prediction dictionaries from agents 423 | 424 | Returns: 425 | The final prediction value 426 | """ 427 | total_weight = 0.0 428 | weighted_sum = 0.0 429 | 430 | for pred in predictions: 431 | prediction = pred.get("prediction", 0.5) 432 | confidence = pred.get("confidence", 0.0) 433 | 434 | # Square the confidence to give more weight to high-confidence predictions 435 | weight = confidence ** 2 436 | 437 | weighted_sum += prediction * weight 438 | total_weight += weight 439 | 440 | # If no valid weights, return simple average 441 | if total_weight == 0: 442 | valid_predictions = [p.get("prediction", 0.5) for p in predictions] 443 | return sum(valid_predictions) / len(valid_predictions) if valid_predictions else 0.5 444 | 445 | return weighted_sum / total_weight 446 | 447 | def run_discussion(self, question: str) -> Dict[str, Any]: 448 | """ 449 | Run the complete discussion process for an EHR prediction task. 450 | 451 | Args: 452 | question: The input question containing EHR data 453 | 454 | Returns: 455 | Dictionary with the final team prediction and discussion history 456 | """ 457 | print(f"Starting EHR prediction discussion with {len(self.agents)} agents") 458 | start_time = time.time() 459 | 460 | discussion_history = [] 461 | 462 | # Phase 1: Initial predictions 463 | print("Phase 1: Generating initial predictions") 464 | current_predictions = [] 465 | 466 | for agent in self.agents: 467 | resp = agent.generate_initial_response(question) 468 | current_predictions.append(resp) 469 | 470 | # Add to discussion history 471 | discussion_history.append({ 472 | "phase": DiscussionPhase.INITIAL.value, 473 | "agent_id": agent.agent_id, 474 | "response": resp 475 | }) 476 | 477 | print(f"Agent {agent.agent_id} initial prediction: {resp.get('prediction', 0.5):.3f} (confidence: {resp.get('confidence', 0.0):.2f})") 478 | 479 | # Phase 2: Multi-round discussion 480 | round_num = 0 481 | consensus_reached = False 482 | 483 | while round_num < self.max_rounds and not consensus_reached: 484 | round_num += 1 485 | print(f"Phase 2: Discussion round {round_num}/{self.max_rounds}") 486 | 487 | # Prepare the discussion prompt based on previous predictions 488 | discussion_prompt = self._group_predictions(current_predictions) 489 | 490 | # Each agent generates a new response 491 | new_predictions = [] 492 | for agent in self.agents: 493 | resp = agent.generate_discussion_response(question, discussion_prompt) 494 | new_predictions.append(resp) 495 | 496 | # Add to discussion history 497 | discussion_history.append({ 498 | "phase": DiscussionPhase.DISCUSSION.value, 499 | "round": round_num, 500 | "agent_id": agent.agent_id, 501 | "response": resp 502 | }) 503 | 504 | print(f"Agent {agent.agent_id} round {round_num} prediction: {resp.get('prediction', 0.5):.3f} (confidence: {resp.get('confidence', 0.0):.2f})") 505 | 506 | # Update current predictions for next round 507 | current_predictions = new_predictions 508 | 509 | # Check if consensus is reached 510 | prediction_values = [p.get("prediction", 0.5) for p in current_predictions] 511 | consensus_reached = self._consensus_threshold(prediction_values) 512 | print(f"Round {round_num} consensus reached: {consensus_reached}") 513 | 514 | if consensus_reached: 515 | print("Consensus reached, ending discussion") 516 | break 517 | 518 | # Phase 3: Final team prediction via weighted average 519 | print("Phase 3: Generating final team prediction") 520 | final_prediction = self._weighted_average(current_predictions) 521 | 522 | # Add final prediction to history 523 | discussion_history.append({ 524 | "phase": DiscussionPhase.FINAL.value, 525 | "final_prediction": final_prediction, 526 | "consensus_reached": 1 if consensus_reached else 0, 527 | "rounds_completed": round_num, 528 | "individual_predictions": [p.get("prediction", 0.5) for p in current_predictions], 529 | "confidence_scores": [p.get("confidence", 0.0) for p in current_predictions] 530 | }) 531 | 532 | end_time = time.time() 533 | processing_time = end_time - start_time 534 | 535 | print(f"Discussion completed in {processing_time:.2f} seconds. Final prediction: {final_prediction:.3f}") 536 | 537 | return { 538 | "final_prediction": final_prediction, 539 | "discussion_history": discussion_history, 540 | "processing_time": processing_time 541 | } 542 | 543 | 544 | ############################################################################### 545 | # Process a Single EHR Item with the Reconcile Framework 546 | ############################################################################### 547 | def process_item(item: Dict[str, Any], 548 | agent_configs: List[Dict[str, str]], 549 | max_rounds: int = 3) -> Dict[str, Any]: 550 | """ 551 | Process a single EHR item with the Reconcile framework. 552 | 553 | Args: 554 | item: Input EHR item dictionary (with qid, question, etc.) 555 | agent_configs: List of agent configurations (each with agent_id and model_key) 556 | max_rounds: Maximum number of discussion rounds 557 | 558 | Returns: 559 | Processed EHR result with the final predicted probability and discussion history 560 | """ 561 | qid = item.get("qid", "unknown") 562 | question = item.get("question", "") 563 | ground_truth = item.get("answer") 564 | 565 | print(f"Processing EHR item {qid}") 566 | 567 | # Create coordinator and run discussion 568 | coordinator = ReconcileCoordinator(agent_configs, max_rounds) 569 | discussion_result = coordinator.run_discussion(question) 570 | 571 | # Compile results 572 | result = { 573 | "qid": qid, 574 | "timestamp": int(time.time()), 575 | "question": question, 576 | "ground_truth": ground_truth, 577 | "predicted_value": discussion_result["final_prediction"], 578 | "case_history": discussion_result, 579 | } 580 | 581 | return result 582 | 583 | 584 | ############################################################################### 585 | # Main Entry Point for the Reconcile Framework on EHR data 586 | ############################################################################### 587 | def main(): 588 | """ 589 | Main entry point for running the Reconcile framework on EHR datasets. 590 | """ 591 | parser = argparse.ArgumentParser(description="Run the Reconcile framework on EHR predictive modeling tasks") 592 | parser.add_argument("--dataset", type=str, choices=["mimic-iv", "tjh"], required=True, help="Dataset name") 593 | parser.add_argument("--task", type=str, choices=["mortality", "readmission"], required=True, help="Prediction task") 594 | parser.add_argument("--agents", nargs='+', default=["qwen-max-latest", "deepseek-v3-official", "qwen-vl-max"], 595 | help="List of agent model keys (e.g., deepseek-v3-official, qwen-max-latest, qwen-vl-max)") 596 | parser.add_argument("--max_rounds", type=int, default=2, help="Maximum number of discussion rounds") 597 | 598 | args = parser.parse_args() 599 | method = "ReConcile" 600 | 601 | # Extract dataset name and task 602 | dataset_name = args.dataset 603 | task_name = args.task 604 | print(f"Dataset: {dataset_name}, Task: {task_name}") 605 | 606 | # Create logs directory structure 607 | logs_dir = os.path.join("logs", "ehr", dataset_name, task_name, method) 608 | os.makedirs(logs_dir, exist_ok=True) 609 | 610 | # Construct the data path 611 | data_path = os.path.join("my_datasets", "processed", "ehr", dataset_name, f"ehr_timeseries_{task_name}_test.json") 612 | 613 | # Load the dataset 614 | data = load_json(data_path) 615 | print(f"Loaded {len(data)} samples from {data_path}") 616 | 617 | # Configure agents: each agent is assigned an ID and a model key 618 | agent_configs = [] 619 | for idx, model_key in enumerate(args.agents, 1): 620 | agent_configs.append({"agent_id": f"agent_{idx}", "model_key": model_key}) 621 | 622 | print(f"Configured {len(agent_configs)} agents: {[cfg['model_key'] for cfg in agent_configs]}") 623 | 624 | # Process each item in the dataset 625 | for item in tqdm(data, desc=f"Processing {dataset_name} ({task_name})"): 626 | qid = item.get("qid") 627 | result_path = os.path.join(logs_dir, f"ehr_timeseries_{qid}-result.json") 628 | 629 | # Skip already processed items 630 | if os.path.exists(result_path): 631 | print(f"Skipping {qid} (already processed)") 632 | continue 633 | 634 | try: 635 | # Process the item 636 | result = process_item(item, agent_configs, args.max_rounds) 637 | 638 | # Save result 639 | save_json(result, result_path) 640 | 641 | except Exception as e: 642 | print(f"Error processing item {qid}: {e}") 643 | 644 | if __name__ == "__main__": 645 | main() -------------------------------------------------------------------------------- /medagentboard/medqa/multi_agent_reconcile.py: -------------------------------------------------------------------------------- 1 | """ 2 | medagentboard/medqa/multi_agent_reconcile.py 3 | 4 | This module implements the Reconcile framework for multi-model, 5 | multi-agent discussion. Each agent generates an answer with step-by-step 6 | reasoning and an estimated confidence level. Then, the agents engage in 7 | multi-round discussions and a confidence-weighted vote produces the final team answer. 8 | """ 9 | 10 | import os 11 | import json 12 | import time 13 | from enum import Enum 14 | from typing import Dict, List, Any, Optional, Union, Tuple 15 | import argparse 16 | from tqdm import tqdm 17 | 18 | # Import ColaCare utilities 19 | from medagentboard.utils.llm_configs import LLM_MODELS_SETTINGS 20 | from medagentboard.utils.json_utils import load_json, save_json, preprocess_response_string 21 | from medagentboard.utils.encode_image import encode_image 22 | 23 | 24 | ############################################################################### 25 | # Discussion Phase Enumeration 26 | ############################################################################### 27 | class DiscussionPhase(Enum): 28 | """Enumeration of discussion phases in the Reconcile framework.""" 29 | INITIAL = "initial" # Initial answer generation 30 | DISCUSSION = "discussion" # Multi-round discussion 31 | FINAL = "final" # Final team answer 32 | 33 | 34 | ############################################################################### 35 | # ReconcileAgent: an LLM agent for the Reconcile framework 36 | ############################################################################### 37 | class ReconcileAgent: 38 | """ 39 | An agent participating in the Reconcile framework. 40 | 41 | Each agent uses a specified LLM model to generate an answer, 42 | detailed reasoning, and an estimated confidence level (between 0.0 and 1.0). 43 | 44 | Attributes: 45 | agent_id: Unique identifier for the agent 46 | model_key: Key of the LLM model in LLM_MODELS_SETTINGS 47 | model_name: Name of the model used by this agent 48 | client: OpenAI-compatible client for making API calls 49 | discussion_history: List of agent's responses throughout the discussion 50 | memory: Agent's memory of the case 51 | """ 52 | def __init__(self, agent_id: str, model_key: str): 53 | """ 54 | Initialize a Reconcile agent. 55 | 56 | Args: 57 | agent_id: Unique identifier for the agent 58 | model_key: Key of the LLM model in LLM_MODELS_SETTINGS 59 | 60 | Raises: 61 | ValueError: If model_key is not found in LLM_MODELS_SETTINGS 62 | """ 63 | self.agent_id = agent_id 64 | self.model_key = model_key 65 | self.discussion_history = [] 66 | self.memory = [] 67 | 68 | if model_key not in LLM_MODELS_SETTINGS: 69 | raise ValueError(f"Model key '{model_key}' not configured in LLM_MODELS_SETTINGS") 70 | self.model_config = LLM_MODELS_SETTINGS[model_key] 71 | 72 | # Set up the LLM client using the OpenAI-based client from ColaCare 73 | try: 74 | from openai import OpenAI 75 | except ImportError as e: 76 | raise ImportError("OpenAI client is not installed. Please install it.") from e 77 | 78 | self.client = OpenAI( 79 | api_key=self.model_config["api_key"], 80 | base_url=self.model_config["base_url"], 81 | ) 82 | self.model_name = self.model_config["model_name"] 83 | print(f"Initialized agent {self.agent_id} with model {self.model_name}") 84 | 85 | def call_llm(self, messages: List[Dict[str, Any]], max_retries: int = 3) -> str: 86 | """ 87 | Call the LLM with the provided messages and a retry mechanism. 88 | 89 | Args: 90 | messages: List of messages (each as a dictionary) to send to the LLM 91 | max_retries: Maximum number of retry attempts 92 | 93 | Returns: 94 | The text content from the LLM response 95 | """ 96 | attempt = 0 97 | wait_time = 1 98 | 99 | while attempt < max_retries: 100 | try: 101 | print(f"Agent {self.agent_id} calling LLM with model {self.model_name} (attempt {attempt+1}/{max_retries})") 102 | completion = self.client.chat.completions.create( 103 | model=self.model_name, 104 | messages=messages, 105 | response_format={"type": "json_object"} 106 | ) 107 | response_text = completion.choices[0].message.content 108 | print(f"Agent {self.agent_id} received response: {response_text[:100]}...") 109 | return response_text 110 | except Exception as e: 111 | attempt += 1 112 | print(f"Agent {self.agent_id} LLM call attempt {attempt}/{max_retries} failed: {e}") 113 | if attempt < max_retries: 114 | print(f"Waiting {wait_time} seconds before retry...") 115 | time.sleep(wait_time) 116 | 117 | # If all retries fail, return an error JSON message 118 | print(f"Agent {self.agent_id} all LLM call attempts failed, returning default response") 119 | return json.dumps({ 120 | "reasoning": "LLM call failed after multiple attempts", 121 | "answer": "", 122 | "confidence": 0.0 123 | }) 124 | 125 | def generate_initial_response(self, 126 | question: str, 127 | options: Optional[Dict[str, str]] = None, 128 | image_path: Optional[str] = None) -> Dict[str, Any]: 129 | """ 130 | Generate an initial response for a given question. 131 | 132 | Args: 133 | question: The input question text 134 | options: Optional multiple choice options 135 | image_path: Optional path to an image for MedVQA 136 | 137 | Returns: 138 | A dictionary containing reasoning, answer, and confidence 139 | """ 140 | print(f"Agent {self.agent_id} generating initial response") 141 | 142 | # Construct system message 143 | system_message = { 144 | "role": "system", 145 | "content": ( 146 | "You are a medical expert assistant. Analyze the following medical question " 147 | "and provide a clear answer along with detailed step-by-step reasoning. " 148 | "Based on your understanding, estimate your confidence in your answer " 149 | "on a scale from 0.0 to 1.0, where 1.0 means complete certainty." 150 | ) 151 | } 152 | 153 | # Construct user message 154 | user_content = [] 155 | 156 | # Add image if provided 157 | if image_path and os.path.exists(image_path): 158 | try: 159 | base64_image = encode_image(image_path) 160 | user_content.append({ 161 | "type": "image_url", 162 | "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} 163 | }) 164 | print(f"Agent {self.agent_id} added image from {image_path}") 165 | except Exception as e: 166 | print(f"Error encoding image {image_path}: {e}") 167 | 168 | # Format question with options if provided 169 | question_text = question 170 | if options: 171 | options_text = "\n".join([f"{key}: {value}" for key, value in options.items()]) 172 | question_text = f"{question}\n\nOptions:\n{options_text}" 173 | 174 | # Add formatted question 175 | prompt_text = ( 176 | f"{question_text}\n\n" 177 | f"Provide your response in JSON format with the following fields:\n" 178 | f"- 'reasoning': your detailed step-by-step analysis\n" 179 | f"- 'answer': your final answer" 180 | ) 181 | 182 | if options: 183 | prompt_text += " (specify just the option letter)" 184 | 185 | prompt_text += ( 186 | f"\n- 'confidence': a number between 0.0 and 1.0 representing your confidence level\n\n" 187 | f"Ensure your JSON is properly formatted." 188 | ) 189 | 190 | user_content.append({ 191 | "type": "text", 192 | "text": prompt_text 193 | }) 194 | 195 | user_message = { 196 | "role": "user", 197 | "content": user_content 198 | } 199 | 200 | # Call LLM and parse response 201 | response_text = self.call_llm([system_message, user_message]) 202 | result = self._parse_response(response_text) 203 | 204 | # Store in agent's memory 205 | self.memory.append({ 206 | "phase": DiscussionPhase.INITIAL.value, 207 | "response": result 208 | }) 209 | 210 | return result 211 | 212 | def generate_discussion_response(self, 213 | question: str, 214 | discussion_prompt: str, 215 | options: Optional[Dict[str, str]] = None, 216 | image_path: Optional[str] = None) -> Dict[str, Any]: 217 | """ 218 | Generate a response during the discussion phase. 219 | 220 | Args: 221 | question: The original question 222 | discussion_prompt: The formatted discussion prompt with other agents' responses 223 | options: Optional multiple choice options 224 | image_path: Optional path to an image for MedVQA 225 | 226 | Returns: 227 | A dictionary containing reasoning, answer, and confidence 228 | """ 229 | print(f"Agent {self.agent_id} generating discussion response") 230 | 231 | # Construct system message 232 | system_message = { 233 | "role": "system", 234 | "content": ( 235 | "You are a medical expert participating in a multi-agent discussion. " 236 | "Review the opinions from other experts, then provide your updated analysis. " 237 | "You may change your opinion if others' reasoning convinces you, or defend your position " 238 | "with clear explanations. Estimate your confidence in your answer on a scale from 0.0 to 1.0." 239 | ) 240 | } 241 | 242 | # Construct user message 243 | user_content = [] 244 | 245 | # Add image if provided 246 | if image_path and os.path.exists(image_path): 247 | try: 248 | base64_image = encode_image(image_path) 249 | user_content.append({ 250 | "type": "image_url", 251 | "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} 252 | }) 253 | except Exception as e: 254 | print(f"Error encoding image {image_path}: {e}") 255 | 256 | # Format question with options if provided 257 | question_text = question 258 | if options: 259 | options_text = "\n".join([f"{key}: {value}" for key, value in options.items()]) 260 | question_text = f"{question}\n\nOptions:\n{options_text}" 261 | 262 | # Add question and discussion prompt 263 | prompt_text = ( 264 | f"Original question: {question_text}\n\n" 265 | f"Discussion from other experts:\n{discussion_prompt}\n\n" 266 | f"Based on this discussion, provide your updated analysis in JSON format with the following fields:\n" 267 | f"- 'reasoning': your detailed step-by-step analysis\n" 268 | f"- 'answer': your final answer" 269 | ) 270 | 271 | if options: 272 | prompt_text += " (specify just the option letter)" 273 | 274 | prompt_text += ( 275 | f"\n- 'confidence': a number between 0.0 and 1.0 representing your confidence level\n\n" 276 | f"Ensure your JSON is properly formatted." 277 | ) 278 | 279 | user_content.append({ 280 | "type": "text", 281 | "text": prompt_text 282 | }) 283 | 284 | user_message = { 285 | "role": "user", 286 | "content": user_content 287 | } 288 | 289 | # Call LLM and parse response 290 | response_text = self.call_llm([system_message, user_message]) 291 | result = self._parse_response(response_text) 292 | 293 | # Determine the current round number 294 | current_round = sum(1 for mem in self.memory if mem["phase"] == DiscussionPhase.DISCUSSION.value) + 1 295 | 296 | # Store in agent's memory 297 | self.memory.append({ 298 | "phase": DiscussionPhase.DISCUSSION.value, 299 | "round": current_round, 300 | "response": result 301 | }) 302 | 303 | return result 304 | 305 | def _parse_response(self, response_text: str) -> Dict[str, Any]: 306 | """ 307 | Parse the LLM response into a structured format. 308 | 309 | Args: 310 | response_text: The raw response text from the LLM 311 | 312 | Returns: 313 | A dictionary with reasoning, answer, and confidence 314 | """ 315 | try: 316 | result = json.loads(preprocess_response_string(response_text)) 317 | 318 | # Validate required fields 319 | if "reasoning" not in result: 320 | result["reasoning"] = "No reasoning provided" 321 | 322 | if "answer" not in result: 323 | result["answer"] = "" 324 | 325 | if "confidence" not in result: 326 | result["confidence"] = 0.0 327 | else: 328 | # Ensure confidence is a float between 0 and 1 329 | try: 330 | result["confidence"] = float(result["confidence"]) 331 | result["confidence"] = max(0.0, min(1.0, result["confidence"])) 332 | except (ValueError, TypeError): 333 | result["confidence"] = 0.0 334 | 335 | return result 336 | 337 | except json.JSONDecodeError: 338 | print(f"Agent {self.agent_id} failed to parse JSON response: {response_text[:100]}...") 339 | 340 | # Attempt to extract with simple parsing 341 | reasoning = "" 342 | answer = "" 343 | confidence = 0.0 344 | 345 | lines = response_text.split('\n') 346 | for line in lines: 347 | if line.lower().startswith("reasoning:"): 348 | reasoning = line.split(":", 1)[1].strip() 349 | elif line.lower().startswith("answer:"): 350 | answer = line.split(":", 1)[1].strip() 351 | elif line.lower().startswith("confidence:"): 352 | try: 353 | confidence = float(line.split(":", 1)[1].strip()) 354 | confidence = max(0.0, min(1.0, confidence)) 355 | except (ValueError, IndexError): 356 | confidence = 0.0 357 | 358 | # If basic parsing doesn't work, use the raw text 359 | if not reasoning: 360 | reasoning = response_text 361 | 362 | return { 363 | "reasoning": reasoning, 364 | "answer": answer, 365 | "confidence": confidence 366 | } 367 | 368 | 369 | ############################################################################### 370 | # ReconcileCoordinator: orchestrates the multi-agent discussion process 371 | ############################################################################### 372 | class ReconcileCoordinator: 373 | """ 374 | The coordinator for the Reconcile framework. 375 | 376 | This class orchestrates the following phases: 377 | 1. Initial Response Generation: Each agent generates an initial response. 378 | 2. Multi-Round Discussion: Agents update their responses based on the grouped responses. 379 | 3. Team Answer Generation: A confidence-weighted vote decides the final answer. 380 | 381 | Attributes: 382 | agents: List of ReconcileAgent objects participating in the discussion 383 | max_rounds: Maximum number of discussion rounds 384 | """ 385 | def __init__(self, agent_configs: List[Dict[str, str]], max_rounds: int = 3): 386 | """ 387 | Initialize the Reconcile coordinator. 388 | 389 | Args: 390 | agent_configs: List of agent configurations (each with agent_id and model_key) 391 | max_rounds: Maximum number of discussion rounds 392 | """ 393 | # Instantiate Reconcile agents using provided configurations 394 | self.agents = [ 395 | ReconcileAgent(cfg["agent_id"], cfg["model_key"]) 396 | for cfg in agent_configs 397 | ] 398 | self.max_rounds = max_rounds 399 | print(f"Initialized ReconcileCoordinator with {len(self.agents)} agents, max_rounds={max_rounds}") 400 | 401 | def _group_answers(self, answers: List[Dict[str, Any]]) -> str: 402 | """ 403 | Group and summarize responses from agents. 404 | 405 | Args: 406 | answers: List of agent response dictionaries 407 | 408 | Returns: 409 | A formatted string with grouped answers and their supporting explanations 410 | """ 411 | groups = {} 412 | 413 | # Group answers and explanations 414 | for ans in answers: 415 | answer_text = ans.get("answer", "").strip().lower() 416 | confidence = ans.get("confidence", 0.0) 417 | 418 | if answer_text not in groups: 419 | groups[answer_text] = { 420 | "count": 0, 421 | "explanations": [], 422 | "confidence_sum": 0.0 423 | } 424 | 425 | groups[answer_text]["count"] += 1 426 | groups[answer_text]["explanations"].append(ans.get("reasoning", "")) 427 | groups[answer_text]["confidence_sum"] += confidence 428 | 429 | # Format grouped answers 430 | grouped_str = "" 431 | for ans_text, data in groups.items(): 432 | # Calculate average confidence 433 | avg_confidence = data["confidence_sum"] / data["count"] if data["count"] > 0 else 0 434 | 435 | grouped_str += f"Answer: {ans_text}\n" 436 | grouped_str += f"Supporters: {data['count']}\n" 437 | grouped_str += f"Average confidence: {avg_confidence:.2f}\n" 438 | grouped_str += f"Explanations:\n" 439 | 440 | # Add each explanation with a bullet point 441 | for i, exp in enumerate(data["explanations"]): 442 | # Truncate very long explanations 443 | if len(exp) > 500: 444 | exp = exp[:500] + "... (truncated)" 445 | grouped_str += f"• Expert {i+1}: {exp}\n" 446 | 447 | grouped_str += "\n" 448 | 449 | return grouped_str.strip() 450 | 451 | def _recalibrate(self, confidence: float) -> float: 452 | """ 453 | Recalibrate a confidence score for better voting weights. 454 | 455 | Args: 456 | confidence: The original confidence score (0.0 to 1.0) 457 | 458 | Returns: 459 | Recalibrated confidence score 460 | """ 461 | if confidence == 1.0: 462 | return 1.0 463 | elif confidence >= 0.9: 464 | return 0.8 465 | elif confidence >= 0.8: 466 | return 0.5 467 | elif confidence > 0.6: 468 | return 0.3 469 | else: 470 | return 0.1 471 | 472 | def _weighted_vote(self, answers: List[Dict[str, Any]]) -> str: 473 | """ 474 | Compute the final team answer using a confidence-weighted vote. 475 | 476 | Args: 477 | answers: List of response dictionaries from agents 478 | 479 | Returns: 480 | The final answer string 481 | """ 482 | vote_weights = {} 483 | 484 | # Calculate weights for each answer 485 | for ans in answers: 486 | answer = ans.get("answer", "").strip() 487 | if not answer: 488 | continue 489 | 490 | confidence = ans.get("confidence", 0.0) 491 | weight = self._recalibrate(confidence) 492 | 493 | # Normalize answer to lowercase for vote counting, but preserve original case 494 | key = answer.lower() 495 | 496 | if key not in vote_weights: 497 | vote_weights[key] = {"weight": 0, "original": answer} 498 | 499 | vote_weights[key]["weight"] += weight 500 | 501 | if not vote_weights: 502 | return "" 503 | 504 | # Find the answer with the highest weight 505 | winner_key = max(vote_weights, key=lambda k: vote_weights[k]["weight"]) 506 | final_decision = vote_weights[winner_key]["original"] 507 | 508 | return final_decision 509 | 510 | def _check_consensus(self, answers: List[Dict[str, Any]]) -> bool: 511 | """ 512 | Check if all agents provided the same answer (consensus reached). 513 | 514 | Args: 515 | answers: List of response dictionaries from agents 516 | 517 | Returns: 518 | True if consensus reached, False otherwise 519 | """ 520 | # Extract valid answers and convert to lowercase for comparison 521 | valid_answers = [ans.get("answer", "").strip().lower() for ans in answers if ans.get("answer", "").strip()] 522 | 523 | # Check if there's at least one valid answer and all are the same 524 | return len(valid_answers) > 0 and len(set(valid_answers)) == 1 525 | 526 | def run_discussion(self, 527 | question: str, 528 | options: Optional[Dict[str, str]] = None, 529 | image_path: Optional[str] = None) -> Dict[str, Any]: 530 | """ 531 | Run the complete discussion process. 532 | 533 | Args: 534 | question: The input question 535 | options: Optional multiple choice options 536 | image_path: Optional path to an image for MedVQA 537 | 538 | Returns: 539 | Dictionary with the final team answer and discussion history 540 | """ 541 | print(f"Starting discussion with {len(self.agents)} agents on question: {question}") 542 | start_time = time.time() 543 | 544 | discussion_history = [] 545 | 546 | # Phase 1: Initial responses 547 | print("Phase 1: Generating initial responses") 548 | current_answers = [] 549 | 550 | for agent in self.agents: 551 | resp = agent.generate_initial_response(question, options, image_path) 552 | current_answers.append(resp) 553 | 554 | # Add to discussion history 555 | discussion_history.append({ 556 | "phase": DiscussionPhase.INITIAL.value, 557 | "agent_id": agent.agent_id, 558 | "response": resp 559 | }) 560 | 561 | print(f"Agent {agent.agent_id} initial answer: {resp.get('answer', '')} (confidence: {resp.get('confidence', 0.0):.2f})") 562 | 563 | # Phase 2: Multi-round discussion 564 | round_num = 0 565 | consensus_reached = False 566 | 567 | while round_num < self.max_rounds and not consensus_reached: 568 | round_num += 1 569 | print(f"Phase 2: Discussion round {round_num}/{self.max_rounds}") 570 | 571 | # Prepare the discussion prompt based on previous answers 572 | discussion_prompt = self._group_answers(current_answers) 573 | 574 | # Each agent generates a new response 575 | new_answers = [] 576 | for agent in self.agents: 577 | resp = agent.generate_discussion_response( 578 | question, discussion_prompt, options, image_path 579 | ) 580 | new_answers.append(resp) 581 | 582 | # Add to discussion history 583 | discussion_history.append({ 584 | "phase": DiscussionPhase.DISCUSSION.value, 585 | "round": round_num, 586 | "agent_id": agent.agent_id, 587 | "response": resp 588 | }) 589 | 590 | print(f"Agent {agent.agent_id} round {round_num} answer: {resp.get('answer', '')} (confidence: {resp.get('confidence', 0.0):.2f})") 591 | 592 | # Update current answers for next round 593 | current_answers = new_answers 594 | 595 | # Check if consensus is reached 596 | consensus_reached = self._check_consensus(current_answers) 597 | print(f"Round {round_num} consensus reached: {consensus_reached}") 598 | 599 | if consensus_reached: 600 | print("Consensus reached, ending discussion") 601 | break 602 | 603 | # Phase 3: Final team answer via weighted vote 604 | print("Phase 3: Generating final team answer") 605 | final_decision = self._weighted_vote(current_answers) 606 | 607 | # Add final decision to history 608 | discussion_history.append({ 609 | "phase": DiscussionPhase.FINAL.value, 610 | "final_decision": final_decision, 611 | "consensus_reached": consensus_reached, 612 | "rounds_completed": round_num, 613 | "confidence_scores": [ans.get("confidence", 0.0) for ans in current_answers] 614 | }) 615 | 616 | end_time = time.time() 617 | processing_time = end_time - start_time 618 | 619 | print(f"Discussion completed in {processing_time:.2f} seconds. Final answer: {final_decision}") 620 | 621 | return { 622 | "final_decision": final_decision, 623 | "discussion_history": discussion_history, 624 | "processing_time": processing_time 625 | } 626 | 627 | 628 | ############################################################################### 629 | # Process a Single QA Item with the Reconcile Framework 630 | ############################################################################### 631 | def process_item(item: Dict[str, Any], 632 | agent_configs: List[Dict[str, str]], 633 | max_rounds: int = 3) -> Dict[str, Any]: 634 | """ 635 | Process a single QA item with the Reconcile framework. 636 | 637 | Args: 638 | item: Input QA item dictionary (with qid, question, etc.) 639 | agent_configs: List of agent configurations (each with agent_id and model_key) 640 | max_rounds: Maximum number of discussion rounds 641 | 642 | Returns: 643 | Processed QA result with the final predicted answer and discussion history 644 | """ 645 | qid = item.get("qid", "unknown") 646 | question = item.get("question", "") 647 | options = item.get("options") 648 | image_path = item.get("image_path") 649 | ground_truth = item.get("answer") 650 | 651 | print(f"Processing item {qid}") 652 | 653 | # Create coordinator and run discussion 654 | coordinator = ReconcileCoordinator(agent_configs, max_rounds) 655 | discussion_result = coordinator.run_discussion(question, options, image_path) 656 | 657 | # Compile results 658 | result = { 659 | "qid": qid, 660 | "timestamp": int(time.time()), 661 | "question": question, 662 | "options": options, 663 | "image_path": image_path, 664 | "ground_truth": ground_truth, 665 | "predicted_answer": discussion_result["final_decision"], 666 | "case_history": discussion_result, 667 | } 668 | 669 | return result 670 | 671 | 672 | ############################################################################### 673 | # Main Entry Point for the Reconcile Framework 674 | ############################################################################### 675 | def main(): 676 | """ 677 | Main entry point for running the Reconcile framework from command line. 678 | """ 679 | parser = argparse.ArgumentParser(description="Run the Reconcile framework on medical QA datasets") 680 | parser.add_argument("--dataset", type=str, required=True, help="Dataset name") 681 | parser.add_argument("--qa_type", type=str, choices=["mc", "ff"], required=True, 682 | help="QA type: multiple-choice (mc) or free-form (ff)") 683 | parser.add_argument("--agents", nargs='+', default=["qwen-max-latest", "deepseek-v3-ark", "qwen-vl-max"], 684 | help="List of agent model keys (e.g., deepseek-v3-ark, qwen-max-latest, qwen-vl-max)") 685 | parser.add_argument("--max_rounds", type=int, default=3, help="Maximum number of discussion rounds") 686 | 687 | args = parser.parse_args() 688 | method = "ReConcile" 689 | 690 | # Extract dataset name 691 | dataset_name = args.dataset 692 | print(f"Dataset: {dataset_name}") 693 | 694 | # Determine QA format (multiple choice or free-form) 695 | qa_type = args.qa_type 696 | print(f"QA Format: {qa_type}") 697 | 698 | # Create logs directory structure 699 | logs_dir = os.path.join("logs", "medqa", dataset_name, "multiple_choice" if qa_type == "mc" else "free-form", method) 700 | os.makedirs(logs_dir, exist_ok=True) 701 | 702 | # Construct the data path 703 | data_path = os.path.join("my_datasets", "processed", "medqa", args.dataset, f"medqa_{args.qa_type}_test.json") 704 | 705 | # Load the dataset 706 | data = load_json(data_path) 707 | print(f"Loaded {len(data)} samples from {data_path}") 708 | 709 | # Configure agents: each agent is assigned an ID and a model key 710 | agent_configs = [] 711 | for idx, model_key in enumerate(args.agents, 1): 712 | agent_configs.append({"agent_id": f"agent_{idx}", "model_key": model_key}) 713 | 714 | print(f"Configured {len(agent_configs)} agents: {[cfg['model_key'] for cfg in agent_configs]}") 715 | 716 | 717 | # Process each item in the dataset 718 | for item in tqdm(data, desc=f"Processing {dataset_name} ({qa_type})"): 719 | qid = item.get("qid") 720 | result_path = os.path.join(logs_dir, f"{qid}-result.json") 721 | 722 | # Skip already processed items 723 | if os.path.exists(result_path): 724 | print(f"Skipping {qid} (already processed)") 725 | continue 726 | 727 | try: 728 | # Process the item 729 | result = process_item(item, agent_configs, args.max_rounds) 730 | # Save result 731 | save_json(result, result_path) 732 | 733 | except Exception as e: 734 | print(f"Error processing item {qid}: {e}") 735 | 736 | if __name__ == "__main__": 737 | main() --------------------------------------------------------------------------------