├── scripts └── run_code.sh ├── tools ├── __pycache__ │ ├── tabtools.cpython-39.pyc │ └── calculator.cpython-39.pyc ├── calculator.py └── tabtools.py ├── requirements.txt ├── ehragent ├── config.py ├── evaluate.py ├── question_difficulty.py ├── toolset_high.py ├── main.py ├── medagent.py ├── prompts_eicu.py └── prompts_mimic.py └── README.md /scripts/run_code.sh: -------------------------------------------------------------------------------- 1 | python main.py --llm XXX --dataset mimic_iii --data_path XXX --logs_path XXX --num_questions -1 --seed 0 -------------------------------------------------------------------------------- /tools/__pycache__/tabtools.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/EhrAgent/HEAD/tools/__pycache__/tabtools.cpython-39.pyc -------------------------------------------------------------------------------- /tools/__pycache__/calculator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wshi83/EhrAgent/HEAD/tools/__pycache__/calculator.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autogen==1.0.16 2 | jsonlines==3.1.0 3 | matplotlib==3.7.1 4 | numpy==1.24.3 5 | openai==1.7.2 6 | pandas==2.1.4 7 | pyautogen==0.2.0 8 | python_Levenshtein==0.23.0 9 | seaborn==0.13.1 10 | termcolor==2.4.0 11 | wolframalpha==5.0.0 12 | -------------------------------------------------------------------------------- /tools/calculator.py: -------------------------------------------------------------------------------- 1 | ''' 2 | input: formula strings 3 | output: the answer of the mathematical formula 4 | ''' 5 | import os 6 | import re 7 | from operator import pow, truediv, mul, add, sub 8 | import wolframalpha 9 | query = '1+2*3' 10 | 11 | def calculator(query: str): 12 | operators = { 13 | '+': add, 14 | '-': sub, 15 | '*': mul, 16 | '/': truediv, 17 | } 18 | query = re.sub(r'\s+', '', query) 19 | if query.isdigit(): 20 | return float(query) 21 | for c in operators.keys(): 22 | left, operator, right = query.partition(c) 23 | if operator in operators: 24 | return round(operators[operator](calculator(left), calculator(right)),2) 25 | 26 | def WolframAlphaCalculator(input_query: str): 27 | try: 28 | wolfram_alpha_appid = "" 29 | wolfram_client = wolframalpha.Client(wolfram_alpha_appid) 30 | res = wolfram_client.query(input_query) 31 | assumption = next(res.pods).text 32 | answer = next(res.results).text 33 | except: 34 | raise Exception("Invalid input query for Calculator. Please check the input query or use other functions to do the computation.") 35 | # return f"Assumption: {assumption} \nAnswer: {answer}" 36 | return answer 37 | 38 | if __name__ == "__main__": 39 | query = 'max(37.97,76.1)' 40 | print(WolframAlphaCalculator(query)) -------------------------------------------------------------------------------- /ehragent/config.py: -------------------------------------------------------------------------------- 1 | def openai_config(model): 2 | if model == '': 3 | config = { 4 | "model": "", 5 | "api_key": "", 6 | "base_url": "", 7 | "api_version": "", 8 | "api_type": "AZURE" 9 | } 10 | elif model == '': 11 | config = { 12 | "model": "", 13 | "api_key": "", 14 | "base_url": "", 15 | "api_version": "", 16 | "api_type": "AZURE" 17 | } 18 | return config 19 | 20 | def llm_config_list(seed, config_list): 21 | llm_config_list = { 22 | "functions": [ 23 | { 24 | "name": "python", 25 | "description": "run the entire code and return the execution result. Only generate the code.", 26 | "parameters": { 27 | "type": "object", 28 | "properties": { 29 | "cell": { 30 | "type": "string", 31 | "description": "Valid Python code to execute.", 32 | } 33 | }, 34 | "required": ["cell"], 35 | }, 36 | }, 37 | ], 38 | "config_list": config_list, 39 | "timeout": 120, 40 | "cache_seed": seed, 41 | "temperature": 0, 42 | } 43 | return llm_config_list -------------------------------------------------------------------------------- /ehragent/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def judge(pred, ans): 5 | old_flag = True 6 | if not ans in pred: 7 | old_flag = False 8 | if "True" in pred: 9 | pred = pred.replace("True", "1") 10 | else: 11 | pred = pred.replace("False", "0") 12 | if ans == "False" or ans == "false": 13 | ans = "0" 14 | if ans == "True" or ans == "true": 15 | ans = "1" 16 | if ans == "None" or ans == "none": 17 | ans = "0" 18 | if ", " in ans: 19 | ans = ans.split(', ') 20 | if ans[-2:] == ".0": 21 | ans = ans[:-2] 22 | if not type(ans) == list: 23 | ans = [ans] 24 | new_flag = True 25 | for i in range(len(ans)): 26 | if not ans[i] in pred: 27 | new_flag = False 28 | break 29 | return (old_flag or new_flag) 30 | 31 | logs_path = "" 32 | files = os.listdir(logs_path) 33 | 34 | # read the files 35 | answer_book = "" 36 | with open(answer_book, 'r') as f: 37 | contents = json.load(f) 38 | answers = {} 39 | for i in range(len(contents)): 40 | answers[contents[i]['id']] = contents[i]['answer'] 41 | 42 | stats = {"total_num": 0, "correct": 0, "unfinished": 0, "incorrect": 0} 43 | 44 | for file in files: 45 | if not file.split('.')[0] in answers.keys(): 46 | continue 47 | with open(logs_path+file, 'r') as f: 48 | logs = f.read() 49 | split_logs = logs.split('\n----------------------------------------------------------\n') 50 | question = split_logs[0] 51 | answer = answers[file.split('.')[0]] 52 | if type(answer) == list: 53 | answer = ', '.join(answer) 54 | stats["total_num"] += 1 55 | if not "TERMINATE" in logs: 56 | stats["unfinished"] += 1 57 | else: 58 | if '"cell": "' in logs: 59 | last_code_start = logs.rfind('"cell": "') 60 | last_code_end = logs.rfind('"\n}') 61 | last_code = logs[last_code_start+9:last_code_end] 62 | else: 63 | last_code_end = logs.rfind('Solution:') 64 | prediction_end = logs.rfind('TERMINATE') 65 | prediction = logs[last_code_end:prediction_end] 66 | logs = logs.split('TERMINATE')[0] 67 | result = judge(prediction, answer) 68 | if result: 69 | stats["correct"] += 1 70 | else: 71 | stats["incorrect"] += 1 72 | 73 | print(stats) 74 | -------------------------------------------------------------------------------- /ehragent/question_difficulty.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import matplotlib.pyplot as plt 4 | import collections 5 | import matplotlib 6 | import seaborn as sns 7 | from collections import defaultdict 8 | import os 9 | 10 | 11 | matplotlib.use("Agg") 12 | matplotlib.rcParams.update({'font.family': 'Times New Roman'}) 13 | matplotlib.rcParams['pdf.fonttype'] = 42 14 | matplotlib.rcParams['ps.fonttype'] = 42 15 | 16 | sns.set_theme(style="ticks", font="Times New Roman", font_scale=2.1, rc={'grid.linestyle': ':', 'axes.grid': True}) 17 | 18 | for dataset in ["mimic_iii", "eicu"]: 19 | with open(f"", 'r') as f_in, \ 20 | open(f"", "w") as f_out: 21 | list_num_q_tag_var = [] 22 | list_num_tables = [] 23 | list_num_columns = [] 24 | num_q_tag_var_dict = defaultdict(list) 25 | num_tables_dict = defaultdict(list) 26 | num_columns_dict = defaultdict(list) 27 | for lines in f_in: 28 | x = json.loads(lines) 29 | 30 | # number of variables in q_tag 31 | num_q_tag_var = x["q_tag"].count("{") + x["q_tag"].count("[") 32 | 33 | # number of different tables used in the sql 34 | # (we only count the tables in the original datasets, but not the new ones created in the sql) 35 | tables = re.findall(r'\bfrom\s+(\w+)\b', x["query"]) 36 | num_tables = len(set(tables)) 37 | 38 | # number of different columns used in the sql 39 | # (we only count the columns in the original datasets, but not the new ones created in the sql) 40 | columns = re.findall(r'\b\w*\.\w*\b', x["query"]) 41 | columns = [item for item in columns if not re.match(r'^t\d', item)] 42 | num_columns = len(set(columns)) 43 | 44 | x["num_q_tag_var"] = num_q_tag_var 45 | x["num_tables"] = num_tables 46 | x["num_columns"] = num_columns 47 | 48 | f_out.write(json.dumps(x) + "\n") 49 | 50 | list_num_q_tag_var.append(num_q_tag_var) 51 | list_num_tables.append(num_tables) 52 | list_num_columns.append(num_columns) 53 | 54 | num_q_tag_var_dict[num_q_tag_var].append(x["id"]) 55 | num_tables_dict[num_tables].append(x["id"]) 56 | num_columns_dict[num_columns].append(x["id"]) 57 | 58 | os.makedirs(f"{dataset}/num_q_tag_var/", exist_ok=True) 59 | for k, id_list in num_q_tag_var_dict.items(): 60 | with open(f"{dataset}/num_q_tag_var/{k}.jsonl", "w") as f: 61 | json.dump(id_list, f) 62 | os.makedirs(f"{dataset}/num_tables/", exist_ok=True) 63 | for k, id_list in num_tables_dict.items(): 64 | with open(f"{dataset}/num_tables/{k}.jsonl", "w") as f: 65 | json.dump(id_list, f) 66 | os.makedirs(f"{dataset}/num_columns/", exist_ok=True) 67 | for k, id_list in num_columns_dict.items(): 68 | with open(f"{dataset}/num_columns/{k}.jsonl", "w") as f: 69 | json.dump(id_list, f) 70 | 71 | 72 | # plot the distribution 73 | all_list = [list_num_q_tag_var, list_num_tables, list_num_columns] 74 | xlabel = ["# q_tag variables", "# tables", "# columns"] 75 | file_name = ["num_q_tag_ver_distri.pdf", "num_tables_distri.pdf", "num_columns_distri.pdf"] 76 | 77 | os.makedirs(f"{dataset}/figures/", exist_ok=True) 78 | for i in range(len(all_list)): 79 | c = collections.Counter(all_list[i]) 80 | c = sorted(c.items()) 81 | plt.figure(figsize=(6, 5.5), dpi=120) 82 | ax = sns.barplot(x=[i[0] for i in c], y=[i[1] for i in c], color=sns.color_palette()[0]) 83 | plt.ylabel("Frequency", size=30) 84 | plt.xlabel(xlabel[i], size=30) 85 | plt.tight_layout(rect=[-0.05, -0.05, 1.05, 1.05]) 86 | plt.savefig(f"{dataset}/figures/{file_name[i]}") 87 | -------------------------------------------------------------------------------- /ehragent/toolset_high.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import openai 3 | import autogen 4 | import time 5 | import os 6 | from config import openai_config 7 | from openai import AzureOpenAI 8 | import traceback 9 | 10 | def run_code(cell): 11 | """ 12 | Returns the path to the python interpreter. 13 | """ 14 | # import prompts 15 | from prompts_mimic import CodeHeader 16 | try: 17 | global_var = {"answer": 0} 18 | exec(CodeHeader+cell, global_var) 19 | cell = "\n".join([line for line in cell.split("\n") if line.strip() and not line.strip().startswith("#")]) 20 | if not 'answer' in cell.split('\n')[-1]: 21 | return "Please save the answer to the question in the variable 'answer'." 22 | return str(global_var['answer']) 23 | except Exception as e: 24 | error_info = traceback.format_exc() 25 | code = CodeHeader + cell 26 | if "SyntaxError" in str(repr(e)): 27 | error_line = str(repr(e)) 28 | 29 | error_type = error_line.split('(')[0] 30 | # then parse out the error message 31 | error_message = error_line.split(',')[0].split('(')[1] 32 | # then parse out the error line 33 | error_line = error_line.split('"')[1] 34 | elif "KeyError" in str(repr(e)): 35 | code = code.split('\n') 36 | key = str(repr(e)).split("'")[1] 37 | error_type = str(repr(e)).split('(')[0] 38 | for i in range(len(code)): 39 | if key in code[i]: 40 | error_line = code[i] 41 | error_message = str(repr(e)) 42 | elif "TypeError" in str(repr(e)): 43 | error_type = str(repr(e)).split('(')[0] 44 | error_message = str(e) 45 | function_mapping_dict = {"get_value": "GetValue", "data_filter": "FilterDB", "db_loader": "LoadDB", "sql_interpreter": "SQLInterpreter", "date_calculator": "Calendar"} 46 | error_key = "" 47 | for key in function_mapping_dict.keys(): 48 | if key in error_message: 49 | error_message = error_message.replace(key, function_mapping_dict[key]) 50 | error_key = function_mapping_dict[key] 51 | code = code.split('\n') 52 | error_line = "" 53 | for i in range(len(code)): 54 | if error_key in code[i]: 55 | error_line = code[i] 56 | else: 57 | error_type = "" 58 | error_message = str(repr(e)).split("('")[-1].split("')")[0] 59 | error_line = "" 60 | # use one sentence to introduce the previous parsed error information 61 | if error_type != "" and error_line != "": 62 | error_info = f'{error_type}: {error_message}. The error messages occur in the code line "{error_line}".' 63 | else: 64 | error_info = f'Error: {error_message}.' 65 | error_info += '\nPlease make modifications accordingly and make sure the rest code works well with the modification.' 66 | 67 | return error_info 68 | 69 | 70 | def llm_agent(config_list): 71 | llm_config = { 72 | "functions": [ 73 | { 74 | "name": "python", 75 | "description": "run cell in ipython and return the execution result.", 76 | "parameters": { 77 | "type": "object", 78 | "properties": { 79 | "cell": { 80 | "type": "string", 81 | "description": "Valid Python cell to execute.", 82 | } 83 | }, 84 | "required": ["cell"], 85 | }, 86 | }, 87 | ], 88 | "config_list": config_list, 89 | "request_timeout": 120, 90 | } 91 | chatbot = autogen.AssistantAgent( 92 | name="chatbot", 93 | system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.", 94 | llm_config=llm_config, 95 | ) 96 | return chatbot 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

⚕️EHRAgent🤖

3 |
4 | 5 | The official repository for the code of the paper ["EHRAgent: Code Empowers Large Language Models for Complex Tabular Reasoning on Electronic Health Records"](https://arxiv.org/abs/2401.07128). EHRAgent is an LLM agent empowered with a code interface, to autonomously generate and execute code for complex clinical tasks within electronic health records (EHRs). The project page is available at [this link](https://wshi83.github.io/EHR-Agent-page/). 6 | 7 | ### Features 8 | 9 | - EHRAgent is an LLM agent augmented with tools and medical knowledge, to solve complex tabular reasoning derived from EHRs; 10 | - Planning with a code interface, EHRAgent enables the LLM agent to formulate a clinical problem-solving process as an executable code plan of action sequences, along with a code executor; 11 | - We introduce interactive coding between the LLM agent and code executor, iteratively refining plan generation and optimizing code execution by examining environment feedback in depth. 12 | 13 | ### Data Preparation 14 | 15 | We use the [EHRSQL](https://github.com/glee4810/EHRSQL) benchmark for evaluation. The original dataset is for text-to-SQL tasks, and we have made adaptations to our evaluation. We release our clean and pre-processed version of [EHRSQL-EHRAgent](https://drive.google.com/file/d/1EE_g3kroKJW_2Op6T2PiZbDSrIQRMtps/view?usp=sharing) data. Please download the data and record the path of the data. 16 | 17 | ### Credentials Preparation 18 | Our experiments are based on OpenAI API services. Please record your API keys and other credentials in the ``./ehragent/config.py``. 19 | 20 | ### Setup 21 | 22 | See ``requirements.txt``. Packages with versions specified in ``requirements.txt`` are used to test the code. Other versions that are not fully tested may also work. We also kindly suggest the users to run this code with Python version: ``python>=3.9``. Install required libraries with the following command: 23 | 24 | ```bash 25 | pip3 install -r requirements.txt 26 | ``` 27 | 28 | ### Instructions 29 | 30 | The outputting results will be saved under the directory ``./logs/``. Use the following command to run our code: 31 | ```bash 32 | python main.py --llm YOUR_LLM_NAME --dataset mimic_iii --data_path YOUR_DATA_PATH --logs_path YOUR_LOGS_PATH --num_questions -1 --seed 0 33 | ``` 34 | 35 | We also support debugging mode to focus on a single question: 36 | ```bash 37 | python main.py --llm YOUR_LLM_NAME --dataset mimic_iii --data_path YOUR_DATA_PATH --logs_path YOUR_LOGS_PATH --debug --debug_id QUESTION_ID_TO_DEBUG 38 | ``` 39 | 40 | For **eICU** dataset, just change the option of dataset to ``--dataset eicu``. 41 | 42 | ### Citation 43 | If you find this repository useful, please consider citing: 44 | ```bibtex 45 | @inproceedings{shi-etal-2024-ehragent, 46 | title = "{EHRA}gent: Code Empowers Large Language Models for Few-shot Complex Tabular Reasoning on Electronic Health Records", 47 | author = "Shi, Wenqi and 48 | Xu, Ran and 49 | Zhuang, Yuchen and 50 | Yu, Yue and 51 | Zhang, Jieyu and 52 | Wu, Hang and 53 | Zhu, Yuanda and 54 | Ho, Joyce C. and 55 | Yang, Carl and 56 | Wang, May Dongmei", 57 | editor = "Al-Onaizan, Yaser and 58 | Bansal, Mohit and 59 | Chen, Yun-Nung", 60 | booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing", 61 | month = nov, 62 | year = "2024", 63 | address = "Miami, Florida, USA", 64 | publisher = "Association for Computational Linguistics", 65 | url = "https://aclanthology.org/2024.emnlp-main.1245", 66 | doi = "10.18653/v1/2024.emnlp-main.1245", 67 | pages = "22315--22339", 68 | abstract = "Clinicians often rely on data engineers to retrieve complex patient information from electronic health record (EHR) systems, a process that is both inefficient and time-consuming. We propose EHRAgent, a large language model (LLM) agent empowered with accumulative domain knowledge and robust coding capability. EHRAgent enables autonomous code generation and execution to facilitate clinicians in directly interacting with EHRs using natural language. Specifically, we formulate a multi-tabular reasoning task based on EHRs as a tool-use planning process, efficiently decomposing a complex task into a sequence of manageable actions with external toolsets. We first inject relevant medical information to enable EHRAgent to effectively reason about the given query, identifying and extracting the required records from the appropriate tables. By integrating interactive coding and execution feedback, EHRAgent then effectively learns from error messages and iteratively improves its originally generated code. Experiments on three real-world EHR datasets show that EHRAgent outperforms the strongest baseline by up to 29.6{\%} in success rate, verifying its strong capacity to tackle complex clinical tasks with minimal demonstrations.", 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /ehragent/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | import argparse 6 | import autogen 7 | from toolset_high import * 8 | from medagent import MedAgent 9 | from config import openai_config, llm_config_list 10 | import time 11 | 12 | def judge(pred, ans): 13 | old_flag = True 14 | if not ans in pred: 15 | old_flag = False 16 | if "True" in pred: 17 | pred = pred.replace("True", "1") 18 | else: 19 | pred = pred.replace("False", "0") 20 | if ans == "False" or ans == "false": 21 | ans = "0" 22 | if ans == "True" or ans == "true": 23 | ans = "1" 24 | if ans == "No" or ans == "no": 25 | ans = "0" 26 | if ans == "Yes" or ans == "yes": 27 | ans = "1" 28 | if ans == "None" or ans == "none": 29 | ans = "0" 30 | if ", " in ans: 31 | ans = ans.split(', ') 32 | if ans[-2:] == ".0": 33 | ans = ans[:-2] 34 | if not type(ans) == list: 35 | ans = [ans] 36 | new_flag = True 37 | for i in range(len(ans)): 38 | if not ans[i] in pred: 39 | new_flag = False 40 | break 41 | return (old_flag or new_flag) 42 | 43 | def set_seed(seed): 44 | random.seed(seed) 45 | np.random.seed(seed) 46 | 47 | def main(): 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--llm", type=str, default="") 50 | parser.add_argument("--num_questions", type=int, default=1) 51 | parser.add_argument("--dataset", type=str, default="mimic_iii") 52 | parser.add_argument("--data_path", type=str, default="") 53 | parser.add_argument("--logs_path", type=str, default="") 54 | parser.add_argument("--seed", type=int, default=42) 55 | parser.add_argument("--debug", action="store_true") 56 | parser.add_argument("--debug_id", type=str, default="521fd2885f51641a963f8d3e") 57 | parser.add_argument("--start_id", type=int, default=0) 58 | parser.add_argument("--num_shots", type=int, default=4) 59 | args = parser.parse_args() 60 | set_seed(args.seed) 61 | if args.dataset == 'mimic_iii': 62 | from prompts_mimic import EHRAgent_4Shots_Knowledge 63 | else: 64 | from prompts_eicu import EHRAgent_4Shots_Knowledge 65 | 66 | config_list = [openai_config(args.llm)] 67 | llm_config = llm_config_list(args.seed, config_list) 68 | 69 | chatbot = autogen.agentchat.AssistantAgent( 70 | name="chatbot", 71 | system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done. Save the answers to the questions in the variable 'answer'. Please only generate the code.", 72 | llm_config=llm_config, 73 | ) 74 | 75 | user_proxy = MedAgent( 76 | name="user_proxy", 77 | is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"), 78 | human_input_mode="NEVER", 79 | max_consecutive_auto_reply=10, 80 | code_execution_config={"work_dir": "coding"}, 81 | config_list=config_list, 82 | ) 83 | 84 | # register the functions 85 | user_proxy.register_function( 86 | function_map={ 87 | "python": run_code 88 | } 89 | ) 90 | 91 | user_proxy.register_dataset(args.dataset) 92 | 93 | file_path = args.data_path 94 | # read from json file 95 | with open(file_path, 'r') as f: 96 | contents = json.load(f) 97 | 98 | # random shuffle 99 | import random 100 | random.shuffle(contents) 101 | file_path = "{}/{}/".format(args.logs_path, args.num_shots) + "{id}.txt" 102 | 103 | start_time = time.time() 104 | if args.num_questions == -1: 105 | args.num_questions = len(contents) 106 | long_term_memory = [] 107 | init_memory = EHRAgent_4Shots_Knowledge 108 | init_memory = init_memory.split('\n\n') 109 | for i in range(len(init_memory)): 110 | item = init_memory[i] 111 | item = item.split('Question:')[-1] 112 | question = item.split('\nKnowledge:\n')[0] 113 | item = item.split('\nKnowledge:\n')[-1] 114 | knowledge = item.split('\nSolution:')[0] 115 | code = item.split('\nSolution:')[-1] 116 | new_item = {"question": question, "knowledge": knowledge, "code": code} 117 | long_term_memory.append(new_item) 118 | 119 | for i in range(args.start_id, args.num_questions): 120 | if args.debug and contents[i]['id'] != args.debug_id: 121 | continue 122 | question = contents[i]['template'] 123 | answer = contents[i]['answer'] 124 | try: 125 | user_proxy.update_memory(args.num_shots, long_term_memory) 126 | user_proxy.initiate_chat( 127 | chatbot, 128 | message=question, 129 | ) 130 | logs = user_proxy._oai_messages 131 | 132 | logs_string = [] 133 | logs_string.append(str(question)) 134 | logs_string.append(str(answer)) 135 | for agent in list(logs.keys()): 136 | for j in range(len(logs[agent])): 137 | if logs[agent][j]['content'] != None: 138 | logs_string.append(logs[agent][j]['content']) 139 | else: 140 | argums = logs[agent][j]['function_call']['arguments'] 141 | if type(argums) == dict and 'cell' in argums.keys(): 142 | logs_string.append(argums['cell']) 143 | else: 144 | logs_string.append(argums) 145 | except Exception as e: 146 | logs_string = [str(e)] 147 | print(logs_string) 148 | file_directory = file_path.format(id=contents[i]['id']) 149 | # f = open(file_directory, 'w') 150 | if type(answer) == list: 151 | answer = ', '.join(answer) 152 | logs_string.append("Ground-Truth Answer ---> "+answer) 153 | with open(file_directory, 'w') as f: 154 | f.write('\n----------------------------------------------------------\n'.join(logs_string)) 155 | logs_string = '\n----------------------------------------------------------\n'.join(logs_string) 156 | if '"cell": "' in logs_string: 157 | last_code_start = logs_string.rfind('"cell": "') 158 | last_code_end = logs_string.rfind('"\n}') 159 | last_code = logs_string[last_code_start+9:last_code_end] 160 | else: 161 | last_code_end = logs_string.rfind('Solution:') 162 | prediction_end = logs_string.rfind('TERMINATE') 163 | prediction = logs_string[last_code_end:prediction_end] 164 | result = judge(prediction, answer) 165 | if result: 166 | new_item = {"question": question, "knowledge": user_proxy.knowledge, "code": user_proxy.code} 167 | long_term_memory.append(new_item) 168 | end_time = time.time() 169 | print("Time elapsed: ", end_time - start_time) 170 | 171 | if __name__ == "__main__": 172 | main() -------------------------------------------------------------------------------- /ehragent/medagent.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Dict, List, Optional, Union, Callable, Literal, Optional, Union 3 | import logging 4 | import asyncio 5 | import openai 6 | import json 7 | from openai import OpenAI, AzureOpenAI 8 | from autogen.agentchat import Agent, UserProxyAgent, ConversableAgent 9 | from termcolor import colored 10 | import Levenshtein 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class MedAgent(UserProxyAgent): 15 | def __init__( 16 | self, 17 | name: str, 18 | is_termination_msg: Optional[Callable[[Dict], bool]] = None, 19 | max_consecutive_auto_reply: Optional[int] = None, 20 | human_input_mode: Optional[str] = "ALWAYS", 21 | function_map: Optional[Dict[str, Callable]] = None, 22 | code_execution_config: Optional[Union[Dict, Literal[False]]] = None, 23 | default_auto_reply: Optional[Union[str, Dict, None]] = "", 24 | llm_config: Optional[Union[Dict, Literal[False]]] = False, 25 | system_message: Optional[Union[str, List]] = "", 26 | config_list: Optional[List[Dict]] = None, 27 | ): 28 | super().__init__( 29 | name=name, 30 | system_message=system_message, 31 | is_termination_msg=is_termination_msg, 32 | max_consecutive_auto_reply=max_consecutive_auto_reply, 33 | human_input_mode=human_input_mode, 34 | function_map=function_map, 35 | code_execution_config=code_execution_config, 36 | llm_config=llm_config, 37 | default_auto_reply=default_auto_reply, 38 | ) 39 | self.config_list = config_list 40 | self.question = '' 41 | self.code = '' 42 | self.knowledge = '' 43 | 44 | def retrieve_knowledge(self, config, query): 45 | # import prompt 46 | if self.dataset == 'mimic_iii': 47 | from prompts_mimic import RetrKnowledge 48 | else: 49 | from prompts_eicu import RetrKnowledge 50 | # Returns the related information to the given query. 51 | patience = 2 52 | sleep_time = 30 53 | openai.api_type = config["api_type"] 54 | openai.api_base = config["base_url"] 55 | openai.api_version = config["api_version"] 56 | openai.api_key = config["api_key"] 57 | engine = config["model"] 58 | query_message = RetrKnowledge.format(question=query) 59 | messages = [{"role":"system","content":"You are an AI assistant that helps people find information."}, 60 | {"role":"user","content": query_message}] 61 | client = AzureOpenAI( 62 | api_key=config["api_key"], 63 | azure_endpoint=config["base_url"], 64 | api_version=config["api_version"], 65 | ) 66 | while patience > 0: 67 | patience -= 1 68 | try: 69 | response = client.chat.completions.create( 70 | model=engine, 71 | messages = messages, 72 | temperature=0, 73 | max_tokens=800, 74 | top_p=0.95, 75 | frequency_penalty=0, 76 | presence_penalty=0, 77 | stop=None) 78 | prediction = response.choices[0].message.content.strip() 79 | if prediction != "" and prediction != None: 80 | return prediction 81 | except Exception as e: 82 | print(e) 83 | if sleep_time > 0: 84 | time.sleep(sleep_time) 85 | return "Fail to retrieve related knowledge, please try again later." 86 | 87 | def retrieve_examples(self, query): 88 | levenshtein_dist = {} 89 | for i in range(len(self.memory)): 90 | question = self.memory[i]["question"] 91 | levenshtein_dist[i] = Levenshtein.distance(query, question) 92 | levenshtein_dist = sorted(levenshtein_dist.items(), key=lambda x: x[1], reverse=False) 93 | selected_indexes = [levenshtein_dist[i][0] for i in range(min(self.num_shots, len(levenshtein_dist)))] 94 | examples = [] 95 | for i in selected_indexes: 96 | template = "Question: {}\nKnowledge:\n{}\nSolution:\n{}\n".format(self.memory[i]["question"], self.memory[i]["knowledge"], self.memory[i]["code"]) 97 | examples.append(template) 98 | examples = '\n'.join(examples) 99 | return examples 100 | 101 | def generate_init_message(self, **context): 102 | # import prompt 103 | if self.dataset == 'mimic_iii': 104 | from prompts_mimic import EHRAgent_Message_Prompt 105 | else: 106 | from prompts_eicu import EHRAgent_Message_Prompt 107 | self.question = context["message"] 108 | knowledge = self.retrieve_knowledge(self.config_list[0], context["message"]) 109 | self.knowledge = knowledge 110 | 111 | examples = self.retrieve_examples(context["message"]) 112 | 113 | init_message = EHRAgent_Message_Prompt.format(examples=examples, knowledge=knowledge, question=context["message"]) 114 | return init_message 115 | 116 | def send(self, message: Union[Dict, str], recipient: Agent, request_reply: Optional[bool]=None, silent: Optional[bool]=False): 117 | valid = self._append_oai_message(message, "assistant", recipient) 118 | if valid: 119 | recipient.receive(message, self, request_reply, silent) 120 | else: 121 | raise ValueError( 122 | "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." 123 | ) 124 | 125 | def initiate_chat(self, recipient: "ConversableAgent", clear_history: Optional[bool]=True, silent: Optional[bool]=False, **context,): 126 | self._prepare_chat(recipient, clear_history) 127 | self.send(self.generate_init_message(**context), recipient, silent=silent) 128 | 129 | def receive( 130 | self, 131 | message: Union[Dict, str], 132 | sender: Agent, 133 | request_reply: Optional[bool] = None, 134 | silent: Optional[bool] = False, 135 | ): 136 | self._process_received_message(message, sender, silent) 137 | if request_reply is False or request_reply is None and self.reply_at_receive[sender] is False: 138 | return 139 | reply = self.generate_reply(messages=self.chat_messages[sender], sender=sender) 140 | if reply is not None: 141 | self.send(reply, sender, silent=silent) 142 | 143 | def error_debugger(self, config, code, error_info): 144 | # import prompt 145 | if self.dataset == 'mimic_iii': 146 | from prompts_mimic import CodeDebugger 147 | else: 148 | from prompts_eicu import CodeDebugger 149 | # Returns the related information to the given query. 150 | patience = 2 151 | sleep_time = 30 152 | openai.api_type = config["api_type"] 153 | openai.api_base = config["base_url"] 154 | openai.api_version = config["api_version"] 155 | openai.api_key = config["api_key"] 156 | engine = config["model"] 157 | query_message = CodeDebugger.format(question=self.question, code=code, error_info=error_info) 158 | messages = [{"role":"system","content":"You are an AI assistant that helps people debug their code. Only list one most possible reason to the errors."}, 159 | {"role":"user","content": query_message}] 160 | client = AzureOpenAI( 161 | api_key=config["api_key"], 162 | azure_endpoint=config["base_url"], 163 | api_version=config["api_version"], 164 | ) 165 | while patience > 0: 166 | patience -= 1 167 | try: 168 | response = client.chat.completions.create( 169 | model=engine, 170 | messages = messages, 171 | temperature=0, 172 | max_tokens=800, 173 | top_p=0.95, 174 | frequency_penalty=0, 175 | presence_penalty=0, 176 | stop=None) 177 | prediction = response.choices[0].message.content.strip() 178 | if prediction != "" and prediction != None: 179 | return prediction 180 | except Exception as e: 181 | print(e) 182 | if sleep_time > 0: 183 | time.sleep(sleep_time) 184 | return "Fail to diagnose the reasons to the errors." 185 | 186 | def execute_function(self, func_call): 187 | """Execute a function call and return the result. 188 | 189 | Override this function to modify the way to execute a function call. 190 | 191 | Args: 192 | func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments". 193 | 194 | Returns: 195 | A tuple of (is_exec_success, result_dict). 196 | is_exec_success (boolean): whether the execution is successful. 197 | result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function". 198 | """ 199 | func_name = func_call.get("name", "") 200 | func = self._function_map.get(func_name, None) 201 | 202 | is_exec_success = False 203 | if func is not None: 204 | # Extract arguments from a json-like string and put it into a dict. 205 | input_string = self._format_json_str(func_call.get("arguments", "{}")) 206 | try: 207 | arguments = json.loads(input_string) 208 | except json.JSONDecodeError as e: 209 | arguments = None 210 | arguments_string = func_call["arguments"].split(': "')[-1] 211 | arguments_string = arguments_string.split('", ')[0] 212 | arguments = {"cell": arguments_string} 213 | # content = f"Error: {e}\n You argument should follow json format." 214 | content = f"Error: {e}\n There might be compilation errors in the code. Please check the code and try again." 215 | 216 | # Try to execute the function 217 | if arguments is not None: 218 | print( 219 | colored(f"\n>>>>>>>> EXECUTING FUNCTION {func_name}...", "magenta"), 220 | flush=True, 221 | ) 222 | self.code = arguments["cell"] 223 | try: 224 | content = func(**arguments) 225 | is_exec_success = True 226 | except Exception as e: 227 | content = f"Error: {e}" 228 | else: 229 | content = f"Error: Function {func_name} not found." 230 | if "error" in content or "Error" in content: 231 | reasons = self.error_debugger(self.config_list[0], self.code, content) 232 | content = content + '\nPotential Reasons: ' + reasons 233 | 234 | return is_exec_success, { 235 | "name": func_name, 236 | "role": "function", 237 | "content": str(content), 238 | } 239 | 240 | def update_memory(self, num_shots, memory): 241 | self.num_shots = num_shots 242 | self.memory = memory 243 | 244 | def register_dataset(self, dataset): 245 | self.dataset = dataset -------------------------------------------------------------------------------- /tools/tabtools.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import jsonlines 3 | import json 4 | import re 5 | import sqlite3 6 | import sys 7 | import Levenshtein 8 | def db_loader(target_ehr): 9 | ehr_dict = {"admissions":"/ehrsql/mimic_iii/ADMISSIONS.csv", 10 | "chartevents":"/ehrsql/mimic_iii/CHARTEVENTS.csv", 11 | "cost":"/ehrsql/mimic_iii/COST.csv", 12 | "d_icd_diagnoses":"/ehrsql/mimic_iii/D_ICD_DIAGNOSES.csv", 13 | "d_icd_procedures":"/ehrsql/mimic_iii/D_ICD_PROCEDURES.csv", 14 | "d_items":"/ehrsql/mimic_iii/D_ITEMS.csv", 15 | "d_labitems":"/ehrsql/mimic_iii/D_LABITEMS.csv", 16 | "diagnoses_icd":"/ehrsql/mimic_iii/DIAGNOSES_ICD.csv", 17 | "icustays":"/ehrsql/mimic_iii/ICUSTAYS.csv", 18 | "inputevents_cv":"/ehrsql/mimic_iii/INPUTEVENTS_CV.csv", 19 | "labevents":"/ehrsql/mimic_iii/LABEVENTS.csv", 20 | "microbiologyevents":"/ehrsql/mimic_iii/MICROBIOLOGYEVENTS.csv", 21 | "outputevents":"/mimic_iii/OUTPUTEVENTS.csv", 22 | "patients":"/ehrsql/mimic_iii/PATIENTS.csv", 23 | "prescriptions":"/ehrsql/mimic_iii/PRESCRIPTIONS.csv", 24 | "procedures_icd":"/ehrsql/mimic_iii/PROCEDURES_ICD.csv", 25 | "transfers":"/ehrsql/mimic_iii/TRANSFERS.csv", 26 | } 27 | data = pd.read_csv(ehr_dict[target_ehr]) 28 | # data = data.astype(str) 29 | column_names = ', '.join(data.columns.tolist()) 30 | return data 31 | # def get_column_names(self, target_db): 32 | # return ', '.join(data.columns.tolist()) 33 | 34 | def data_filter(data, argument): 35 | # commands = re.sub(r' ', '', argument) 36 | backup_data = data 37 | # print('-->', argument) 38 | commands = argument.split('||') 39 | for i in range(len(commands)): 40 | try: 41 | # commands[i] = commands[i].replace(' ', '') 42 | if '>=' in commands[i]: 43 | command = commands[i].split('>=') 44 | column_name = command[0] 45 | value = command[1] 46 | try: 47 | value = type(data[column_name][0])(value) 48 | except: 49 | value = value 50 | data = data[data[column_name] >= value] 51 | elif '<=' in commands[i]: 52 | command = commands[i].split('<=') 53 | column_name = command[0] 54 | value = command[1] 55 | try: 56 | value = type(data[column_name][0])(value) 57 | except: 58 | value = value 59 | data = data[data[column_name] <= value] 60 | elif '>' in commands[i]: 61 | command = commands[i].split('>') 62 | column_name = command[0] 63 | value = command[1] 64 | try: 65 | value = type(data[column_name][0])(value) 66 | except: 67 | value = value 68 | data = data[data[column_name] > value] 69 | elif '<' in commands[i]: 70 | command = commands[i].split('<') 71 | column_name = command[0] 72 | value = command[1] 73 | if value[0] == "'" or value[0] == '"': 74 | value = value[1:-1] 75 | try: 76 | value = type(data[column_name][0])(value) 77 | except: 78 | value = value 79 | data = data[data[column_name] < value] 80 | elif '=' in commands[i]: 81 | command = commands[i].split('=') 82 | column_name = command[0] 83 | value = command[1] 84 | # print(command) 85 | # print(value) 86 | if value[0] == "'" or value[0] == '"': 87 | value = value[1:-1] 88 | try: 89 | examplar = backup_data[column_name].tolist()[0] 90 | value = type(examplar)(value) 91 | # print(value, type(value), type(examplar)) 92 | except: 93 | value = value 94 | # print('--', value, type(value), type(examplar)) 95 | # print('------', len(data)) 96 | data = data[data[column_name] == value] 97 | # print('======', len(data)) 98 | elif ' in ' in commands[i]: 99 | command = commands[i].split(' in ') 100 | column_name = command[0] 101 | value = command[1] 102 | value_list = [s.strip() for s in value.strip("[]").split(',')] 103 | value_list = [s.strip("'").strip('"') for s in value_list] 104 | # print(command) 105 | # print(column_name) 106 | # print(value) 107 | # print(value_list) 108 | value_list = list(map(type(data[column_name][0]), value_list)) 109 | # print(len(data)) 110 | data = data[data[column_name].isin(value_list)] 111 | # print(len(data)) 112 | elif 'max' in commands[i]: 113 | command = commands[i].split('max(') 114 | column_name = command[1].split(')')[0] 115 | data = data[data[column_name] == data[column_name].max()] 116 | elif 'min' in commands[i]: 117 | command = commands[i].split('min(') 118 | column_name = command[1].split(')')[0] 119 | data = data[data[column_name] == data[column_name].min()] 120 | except: 121 | if column_name not in data.columns.tolist(): 122 | columns = ', '.join(data.columns.tolist()) 123 | raise Exception("The filtering query {} is incorrect. Please modify the column name or use LoadDB to read another table. The column names in the current DB are {}.".format(commands[i], columns)) 124 | if column_name == '' or value == '': 125 | raise Exception("The filtering query {} is incorrect. There is syntax error in the command. Please modify the condition or use LoadDB to read another table.".format(commands[i])) 126 | if len(data) == 0: 127 | # get 5 examples from the backup data what is in the same column 128 | column_values = list(set(backup_data[column_name].tolist())) 129 | if ('=' in commands[i]) and (not value in column_values) and (not '>=' in commands[i]) and (not '<=' in commands[i]): 130 | levenshtein_dist = {} 131 | for cv in column_values: 132 | levenshtein_dist[cv] = Levenshtein.distance(str(cv), str(value)) 133 | levenshtein_dist = sorted(levenshtein_dist.items(), key=lambda x: x[1], reverse=False) 134 | column_values = [i[0] for i in levenshtein_dist[:5]] 135 | column_values = ', '.join([str(i) for i in column_values]) 136 | raise Exception("The filtering query {} is incorrect. There is no {} value in the column. Five example values in the column are {}. Please check if you get the correct {} value.".format(commands[i], value, column_values, column_name)) 137 | else: 138 | return data 139 | return data 140 | 141 | def get_value(data, argument): 142 | try: 143 | commands = argument.split(', ') 144 | if len(commands) == 1: 145 | column = argument 146 | while column[0] == '[' or column[0] == "'": 147 | column = column[1:] 148 | while column[-1] == ']' or column[-1] == "'": 149 | column = column[:-1] 150 | if len(data) == 1: 151 | return str(data.iloc[0][column]) 152 | else: 153 | answer_list = list(set(data[column].tolist())) 154 | answer_list = [str(i) for i in answer_list] 155 | return ', '.join(answer_list) 156 | # else: 157 | # return "Get the value. But there are too many returned values. Please double-check the code and make necessary changes." 158 | else: 159 | column = commands[0] 160 | if 'mean' in commands[-1]: 161 | res_list = data[column].tolist() 162 | res_list = [float(i) for i in res_list] 163 | return sum(res_list)/len(res_list) 164 | elif 'max' in commands[-1]: 165 | res_list = data[column].tolist() 166 | try: 167 | res_list = [float(i) for i in res_list] 168 | except: 169 | res_list = [str(i) for i in res_list] 170 | return max(res_list) 171 | elif 'min' in commands[-1]: 172 | res_list = data[column].tolist() 173 | try: 174 | res_list = [float(i) for i in res_list] 175 | except: 176 | res_list = [str(i) for i in res_list] 177 | return min(res_list) 178 | elif 'sum' in commands[-1]: 179 | res_list = data[column].tolist() 180 | res_list = [float(i) for i in res_list] 181 | return sum(res_list) 182 | elif 'list' in commands[-1]: 183 | res_list = data[column].tolist() 184 | res_list = [str(i) for i in res_list] 185 | return list(res_list) 186 | else: 187 | raise Exception("The operation {} contains syntax errors. Please check the arguments.".format(commands[-1])) 188 | except: 189 | column_values = ', '.join(data.columns.tolist()) 190 | raise Exception("The column name {} is incorrect. Please check the column name and make necessary changes. The columns in this table include {}.".format(column, column_values)) 191 | 192 | def sql_interpreter(command): 193 | con = sqlite3.connect("/ehrsql/mimic_iii/mimic_iii.db") 194 | cur = con.cursor() 195 | results = cur.execute(command).fetchall() 196 | return results 197 | 198 | def date_calculator(argument): 199 | try: 200 | con = sqlite3.connect("/ehrsql/mimic_iii/mimic_iii.db") 201 | cur = con.cursor() 202 | command = "select datetime(current_time, '{}')".format(argument) 203 | results = cur.execute(command).fetchall()[0][0] 204 | except: 205 | raise Exception("The date calculator {} is incorrect. Please check the syntax and make necessary changes. For the current date and time, please call Calendar('0 year').".format(argument)) 206 | return results 207 | 208 | if __name__ == "__main__": 209 | db = table_toolkits() 210 | print(db.db_loader("microbiologyevents")) 211 | # print(db.data_filter("SPEC_TYPE_DESC=peripheral blood lymphocytes")) 212 | print(db.data_filter("HADM_ID=107655")) 213 | print(db.data_filter("SPEC_TYPE_DESC=peripheral blood lymphocytes")) 214 | print(db.get_value('CHARTTIME')) 215 | # results = db.sql_interpreter("select max(t1.c1) from ( select sum(cost.cost) as c1 from cost where cost.hadm_id in ( select diagnoses_icd.hadm_id from diagnoses_icd where diagnoses_icd.icd9_code = ( select d_icd_diagnoses.icd9_code from d_icd_diagnoses where d_icd_diagnoses.short_title = 'comp-oth vasc dev/graft' ) ) and datetime(cost.chargetime) >= datetime(current_time,'-1 year') group by cost.hadm_id ) as t1") 216 | # results = [result[0] for result in results] 217 | # if len(results) == 1: 218 | # print(results[0]) 219 | # else: 220 | # print(results) 221 | # print(db.date_calculator('-1 year')) -------------------------------------------------------------------------------- /ehragent/prompts_eicu.py: -------------------------------------------------------------------------------- 1 | CodeHeader = """from tools import tabtools, calculator 2 | Calculate = calculator.WolframAlphaCalculator 3 | LoadDB = tabtools.db_loader 4 | FilterDB = tabtools.data_filter 5 | GetValue = tabtools.get_value 6 | SQLInterpreter = tabtools.sql_interpreter 7 | Calendar = tabtools.date_calculator 8 | """ 9 | 10 | RetrKnowledge = """Read the following data descriptions, generate the background knowledge as the context information that could be helpful for answering the question. 11 | (1) Data include vital signs, laboratory measurements, medications, APACHE components, care plan information, admission diagnosis, patient history, time-stamped diagnoses from a structured problem list, and similarly chosen treatments. 12 | (2) Data from each patient is collected into a common warehouse only if certain “interfaces” are available. Each interface is used to transform and load a certain type of data: vital sign interfaces incorporate vital signs, laboratory interfaces provide measurements on blood samples, and so on. 13 | (3) It is important to be aware that different care units may have different interfaces in place, and that the lack of an interface will result in no data being available for a given patient, even if those measurements were made in reality. The data is provided as a relational database, comprising multiple tables joined by keys. 14 | (4) All the databases are used to record information associated to patient care, such as allergy, cost, diagnosis, intakeoutput, lab, medication, microlab, patient, treatment, vitalperiodic. 15 | For different tables, they contain the following information: 16 | (1) allergy: allergyid, patientunitstayid, drugname, allergyname, allergytime 17 | (2) cost: costid, uniquepid, patienthealthsystemstayid, eventtype, eventid, chargetime, cost 18 | (3) diagnosis: diagnosisid, patientunitstayid, icd9code, diagnosisname, diagnosistime 19 | (4) intakeoutput: intakeoutputid, patientunitstayid, cellpath, celllabel, cellvaluenumeric, intakeoutputtime 20 | (5) lab: labid, patientunitstayid, labname, labresult, labresulttime 21 | (6) medication: medicationid, patientunitstayid, drugname, dosage, routeadmin, drugstarttime, drugstoptime 22 | (7) microlab: microlabid, patientunitstayid, culturesite, organism, culturetakentime 23 | (8) patient: patientunitstayid, patienthealthsystemstayid, gender, age, ethnicity, hospitalid, wardid, admissionheight, hospitaladmitsource, hospitaldischargestatus, admissionweight, dischargeweight, uniquepid, hospitaladmittime, unitadmittime, unitdischargetime, hospitaldischargetime 24 | (9) treatment: treatmentid, patientunitstayid, treatmentname, treatmenttime 25 | (10) vitalperiodic: vitalperiodicid, patientunitstayid, temperature, sao2, heartrate, respiration, systemicsystolic, systemicdiastolic, systemicmean, observationtime 26 | 27 | Question: was the fluticasone-salmeterol 250-50 mcg/dose in aepb prescribed to patient 035-2205 on their current hospital encounter? 28 | Knowledge: 29 | - We can find the patient 035-2205 information in the patient database. 30 | - As fluticasone-salmeterol 250-50 mcg/dose in aepb is a drug, we can find the drug information in the medication database. 31 | - We can find the patientunitstayid in the patient database and use it to find the drug precsription information in the medication database. 32 | 33 | Question: in the last hospital encounter, when was patient 031-22988's first microbiology test time? 34 | Knowledge: 35 | - We can find the patient 031-22988 information in the patient database. 36 | - We can find the microbiology test information in the microlab database. 37 | - We can find the patientunitstayid in the patient database and use it to find the microbiology test information in the microlab database. 38 | 39 | Question: what is the minimum hospital cost for a drug with a name called albumin 5% since 6 years ago? 40 | Knowledge: 41 | - As albumin 5% is a drug, we can find the drug information in the medication database. 42 | - We can find the patientunitstayid in the medication database and use it to find the patienthealthsystemstayid information in the patient database. 43 | - We can use the patienthealthsystemstayid information to find the cost information in the cost database. 44 | 45 | Question: what are the number of patients who have had a magnesium test the previous year? 46 | Knowledge: 47 | - As magnesium is a lab test, we can find the lab test information in the lab database. 48 | - We can find the patientunitstayid in the lab database and use it to find the patient information in the patient database. 49 | 50 | Question: {question} 51 | Knowledge: 52 | """ 53 | 54 | SYSTEM_PROMPT = """You are a helpful AI assistant. Solve tasks using your coding and language skills. 55 | In the following cases, suggest python code (in a python coding block) or shell script (in a sh 56 | coding block) for the user to execute. 57 | 1. When you need to collect info, use the code to output the info you need, for example, browse or 58 | search the web, download/read a file, print the content of a webpage or a file, get the current 59 | date/time. After sufficient info is printed and the task is ready to be solved based on your 60 | language skill, you can solve the task by yourself. 61 | 2. When you need to perform some task with code, use the code to perform the task and output the 62 | result. Finish the task smartly. 63 | Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be 64 | clear which step uses code, and which step uses your language skill. 65 | When using code, you must indicate the script type in the code block. The user cannot provide any 66 | other feedback or perform any other action beyond executing the code you suggest. The user can't 67 | modify your code. So do not suggest incomplete code which requires users to modify. Don't use a 68 | code block if it's not intended to be executed by the user. 69 | If you want the user to save the code in a file before executing it, put # filename: 70 | inside the code block as the first line. Don't include multiple code blocks in one response. Do not 71 | ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. 72 | Check the execution result returned by the user. 73 | If the result indicates there is an error, fix the error and output the code again. Suggest the 74 | full code instead of partial code or code changes. If the error can't be fixed or if the task is 75 | not solved even after the code is executed successfully, analyze the problem, revisit your 76 | assumption, collect additional info you need, and think of a different approach to try. 77 | When you find an answer, verify the answer carefully. Include verifiable evidence in your response 78 | if possible. 79 | Reply "TERMINATE" in the end when everything is done.""" 80 | 81 | EHRAgent_Message_Prompt = """Assume you have knowledge of several tables: 82 | (1) Tables are linked by identifiers which usually have the suffix 'ID'. For example, SUBJECT_ID refers to a unique patient, HADM_ID refers to a unique admission to the hospital, and ICUSTAY_ID refers to a unique admission to an intensive care unit. 83 | (2) Charted events such as notes, laboratory tests, and fluid balance are stored in a series of 'events' tables. For example the outputevents table contains all measurements related to output for a given patient, while the labevents table contains laboratory test results for a patient. 84 | (3) Tables prefixed with 'd_' are dictionary tables and provide definitions for identifiers. For example, every row of chartevents is associated with a single ITEMID which represents the concept measured, but it does not contain the actual name of the measurement. By joining chartevents and d_items on ITEMID, it is possible to identify the concept represented by a given ITEMID. 85 | (4) For the databases, four of them are used to define and track patient stays: admissions, patients, icustays, and transfers. Another four tables are dictionaries for cross-referencing codes against their respective definitions: d_icd_diagnoses, d_icd_procedures, d_items, and d_labitems. The remaining tables, including chartevents, cost, inputevents_cv, labevents, microbiologyevents, outputevents, prescriptions, procedures_icd, contain data associated with patient care, such as physiological measurements, caregiver observations, and billing information. 86 | Write a python code to solve the given question. You can use the following functions: 87 | (1) Calculate(FORMULA), which calculates the FORMULA and returns the result. 88 | (2) LoadDB(DBNAME) which loads the database DBNAME and returns the database. The DBNAME can be one of the following: allergy, cost, diagnosis, intakeoutput, lab, medication, microlab, patient, treatment, vitalperiodic. 89 | (3) FilterDB(DATABASE, CONDITIONS), which filters the DATABASE according to the CONDITIONS and returns the filtered database. The CONDITIONS is a string composed of multiple conditions, each of which consists of the column_name, the relation and the value (e.g., COST<10). The CONDITIONS is one single string (e.g., "admissions, SUBJECT_ID=24971"). Different conditions are separated by '||'. 90 | (4) GetValue(DATABASE, ARGUMENT), which returns a string containing all the values of the column in the DATABASE (if multiple values, separated by ", "). When there is no additional operations on the values, the ARGUMENT is the column_name in demand. If the values need to be returned with certain operations, the ARGUMENT is composed of the column_name and the operation (like COST, sum). Please do not contain " or ' in the argument. 91 | (5) SQLInterpreter(SQL), which interprets the query SQL and returns the result. 92 | (6) Calendar(DURATION), which returns the date after the duration of time. 93 | Use the variable 'answer' to store the answer of the code. Here are some examples: 94 | {examples} 95 | (END OF EXAMPLES) 96 | Knowledge: 97 | {knowledge} 98 | Question: {question} 99 | Solution: """ 100 | 101 | DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS = { 102 | "ALWAYS": "An attentive HUMAN user who can answer questions about the task, and can perform tasks such as running Python code or inputting command line commands at a Linux terminal and reporting back the execution results.", 103 | "TERMINATE": "A user that can run Python code or input command line commands at a Linux terminal and report back the execution results.", 104 | "NEVER": "A user that can run Python code or input command line commands at a Linux terminal and report back the execution results.", 105 | } 106 | 107 | CodeDebugger = """Given a question: 108 | {question} 109 | The user have written code with the following functions: 110 | (1) Calculate(FORMULA), which calculates the FORMULA and returns the result. 111 | (2) LoadDB(DBNAME) which loads the database DBNAME and returns the database. The DBNAME can be one of the following: allergy, cost, diagnosis, intakeoutput, lab, medication, microlab, patient, treatment, vitalperiodic. 112 | (3) FilterDB(DATABASE, CONDITIONS), which filters the DATABASE according to the CONDITIONS. The CONDITIONS is a string composed of multiple conditions, each of which consists of the column_name, the relation and the value (e.g., COST<10). The CONDITIONS is one single string (e.g., "admissions, SUBJECT_ID=24971"). Different conditions are separated by '||'. 113 | (4) GetValue(DATABASE, ARGUMENT), which returns the values of the column in the DATABASE. When there is no additional operations on the values, the ARGUMENT is the column_name in demand. If the values need to be returned with certain operations, the ARGUMENT is composed of the column_name and the operation (like COST, sum). Please do not contain " or ' in the argument. 114 | (5) SQLInterpreter(SQL), which interprets the query SQL and returns the result. 115 | (6) Calendar(DURATION), which returns the date after the duration of time. 116 | 117 | The code is as follows: 118 | {code} 119 | 120 | The execution result is: 121 | {error_info} 122 | 123 | Please check the code and point out the most possible reason to the error. 124 | """ 125 | 126 | EHRAgent_4Shots_Knowledge = """Question: was the fluticasone-salmeterol 250-50 mcg/dose in aepb prescribed to patient 035-2205 on their current hospital encounter? 127 | Knowledge: 128 | - We can find the patient 035-2205 information in the patient database. 129 | - As fluticasone-salmeterol 250-50 mcg/dose in aepb is a drug, we can find the drug information in the medication database. 130 | - We can find the patientunitstayid in the patient database and use it to find the drug precsription information in the medication database. 131 | Solution: patient_db = LoadDB('patient') 132 | filtered_patient_db = FilterDB(patient_db, 'uniquepid=035-2205||hospitaldischargetime=null') 133 | patientunitstayid = GetValue(filtered_patient_db, 'patientunitstayid') 134 | medication_db = LoadDB('medication') 135 | filtered_medication_db = FilterDB(medication_db, 'patientunitstayid={}||drugname=fluticasone-salmeterol 250-50 mcg/dose in aepb'.format(patientunitstayid)) 136 | if len(filtered_medication_db) > 0: 137 | answer = 1 138 | else: 139 | answer = 0 140 | 141 | Question: in the last hospital encounter, when was patient 031-22988's first microbiology test time? 142 | Knowledge: 143 | - We can find the patient 031-22988 information in the patient database. 144 | - We can find the microbiology test information in the microlab database. 145 | - We can find the patientunitstayid in the patient database and use it to find the microbiology test information in the microlab database. 146 | Solution: patient_db = LoadDB('patient') 147 | filtered_patient_db = FilterDB(patient_db, 'uniquepid=031-22988||max(hospitaladmittime)') 148 | patientunitstayid = GetValue(filtered_patient_db, 'patientunitstayid') 149 | microlab_db = LoadDB('microlab') 150 | filtered_microlab_db = FilterDB(microlab_db, 'patientunitstayid={}||min(culturetakentime)'.format(patientunitstayid)) 151 | culturetakentime = GetValue(filtered_microlab_db, 'culturetakentime') 152 | answer = culturetakentime 153 | 154 | Question: what is the minimum hospital cost for a drug with a name called albumin 5% since 6 years ago? 155 | Knowledge: 156 | - As albumin 5% is a drug, we can find the drug information in the medication database. 157 | - We can find the patientunitstayid in the medication database and use it to find the patienthealthsystemstayid information in the patient database. 158 | - We can use the patienthealthsystemstayid information to find the cost information in the cost database. 159 | Solution: date = Calendar('-6 year') 160 | medication_db = LoadDB('medication') 161 | filtered_medication_db = FilterDB(medication_db, 'drugname=albumin 5%') 162 | patientunitstayid_list = GetValue(filtered_medication_db, 'patientunitstayid, list') 163 | patient_db = LoadDB('patient') 164 | filtered_patient_db = FilterDB(patient_db, 'patientunitstayid in {}'.format(patientunitstayid_list)) 165 | patienthealthsystemstayid_list = GetValue(filtered_patient_db, 'patienthealthsystemstayid, list') 166 | cost_db = LoadDB('cost') 167 | min_cost = 1e9 168 | for patienthealthsystemstayid in patienthealthsystemstayid_list: 169 | filtered_cost_db = FilterDB(cost_db, 'patienthealthsystemstayid={}||chargetime>{}'.format(patienthealthsystemstayid, date)) 170 | cost = GetValue(filtered_cost_db, 'cost, sum') 171 | if cost < min_cost: 172 | min_cost = cost 173 | answer = min_cost 174 | 175 | Question: what are the number of patients who have had a magnesium test the previous year? 176 | Knowledge: 177 | - As magnesium is a lab test, we can find the lab test information in the lab database. 178 | - We can find the patientunitstayid in the lab database and use it to find the patient information in the patient database. 179 | Solution: answer = SQLInterpreter[select count( distinct patient.uniquepid ) from patient where patient.patientunitstayid in ( select lab.patientunitstayid from lab where lab.labname = 'magnesium' and datetime(lab.labresulttime,'start of year') = datetime(current_time,'start of year','-1 year') )] 180 | """ -------------------------------------------------------------------------------- /ehragent/prompts_mimic.py: -------------------------------------------------------------------------------- 1 | CodeHeader = """from tools import tabtools, calculator 2 | Calculate = calculator.WolframAlphaCalculator 3 | LoadDB = tabtools.db_loader 4 | FilterDB = tabtools.data_filter 5 | GetValue = tabtools.get_value 6 | SQLInterpreter = tabtools.sql_interpreter 7 | Calendar = tabtools.date_calculator 8 | """ 9 | 10 | RetrKnowledge = """Read the following data descriptions, generate the background knowledge as the context information that could be helpful for answering the question. 11 | (1) Tables are linked by identifiers which usually have the suffix 'ID'. For example, SUBJECT_ID refers to a unique patient, HADM_ID refers to a unique admission to the hospital, and ICUSTAY_ID refers to a unique admission to an intensive care unit. 12 | (2) Charted events such as notes, laboratory tests, and fluid balance are stored in a series of 'events' tables. For example the outputevents table contains all measurements related to output for a given patient, while the labevents table contains laboratory test results for a patient. 13 | (3) Tables prefixed with 'd_' are dictionary tables and provide definitions for identifiers. For example, every row of chartevents is associated with a single ITEMID which represents the concept measured, but it does not contain the actual name of the measurement. By joining chartevents and d_items on ITEMID, it is possible to identify the concept represented by a given ITEMID. 14 | (4) For the databases, four of them are used to define and track patient stays: admissions, patients, icustays, and transfers. Another four tables are dictionaries for cross-referencing codes against their respective definitions: d_icd_diagnoses, d_icd_procedures, d_items, and d_labitems. The remaining tables, including chartevents, cost, inputevents_cv, labevents, microbiologyevents, outputevents, prescriptions, procedures_icd, contain data associated with patient care, such as physiological measurements, caregiver observations, and billing information. 15 | For different tables, they contain the following information: 16 | (1) admissions: ROW_ID, SUBJECT_ID, HADM_ID, ADMITTIME, DISCHTIME, ADMISSION_TYPE, ADMISSION_LOCATION, DISCHARGE_LOCATION, INSURANCE, LANGUAGE, MARITAL_STATUS, ETHNICITY, AGE 17 | (2) chartevents: ROW_ID, SUBJECT_ID, HADM_ID, ICUSTAY_ID, ITEMID, CHARTTIME, VALUENUM, VALUEUOM 18 | (3) cost: ROW_ID, SUBJECT_ID, HADM_ID, EVENT_TYPE, EVENT_ID, CHARGETIME, COST 19 | (4) d_icd_diagnoses: ROW_ID, ICD9_CODE, SHORT_TITLE, LONG_TITLE 20 | (5) d_icd_procedures: ROW_ID, ICD9_CODE, SHORT_TITLE, LONG_TITLE 21 | (6) d_items: ROW_ID, ITEMID, LABEL, LINKSTO 22 | (7) d_labitems: ROW_ID, ITEMID, LABEL 23 | (8) dianoses_icd: ROW_ID, SUBJECT_ID, HADM_ID, ICD9_CODE, CHARTTIME 24 | (9) icustays: ROW_ID, SUBJECT_ID, HADM_ID, ICUSTAY_ID, FIRST_CAREUNIT, LAST_CAREUNIT, FIRST_WARDID, LAST_WARDID, INTIME, OUTTIME 25 | (10) inputevents_cv: ROW_ID, SUBJECT_ID, HADM_ID, ICUSTAY_ID, CHARTTIME, ITEMID, AMOUNT 26 | (11) labevents: ROW_ID, SUBJECT_ID, HADM_ID, ITEMID, CHARTTIME, VALUENUM, VALUEUOM 27 | (12) microbiologyevents: RROW_ID, SUBJECT_ID, HADM_ID, CHARTTIME, SPEC_TYPE_DESC, ORG_NAME 28 | (13) outputevents: ROW_ID, SUBJECT_ID, HADM_ID, ICUSTAY_ID, CHARTTIME, ITEMID, VALUE 29 | (14) patients: ROW_ID, SUBJECT_ID, GENDER, DOB, DOD 30 | (15) prescriptions: ROW_ID, SUBJECT_ID, HADM_ID, STARTDATE, ENDDATE, DRUG, DOSE_VAL_RX, DOSE_UNIT_RX, ROUTE 31 | (16) procedures_icd: ROW_ID, SUBJECT_ID, HADM_ID, ICD9_CODE, CHARTTIME 32 | (17) transfers: ROW_ID, SUBJECT_ID, HADM_ID, ICUSTAY_ID, EVENTTYPE, CAREUNIT, WARDID, INTIME, OUTTIME 33 | 34 | Question: What is the maximum total hospital cost that involves a diagnosis named comp-oth vasc dev/graft since 1 year ago? 35 | Knowledge: 36 | - As comp-oth vasc dev/graft is a diagnose, the corresponding ICD9_CODE can be found in the d_icd_diagnoses database. 37 | - The ICD9_CODE can be used to find the corresponding HADM_ID in the diagnoses_icd database. 38 | - The HADM_ID can be used to find the corresponding COST in the cost database. 39 | 40 | Question: had any tpn w/lipids been given to patient 2238 in their last hospital visit? 41 | Knowledge: 42 | - We can find the visiting information of patient 2238 in the admissions database. 43 | - As tpn w/lipids is an item, we can find the corresponding information in the d_items database. 44 | - As admissions only contains the visiting information of patients, we need to find the corresponding ICUSTAY_ID in the icustays database. 45 | - We will check the inputevents_cv database to see if there is any record of tpn w/lipids given to patient 2238 in their last hospital visit. 46 | 47 | Question: what was the name of the procedure that was given two or more times to patient 58730? 48 | Knowledge: 49 | - We can find the visiting information of patient 58730 in the admissions database. 50 | - As procedures are stored in the procedures_icd database, we can find the corresponding ICD9_CODE in the procedures_icd database. 51 | - As we only need to find the name of the procedure, we can find the corresponding SHORT_TITLE as the name in the d_icd_procedures database. 52 | 53 | Question: {question} 54 | Knowledge: 55 | """ 56 | 57 | SYSTEM_PROMPT = """You are a helpful AI assistant. Solve tasks using your coding and language skills. 58 | In the following cases, suggest python code (in a python coding block) or shell script (in a sh 59 | coding block) for the user to execute. 60 | 1. When you need to collect info, use the code to output the info you need, for example, browse or 61 | search the web, download/read a file, print the content of a webpage or a file, get the current 62 | date/time. After sufficient info is printed and the task is ready to be solved based on your 63 | language skill, you can solve the task by yourself. 64 | 2. When you need to perform some task with code, use the code to perform the task and output the 65 | result. Finish the task smartly. 66 | Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be 67 | clear which step uses code, and which step uses your language skill. 68 | When using code, you must indicate the script type in the code block. The user cannot provide any 69 | other feedback or perform any other action beyond executing the code you suggest. The user can't 70 | modify your code. So do not suggest incomplete code which requires users to modify. Don't use a 71 | code block if it's not intended to be executed by the user. 72 | If you want the user to save the code in a file before executing it, put # filename: 73 | inside the code block as the first line. Don't include multiple code blocks in one response. Do not 74 | ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. 75 | Check the execution result returned by the user. 76 | If the result indicates there is an error, fix the error and output the code again. Suggest the 77 | full code instead of partial code or code changes. If the error can't be fixed or if the task is 78 | not solved even after the code is executed successfully, analyze the problem, revisit your 79 | assumption, collect additional info you need, and think of a different approach to try. 80 | When you find an answer, verify the answer carefully. Include verifiable evidence in your response 81 | if possible. 82 | Reply "TERMINATE" in the end when everything is done.""" 83 | 84 | EHRAgent_Message_Prompt = """Assume you have knowledge of several tables: 85 | (1) Tables are linked by identifiers which usually have the suffix 'ID'. For example, SUBJECT_ID refers to a unique patient, HADM_ID refers to a unique admission to the hospital, and ICUSTAY_ID refers to a unique admission to an intensive care unit. 86 | (2) Charted events such as notes, laboratory tests, and fluid balance are stored in a series of 'events' tables. For example the outputevents table contains all measurements related to output for a given patient, while the labevents table contains laboratory test results for a patient. 87 | (3) Tables prefixed with 'd_' are dictionary tables and provide definitions for identifiers. For example, every row of chartevents is associated with a single ITEMID which represents the concept measured, but it does not contain the actual name of the measurement. By joining chartevents and d_items on ITEMID, it is possible to identify the concept represented by a given ITEMID. 88 | (4) For the databases, four of them are used to define and track patient stays: admissions, patients, icustays, and transfers. Another four tables are dictionaries for cross-referencing codes against their respective definitions: d_icd_diagnoses, d_icd_procedures, d_items, and d_labitems. The remaining tables, including chartevents, cost, inputevents_cv, labevents, microbiologyevents, outputevents, prescriptions, procedures_icd, contain data associated with patient care, such as physiological measurements, caregiver observations, and billing information. 89 | Write a python code to solve the given question. You can use the following functions: 90 | (1) Calculate(FORMULA), which calculates the FORMULA and returns the result. 91 | (2) LoadDB(DBNAME) which loads the database DBNAME and returns the database. The DBNAME can be one of the following: admissions, chartevents, cost, d_icd_diagnoses, d_icd_procedures, d_items, d_labitems, diagnoses_icd, icustays, inputevents_cv, labevents, microbiologyevents, outputevents,patients, prescriptions, procedures_icd, transfers. 92 | (3) FilterDB(DATABASE, CONDITIONS), which filters the DATABASE according to the CONDITIONS and returns the filtered database. The CONDITIONS is a string composed of multiple conditions, each of which consists of the column_name, the relation and the value (e.g., COST<10). The CONDITIONS is one single string (e.g., "admissions, SUBJECT_ID=24971"). 93 | (4) GetValue(DATABASE, ARGUMENT), which returns a string containing all the values of the column in the DATABASE (if multiple values, separated by ", "). When there is no additional operations on the values, the ARGUMENT is the column_name in demand. If the values need to be returned with certain operations, the ARGUMENT is composed of the column_name and the operation (like COST, sum). Please do not contain " or ' in the argument. 94 | (5) SQLInterpreter(SQL), which interprets the query SQL and returns the result. 95 | (6) Calendar(DURATION), which returns the date after the duration of time. 96 | Use the variable 'answer' to store the answer of the code. Here are some examples: 97 | {examples} 98 | (END OF EXAMPLES) 99 | Knowledge: 100 | {knowledge} 101 | Question: {question} 102 | Solution: """ 103 | 104 | DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS = { 105 | "ALWAYS": "An attentive HUMAN user who can answer questions about the task, and can perform tasks such as running Python code or inputting command line commands at a Linux terminal and reporting back the execution results.", 106 | "TERMINATE": "A user that can run Python code or input command line commands at a Linux terminal and report back the execution results.", 107 | "NEVER": "A user that can run Python code or input command line commands at a Linux terminal and report back the execution results.", 108 | } 109 | 110 | CodeDebugger = """Given a question: 111 | {question} 112 | The user have written code with the following functions: 113 | (1) Calculate(FORMULA), which calculates the FORMULA and returns the result. 114 | (2) LoadDB(DBNAME) which loads the database DBNAME and returns the database. The DBNAME can be one of the following: admissions, chartevents, cost, d_icd_diagnoses, d_icd_procedures, d_items, d_labitems, diagnoses_icd, icustays, inputevents_cv, labevents, microbiologyevents, outputevents,patients, prescriptions, procedures_icd, transfers. 115 | (3) FilterDB(DATABASE, CONDITIONS), which filters the DATABASE according to the CONDITIONS. The CONDITIONS is a string composed of multiple conditions, each of which consists of the column_name, the relation and the value (e.g., COST<10). The CONDITIONS is one single string (e.g., "admissions, SUBJECT_ID=24971"). 116 | (4) GetValue(DATABASE, ARGUMENT), which returns the values of the column in the DATABASE. When there is no additional operations on the values, the ARGUMENT is the column_name in demand. If the values need to be returned with certain operations, the ARGUMENT is composed of the column_name and the operation (like COST, sum). Please do not contain " or ' in the argument. 117 | (5) SQLInterpreter(SQL), which interprets the query SQL and returns the result. 118 | (6) Calendar(DURATION), which returns the date after the duration of time. 119 | 120 | The code is as follows: 121 | {code} 122 | 123 | The execution result is: 124 | {error_info} 125 | 126 | Please check the code and point out the most possible reason to the error. 127 | """ 128 | 129 | EHRAgent_4Shots_Knowledge = """Question: What is the maximum total hospital cost that involves a diagnosis named comp-oth vasc dev/graft since 1 year ago? 130 | Knowledge: 131 | - As comp-oth vasc dev/graft is a diagnose, the corresponding ICD9_CODE can be found in the d_icd_diagnoses database. 132 | - The ICD9_CODE can be used to find the corresponding HADM_ID in the diagnoses_icd database. 133 | - The HADM_ID can be used to find the corresponding COST in the cost database. 134 | Solution: date = Calendar('-1 year') 135 | # As comp-oth vasc dev/graft is a diagnose, the corresponding ICD9_CODE can be found in the d_icd_diagnoses database. 136 | diagnosis_db = LoadDB('d_icd_diagnoses') 137 | filtered_diagnosis_db = FilterDB(diagnosis_db, 'SHORT_TITLE=comp-oth vasc dev/graft') 138 | icd_code = GetValue(filtered_diagnosis_db, 'ICD9_CODE') 139 | # The ICD9_CODE can be used to find the corresponding HADM_ID in the diagnoses_icd database. 140 | diagnoses_icd_db = LoadDB('diagnoses_icd') 141 | filtered_diagnoses_icd_db = FilterDB(diagnoses_icd_db, 'ICD9_CODE={}'.format(icd_code)) 142 | hadm_id_list = GetValue(filtered_diagnoses_icd_db, 'HADM_ID, list') 143 | # The HADM_ID can be used to find the corresponding COST in the cost database. 144 | max_cost = 0 145 | for hadm_id in hadm_id_list: 146 | cost_db = LoadDB('cost') 147 | filtered_cost_db = FilterDB(cost_db, 'HADM_ID={}'.format(hadm_id)) 148 | cost = GetValue(filtered_cost_db, 'COST, sum') 149 | if cost > max_cost: 150 | max_cost = cost 151 | answer = max_cost 152 | 153 | Question: had any tpn w/lipids been given to patient 2238 in their last hospital visit? 154 | Knowledge: 155 | - We can find the visiting information of patient 2238 in the admissions database. 156 | - As tpn w/lipids is an item, we can find the corresponding information in the d_items database. 157 | - As admissions only contains the visiting information of patients, we need to find the corresponding ICUSTAY_ID in the icustays database. 158 | - We will check the inputevents_cv database to see if there is any record of tpn w/lipids given to patient 2238 in their last hospital visit. 159 | Solution: # We can find the visiting information of patient 2238 in the admissions database. 160 | patient_db = LoadDB('admissions') 161 | filtered_patient_db = FilterDB(patient_db, 'SUBJECT_ID=2238||min(DISCHTIME)') 162 | hadm_id = GetValue(filtered_patient_db, 'HADM_ID') 163 | # As tpn w/lipids is an item, we can find the corresponding information in the d_items database. 164 | d_items_db = LoadDB('d_items') 165 | filtered_d_items_db = FilterDB(d_items_db, 'LABEL=tpn w/lipids') 166 | item_id = GetValue(filtered_d_items_db, 'ITEMID') 167 | # As admissions only contains the visiting information of patients, we need to find the corresponding ICUSTAY_ID in the icustays database. 168 | icustays_db = LoadDB('icustays') 169 | filtered_icustays_db = FilterDB(icustays_db, 'HADM_ID={}'.format(hadm_id)) 170 | icustay_id = GetValue(filtered_icustays_db, 'ICUSTAY_ID') 171 | # We will check the inputevents_cv database to see if there is any record of tpn w/lipids given to patient 2238 in their last hospital visit. 172 | inputevents_cv_db = LoadDB('inputevents_cv') 173 | filtered_inputevents_cv_db = FilterDB(inputevents_cv_db, 'HADM_ID={}||ICUSTAY_ID={}||ITEMID={}'.format(hadm_id, icustay_id, item_id)) 174 | if len(filtered_inputevents_cv_db) > 0: 175 | answer = 1 176 | else: 177 | answer = 0 178 | 179 | Question: what was the name of the procedure that was given two or more times to patient 58730? 180 | Knowledge: 181 | - We can find the visiting information of patient 58730 in the admissions database. 182 | - As procedures are stored in the procedures_icd database, we can find the corresponding ICD9_CODE in the procedures_icd database. 183 | - As we only need to find the name of the procedure, we can find the corresponding SHORT_TITLE as the name in the d_icd_procedures database. 184 | Solution: answer = SQLInterpreter('select d_icd_procedures.short_title from d_icd_procedures where d_icd_procedures.icd9_code in ( select t1.icd9_code from ( select procedures_icd.icd9_code, count( procedures_icd.charttime ) as c1 from procedures_icd where procedures_icd.hadm_id in ( select admissions.hadm_id from admissions where admissions.subject_id = 58730 ) group by procedures_icd.icd9_code ) as t1 where t1.c1 >= 2 )') 185 | 186 | Question: calculate the length of stay of the first stay of patient 27392 in the icu. 187 | Knowledge: 188 | - We can find the visiting information of patient 27392 in the admissions database. 189 | - As we only need to find the length of stay, we can find the corresponding INTIME and OUTTIME in the icustays database. 190 | Solution: from datetime import datetime 191 | patient_db = LoadDB('admissions') 192 | filtered_patient_db = FilterDB(patient_db, 'SUBJECT_ID=27392||min(ADMITTIME)') 193 | hadm_id = GetValue(filtered_patient_db, 'HADM_ID') 194 | icustays_db = LoadDB('icustays') 195 | filtered_icustays_db = FilterDB(icustays_db, 'HADM_ID={}'.format(hadm_id)) 196 | intime = GetValue(filtered_icustays_db, 'INTIME') 197 | outtime = GetValue(filtered_icustays_db, 'OUTTIME') 198 | intime = datetime.strptime(intime, '%Y-%m-%d %H:%M:%S') 199 | outtime = datetime.strptime(outtime, '%Y-%m-%d %H:%M:%S') 200 | length_of_stay = outtime - intime 201 | if length_of_stay.seconds // 3600 > 12: 202 | answer = length_of_stay.days + 1 203 | else: 204 | answer = length_of_stay.days 205 | """ --------------------------------------------------------------------------------