├── ui ├── __init__.py ├── chat_display.py ├── query_confirmation.py └── icons.py ├── utils ├── __init__.py ├── save_interaction.py ├── chain_processor.py ├── my_llm.py ├── memory_handler.py ├── error_handler.py ├── message_handler.py ├── umls_mapper.py ├── umls.py └── my_langchain_tools.py ├── requirements.txt ├── readme.MD ├── LICENSE ├── interesting_q.txt ├── interaction.jsonl ├── evaluation ├── evaluator.py ├── evaluation.ipynb ├── drugbot_data.json └── qa_set.json ├── .gitignore ├── config.py ├── app.py ├── graph_definition.py └── my_db_specifics.py /ui/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | duckdb==1.1.1 2 | langchain==0.3.8 3 | langchain_community==0.3.8 4 | langchain_core==0.3.21 5 | langchain_openai==0.2.9 6 | langgraph==0.2.53 7 | pandas==2.2.3 8 | pydantic==2.10.1 9 | PyYAML==6.0.2 10 | SQLAlchemy==2.0.35 11 | sqlparse==0.5.2 12 | streamlit==1.40.0 13 | streamlit_monaco==0.1.3 14 | trustcall==0.0.26 15 | 16 | -------------------------------------------------------------------------------- /utils/save_interaction.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | def save_interaction(question: str, query: str, filename: str = "interaction.jsonl"): 5 | """ 6 | Save the interaction to a JSONL file. 7 | 8 | Args: 9 | question (str): The user's question 10 | query (str): The generated and confirmed query 11 | filename (str): The name of the JSONL file to save to 12 | """ 13 | interaction = { 14 | "input": question, 15 | "query": query 16 | } 17 | 18 | with open(filename, "a", encoding='utf-8') as myfile: 19 | myfile.write(json.dumps({"input": question, "query": query}) + "\n") -------------------------------------------------------------------------------- /readme.MD: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # Introduction 6 | 7 | 8 | 9 | This repository contains code and data for my article "[Tailor a Multi-Model Chatbot for a Multi-Model DuckDB](https://dgg32.medium.com/tailor-a-multi-model-chatbot-for-a-multi-model-duckdb-8afb7ac4c1fa)". 10 | 11 | 12 | 13 | # Prerequisite 14 | 15 | OpenAI API 16 | 17 | Python and its DuckDB API 18 | 19 | Streamlit 20 | 21 | 22 | # Run 23 | 24 | 1. Fill in the token in config.yaml for vector embedding and vector search 25 | 26 | > openai_token: [YOUR_OPENAI_TOKEN] 27 | 28 | 29 | 2. Install DuckDB by following the instruction in the article 30 | 31 | 3. Run streamlit run app.py to start the chatbot 32 | 33 | 4. Play with the chatbot 34 | 35 | ## Authors 36 | 37 | 38 | 39 | * **Sixing Huang** - *Concept and Coding* 40 | 41 | 42 | 43 | 44 | 45 | ## License 46 | 47 | 48 | 49 | 50 | 51 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sixing Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/chain_processor.py: -------------------------------------------------------------------------------- 1 | # File: utils/chain_processor.py 2 | 3 | import streamlit as st 4 | from typing import List, Dict, Any, Union 5 | from utils.my_langchain_tools import * 6 | #from utils.message_handler import store_ai_message 7 | 8 | def process_chain_response(query: str, tool: str, tool_call_id: str, prompt: str) -> None: 9 | """ 10 | Process the response from the LangChain generation process and handle 11 | different tool executions appropriately. 12 | 13 | Args: 14 | response (List[Dict[str, Any]]): The response from generate_query_chain 15 | prompt (str): The original user prompt 16 | """ 17 | #print ("I am in process_chain_response, response", response) 18 | # if not isinstance(response, list) or len(response) == 0: 19 | # st.error("Invalid response from chain. Please try again.") 20 | # return 21 | 22 | # tool_call = response[0] 23 | 24 | # # Store tool information in session state 25 | # st.session_state.tool_name = tool_call["name"] 26 | 27 | 28 | # setup_confirmation_state(tool_call, prompt) 29 | 30 | st.session_state.awaiting_confirmation = True 31 | st.session_state.current_query = query 32 | st.session_state.tool_name = tool 33 | st.session_state.current_chain_input = prompt 34 | st.session_state.tool_call_id = tool_call_id 35 | 36 | 37 | -------------------------------------------------------------------------------- /interesting_q.txt: -------------------------------------------------------------------------------- 1 | double blind & Alzheimer's disease 2 | 3 | What diseases can fluocinolone acetonide treat? Show me 4 results. Use the graph query. 4 | 5 | What diseases can hydroflumethiazide treat? 6 | 7 | What is the MOA of hydroflumethiazide? 8 | 9 | Show 5 liver-related disorder? 10 | 11 | What is the MOA of tretoquinol? 12 | 13 | What is Fascioliasis? 14 | 15 | Show 5 trials that tested drugs against the top 7 liver-related disorders 16 | 17 | What disorder can this drug treat? 18 | 19 | 20 | Colorectal Neoplasms, Peritoneal Neoplasms, Uterine Cervical Neoplasm belong to "Class A diseases". Give me 5 drugs that may treat the Class A diseases 21 | 22 | 23 | Colorectal Neoplasms, Peritoneal Neoplasms, Uterine Cervical Neoplasm belong to "Class A diseases". List the Class A diseases' definitions. using sql 24 | 25 | 26 | The "Three Small Players" pharmaceutical companies are Astellas, Novartis, and ViiV 27 | 28 | Tabulate the amount of trials sponsored by "the Three small Players". 29 | 30 | 31 | 32 | We also define "the Market leader" company as GSK 33 | 34 | 35 | Calculate the amount of trials sponsored by the market leader. 36 | 37 | 38 | 'Four Horsemen' are cardiovascular diseases, carcinoma, alzheimer's disease, and diabetes mellitus 39 | 40 | 41 | cardiovascular diseases, carcinoma, Alzheimer's disease, and diabetes mellitus as the 'Four Horsemen Diseases' -------------------------------------------------------------------------------- /utils/my_llm.py: -------------------------------------------------------------------------------- 1 | # import yaml 2 | # from langchain_openai import ChatOpenAI 3 | # import os 4 | 5 | # with open("config.yaml", "r") as stream: 6 | # try: 7 | # PARAM = yaml.safe_load(stream) 8 | # except yaml.YAMLError as exc: 9 | # print(exc) 10 | 11 | # # Set up your OpenAI API key 12 | # os.environ["OPENAI_API_KEY"] = PARAM['openai_api'] 13 | # llm = ChatOpenAI(model_name="gpt-4o-mini") 14 | 15 | import yaml 16 | from langchain_openai import AzureChatOpenAI 17 | from openai import AzureOpenAI 18 | from langchain_openai import AzureOpenAIEmbeddings 19 | 20 | 21 | #with open("../config.yaml", "r") as stream: 22 | with open("config.yaml", "r") as stream: 23 | try: 24 | PARAM = yaml.safe_load(stream) 25 | except yaml.YAMLError as exc: 26 | print(exc) 27 | 28 | # settings 29 | base_url = PARAM['azure_base_url'] 30 | api_version = PARAM['azure_api_version'] 31 | llm_key = PARAM['azure_key'] 32 | chat_deployment_name = 'explore' 33 | embeddings_deployment_name = 'embeddings' 34 | embeddings_api_version = 'api_version' 35 | 36 | 37 | embeddings = AzureOpenAIEmbeddings( 38 | model="text-embedding-3-small", 39 | azure_endpoint="https://geminidata-enterprise.openai.azure.com/openai/deployments/text-embedding-3-small/embeddings?api-version=2023-05-15", 40 | #openai_api_version=embeddings_api_version, 41 | api_key = llm_key 42 | ) 43 | 44 | llm=AzureChatOpenAI( 45 | azure_endpoint=base_url, 46 | deployment_name=chat_deployment_name, 47 | api_version=api_version, 48 | api_key=llm_key) 49 | 50 | 51 | client = AzureOpenAI( 52 | azure_endpoint=base_url, 53 | api_key=llm_key, 54 | api_version=api_version 55 | ) -------------------------------------------------------------------------------- /interaction.jsonl: -------------------------------------------------------------------------------- 1 | {"input": "Give me 3 liver-related disorders", "query": "SELECT name, definition\n FROM Disorder\n ORDER BY array_distance(definitionEmbedding, embeddings('liver-related disorders')::FLOAT[1536])\n LIMIT 3;", "tool_name": "mimicking"} 2 | {"input": "Show 5 trials and their drugs. At least one of the drugs must be used to treat against the Non-small cell lung carcinoma?", "query": "SELECT Trials.StudyTitle as StudyTitle, drug_for_disease.drug_name\n FROM Trials,\n GRAPH_TABLE(\n drug_graph\n MATCH\n (i:Drug)-[m:MAY_TREAT]->(c:Disorder WHERE LOWER(c.name) = LOWER('Non-small cell lung carcinoma'))\n COLUMNS (i.drug_cui AS drug_cui, i.name AS drug_name)\n ) drug_for_disease\n WHERE list_contains(Trials.drug_cui, drug_for_disease.drug_cui)\n LIMIT 5;", "tool_name": "mimicking"} 3 | {"input": "Show 3 trials that tested drugs against the top 10 joint-related disorders", "query": "SELECT Trials.PostingID AS trial_id, Trials.StudyTitle AS StudyTitle, target_disease.name, drug_for_disorder.drug_name\n FROM Trials,\n (SELECT disorder_cui, name FROM Disorder\n ORDER BY array_distance(definitionEmbedding, embeddings('joint-related disorders')::FLOAT[1536])\n LIMIT 10) target_disease,\n GRAPH_TABLE(drug_graph MATCH (i:Drug)-[m:MAY_TREAT]->(c:Disorder)\n COLUMNS (i.drug_cui AS drug_cui, i.name AS drug_name,c.disorder_cui AS disorder_cui)) drug_for_disorder\n WHERE target_disease.disorder_cui = drug_for_disorder.disorder_cui AND list_contains(Trials.drug_cui, drug_for_disorder.drug_cui)\n LIMIT 3", "tool_name": "mimicking"} 4 | {"input": "Show 4 trials that tested drugs with the MOA of 'UGT1A9 Inhibitors'", "query": "SELECT Trials.PostingID AS trial_id, Trials.StudyTitle AS StudyTitle, drug_with_moa.drug_name FROM Trials,\n GRAPH_TABLE(drug_graph MATCH (i:Drug)-[m:HAS_MOA]->(a:MOA WHERE a.name='UGT1A9 Inhibitors')\n COLUMNS (i.drug_cui AS drug_cui, i.name AS drug_name)) drug_with_moa\n WHERE list_contains(Trials.drug_cui, drug_with_moa.drug_cui) LIMIT 4", "tool_name": "mimicking"} 5 | -------------------------------------------------------------------------------- /evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import autoevals 2 | import os 3 | import yaml 4 | import braintrust 5 | from openai import AsyncOpenAI 6 | 7 | from braintrust import EvalAsync 8 | 9 | 10 | with open("../config.yaml", "r") as stream: 11 | try: 12 | PARAM = yaml.safe_load(stream) 13 | except yaml.YAMLError as exc: 14 | print(exc) 15 | 16 | os.environ["OPENAI_API_KEY"] = PARAM['openai_api'] 17 | 18 | braintrust.login(api_key=PARAM["BRAINTRUST_API_KEY"]) 19 | 20 | 21 | PROMPT = """\ 22 | You are comparing a submitted answer to an expert answer on a given question. Here is the data: 23 | [BEGIN DATA] 24 | ************ 25 | [Question]: {{input}} 26 | ************ 27 | [Expert]: {{expected}} 28 | ************ 29 | [Chatbot_output]: {{output}} 30 | ************ 31 | [END DATA] 32 | 33 | Compare the factual content of the submitted answer with the expert answer. Ignore any differences in style, grammar, or punctuation. 34 | The submitted answer may either be a subset or superset of the expert answer, or it may conflict with it. Determine which case applies. Answer the question by selecting one of the following options: 35 | (A) The submitted answer is a subset of the expert answer and is fully consistent with it. 36 | (B) The submitted answer is a superset of the expert answer and is fully consistent with it. 37 | (C) The submitted answer contains all the same details as the expert answer. 38 | (D) There is a disagreement between the submitted answer and the expert answer. 39 | (E) The answers differ, but these differences don't matter from the perspective of factuality. 40 | 41 | Answer the question by calling `select_choice` with your reasoning in a step-by-step matter to be 42 | sure that your conclusion is correct. Avoid simply stating the correct answer at the outset. Select a 43 | single choice by setting the `choice` parameter to a single choice from A, B, C, D, or E. 44 | """ 45 | 46 | llm_classifier = autoevals.LLMClassifier( 47 | name="Chatbot evaluator", 48 | prompt_template=PROMPT, 49 | choice_scores={"A": 0.5, "B": 1, "C": 1, "D": 0, "E": 1}, 50 | use_cot=True, 51 | ) 52 | 53 | async def task(input): 54 | return await llm_classifier.eval_async( 55 | input=input["question"], 56 | output=input["generated_answer"], 57 | expected=input["expected_answer"], 58 | ) 59 | 60 | 61 | def five_grader(output): 62 | return abs(output.score) 63 | 64 | 65 | 66 | async def run(qa_set, metadata={}): 67 | def data(): 68 | for one_set in qa_set: 69 | yield dict( 70 | input=one_set, expected=0 71 | ) 72 | 73 | 74 | await EvalAsync( 75 | "LLM-as-a-judge", 76 | data=data, 77 | task=task, 78 | scores=[five_grader], 79 | experiment_name="Classifier", 80 | max_concurrency=10, 81 | metadata=metadata, 82 | ) -------------------------------------------------------------------------------- /evaluation/evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import evaluator\n", 10 | "import json" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "qa_set = json.load(open('drugbot_data.json'))" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 3, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "Experiment Classifier-f384ea48 is running at https://www.braintrust.dev/app/Gemini%20Data/p/LLM-as-a-judge/experiments/Classifier-f384ea48\n", 32 | "LLM-as-a-judge [experiment_name=Classifier] (data): 16it [00:00, 62253.12it/s]\n" 33 | ] 34 | }, 35 | { 36 | "data": { 37 | "application/vnd.jupyter.widget-view+json": { 38 | "model_id": "ef69ddf6a19d48d7a6a3a7ddb0cff641", 39 | "version_major": 2, 40 | "version_minor": 0 41 | }, 42 | "text/plain": [ 43 | "LLM-as-a-judge [experiment_name=Classifier] (tasks): 0%| | 0/16 [00:00 None: 7 | """ 8 | Display a user message in the chat interface with appropriate styling and icon. 9 | 10 | Args: 11 | message (HumanMessage): The user's message to display 12 | """ 13 | icon, _ = get_tool_icon_and_description(message) 14 | 15 | with st.chat_message("user", avatar=icon): 16 | st.markdown(message.content) 17 | 18 | def display_assistant_message(message: AIMessage, user_question: str = None) -> None: 19 | """ 20 | Display an assistant message in the chat interface with appropriate styling, 21 | icon, and any tool information. 22 | 23 | Args: 24 | message (AIMessage): The assistant's message to display 25 | user_question (str): The corresponding user question for save functionality 26 | """ 27 | tool_name = (message.tool_calls[0]["name"] 28 | if hasattr(message, 'tool_calls') and message.tool_calls 29 | else "default") 30 | 31 | icon, description = get_tool_icon_and_description(message) 32 | 33 | with st.chat_message("assistant", avatar=icon): 34 | st.markdown(message.content) 35 | 36 | if hasattr(message, 'tool_calls') and message.tool_calls: 37 | with st.expander("See query details"): 38 | st.markdown(f"**Tool Used:** {tool_name.replace('_', ' ')}") 39 | if description: 40 | st.markdown(f"**Description:** {description}") 41 | 42 | if 'args' in message.tool_calls[0]: 43 | args = message.tool_calls[0]['args'] 44 | query = None 45 | 46 | if isinstance(args, dict): 47 | query = (args.get('executed_query') or 48 | args.get('query')) 49 | elif isinstance(args, str): 50 | query = args 51 | 52 | if query: 53 | st.markdown("**Query Used:**") 54 | st.code(query, language='sql') 55 | 56 | 57 | st.markdown('

Would you like to save this interaction?

', 58 | unsafe_allow_html=True) 59 | 60 | if st.button("💾 Save Interaction", 61 | key=f"save_{id(message)}", 62 | help="Save this question and query to the examples file", 63 | type="primary"): 64 | if query: 65 | save_interaction.save_interaction(user_question, query) 66 | st.success("Interaction saved successfully!") 67 | 68 | def display_message_pair(messages: list, i: int) -> bool: 69 | """ 70 | Display a pair of user and assistant messages, if they exist. 71 | 72 | Args: 73 | messages (list): List of all messages 74 | i (int): Current index in the message list 75 | 76 | Returns: 77 | bool: True if an assistant message was displayed after the user message 78 | """ 79 | 80 | #print ("+++++++display_message_pair+++++++++ messages[i]", messages[i]) 81 | 82 | question = "" 83 | if isinstance(messages[i], HumanMessage): 84 | question = messages[i].content 85 | display_user_message(messages[i]) 86 | 87 | if i + 1 < len(messages) and isinstance(messages[i + 1], AIMessage): 88 | 89 | display_assistant_message(messages[i + 1], question) 90 | st.markdown('
', unsafe_allow_html=True) 91 | return True 92 | return False 93 | 94 | def display_chat_messages() -> None: 95 | """ 96 | Display all chat messages with appropriate styling and separation. 97 | """ 98 | #print ("+++++++display_chat_messages+++++++++ st.session_state", st.session_state) 99 | messages = st.session_state.messages 100 | 101 | #print ("+++++++display_chat_messages+++++++++", st.session_state) 102 | i = 0 103 | while i < len(messages): 104 | assistant_displayed = display_message_pair(messages, i) 105 | i += 2 if assistant_displayed else 1 -------------------------------------------------------------------------------- /ui/query_confirmation.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from ui.icons import TOOL_DESCRIPTIONS, TOOL_ICONS, svg_to_base64 3 | #from streamlit_ace import st_ace 4 | from streamlit_monaco import st_monaco 5 | import sqlparse 6 | 7 | def create_query_confirmation_ui() -> str | None: 8 | """ 9 | Create the query confirmation UI and return the confirmed query. 10 | Returns: 11 | str | None: Returns either: 12 | - "waiting" if no button was pressed 13 | - the confirmed/edited query if Confirm was pressed 14 | - None if Reject was pressed 15 | """ 16 | container = st.container() 17 | with container: 18 | # Get the SVG icon for the current tool 19 | 20 | tool_icon = TOOL_ICONS.get(st.session_state.tool_name, TOOL_ICONS["default"]) 21 | icon_b64 = svg_to_base64(tool_icon.strip()) 22 | description = TOOL_DESCRIPTIONS.get(st.session_state.tool_name, "") 23 | 24 | #print ("st.session_state.tool_name", st.session_state.tool_name) 25 | 26 | 27 | ###### AI cannot answer the question without some claification 28 | if st.session_state.tool_name == "clarifying": 29 | print ("in clarifying") 30 | st.markdown(f""" 31 |
32 | 33 |

{st.session_state.current_query}

34 |
35 | 36 | """, unsafe_allow_html=True) 37 | st.markdown('
', unsafe_allow_html=True) 38 | elif st.session_state.tool_name == "memory": 39 | print ("in memory") 40 | st.markdown(f""" 41 |
42 | 43 |

Updated the memory

44 |
45 | """, unsafe_allow_html=True) 46 | st.markdown('
', unsafe_allow_html=True) 47 | #Enough information to answer the question 48 | else: 49 | st.markdown(f""" 50 |
51 | 52 |

This query uses the {st.session_state.tool_name.replace("_", " ")}

53 |
54 | """, unsafe_allow_html=True) 55 | 56 | st.markdown(f"*{description}*") 57 | st.markdown(f"**Review and edit the query if needed, press Confirm to proceed:**") 58 | 59 | # If there was an error in the previous attempt, show it 60 | if st.session_state.last_error: 61 | st.error(f"Error in previous query: {st.session_state.last_error}") 62 | st.markdown("Please fix the query and try again, or reject to move to the next question.") 63 | 64 | # Code editor 65 | #editor_response = st_ace(language='sql', value=st.session_state.current_query) 66 | print ("--------before query_to_edit. st.session_state", st.session_state) 67 | #query_to_edit = f"/*{st.session_state.current_chain_input}*/" 68 | query_to_edit = sqlparse.format(f"/*{st.session_state.current_chain_input}*/\n" + st.session_state.current_query, reindent=True, keyword_case='upper') 69 | if st.session_state.tool_name == "graph": 70 | query_to_edit = f"/*{st.session_state.current_chain_input}*/\n" + st.session_state.current_query 71 | editor_response = st_monaco(value=query_to_edit, height="150px", language="sql") 72 | 73 | # Add spacing after code editor 74 | st.markdown("
", unsafe_allow_html=True) 75 | 76 | # Create a container for buttons with padding bottom to avoid overlap 77 | button_container = st.container() 78 | with button_container: 79 | # Create two columns for the buttons 80 | col1, col2 = st.columns([1, 1]) 81 | 82 | # Confirm button in first column 83 | with col1: 84 | confirm = st.button( 85 | "✓ Confirm", 86 | use_container_width=True, 87 | type="primary", 88 | key="confirm_button" 89 | ) 90 | if confirm: 91 | st.session_state.awaiting_confirmation = False 92 | return editor_response 93 | 94 | # Reject button in second column 95 | with col2: 96 | reject = st.button( 97 | "✕ Reject", 98 | use_container_width=True, 99 | type="secondary", 100 | key="reject_button" 101 | ) 102 | if reject: 103 | st.session_state.awaiting_confirmation = False 104 | st.session_state.last_error = None 105 | st.session_state.retry_count = 0 106 | return None 107 | 108 | # Add spacing at the bottom to prevent overlap with chat input 109 | st.markdown("
", unsafe_allow_html=True) 110 | 111 | return "waiting" -------------------------------------------------------------------------------- /utils/error_handler.py: -------------------------------------------------------------------------------- 1 | # File: utils/error_handler.py 2 | 3 | import streamlit as st 4 | #from typing import Optional 5 | 6 | def handle_query_error(error: Exception) -> None: 7 | """ 8 | Handle errors that occur during query execution, managing retry attempts 9 | and error states. 10 | 11 | Args: 12 | error (Exception): The error that occurred during query execution 13 | """ 14 | # Store the error message in session state 15 | st.session_state.last_error = str(error) 16 | st.session_state.retry_count += 1 17 | 18 | # If we've had too many retries, auto-reject and reset state 19 | if st.session_state.retry_count >= 3: 20 | st.error("Maximum retry attempts reached. Moving to next question.") 21 | clear_error_state() 22 | clear_confirmation_state() 23 | else: 24 | # Keep the confirmation UI open for another try 25 | st.session_state.awaiting_confirmation = True 26 | 27 | def clear_error_state() -> None: 28 | """ 29 | Reset error-related session state variables. 30 | """ 31 | st.session_state.last_error = None 32 | st.session_state.retry_count = 0 33 | 34 | def clear_confirmation_state() -> None: 35 | """ 36 | Reset all confirmation-related session state variables. 37 | """ 38 | st.session_state.awaiting_confirmation = False 39 | st.session_state.current_query = None 40 | st.session_state.current_chain_input = None 41 | st.session_state.tool_name = None 42 | 43 | # def get_friendly_error_message(error: Exception) -> str: 44 | # """ 45 | # Convert technical error messages into user-friendly messages. 46 | 47 | # Args: 48 | # error (Exception): The original error 49 | 50 | # Returns: 51 | # str: A user-friendly error message 52 | # """ 53 | # error_str = str(error) 54 | 55 | # # Common SQL errors 56 | # if "syntax error" in error_str.lower(): 57 | # return "There appears to be a syntax error in the SQL query. Please check the query structure." 58 | # elif "permission denied" in error_str.lower(): 59 | # return "The query doesn't have permission to access this data. Please try a different approach." 60 | # elif "relation does not exist" in error_str.lower(): 61 | # return "The table or view referenced in the query doesn't exist. Please verify the table names." 62 | # elif "column does not exist" in error_str.lower(): 63 | # return "One of the columns referenced in the query doesn't exist. Please check the column names." 64 | # elif "division by zero" in error_str.lower(): 65 | # return "The query attempted to divide by zero. Please check any calculations in the query." 66 | 67 | # # Graph query errors 68 | # elif "cycle detected" in error_str.lower(): 69 | # return "The graph query contains a circular reference. Please modify the query to avoid cycles." 70 | # elif "depth limit exceeded" in error_str.lower(): 71 | # return "The graph query is too deep. Please reduce the number of relationship levels." 72 | 73 | # # Vector query errors 74 | # elif "dimension mismatch" in error_str.lower(): 75 | # return "The vector dimensions don't match. Please ensure you're using the correct embedding size." 76 | # elif "invalid vector format" in error_str.lower(): 77 | # return "The vector format is invalid. Please check the vector representation." 78 | 79 | # # Fulltext query errors 80 | # elif "invalid text query" in error_str.lower(): 81 | # return "The text search query is invalid. Please check the search syntax." 82 | 83 | # # Default case 84 | # return f"An error occurred while executing the query: {error_str}" 85 | 86 | # def handle_tool_specific_error(error: Exception, tool_name: str) -> Optional[str]: 87 | # """ 88 | # Handle tool-specific errors and provide appropriate suggestions. 89 | 90 | # Args: 91 | # error (Exception): The error that occurred 92 | # tool_name (str): The name of the tool that caused the error 93 | 94 | # Returns: 95 | # Optional[str]: A suggestion for fixing the error, if available 96 | # """ 97 | # error_str = str(error).lower() 98 | 99 | # if tool_name == "SQL_QueryTool": 100 | # if "syntax error" in error_str: 101 | # return """ 102 | # Try the following: 103 | # 1. Check for missing or extra commas 104 | # 2. Verify table and column names 105 | # 3. Ensure all SQL keywords are properly used 106 | # 4. Check that parentheses are properly closed 107 | # """ 108 | 109 | # elif tool_name == "Graph_QueryTool": 110 | # if "depth limit" in error_str: 111 | # return """ 112 | # Try to simplify the query by: 113 | # 1. Reducing the number of relationship levels 114 | # 2. Using LIMIT to restrict the result set 115 | # 3. Adding more specific filters 116 | # """ 117 | 118 | # elif tool_name == "Vector_QueryTool": 119 | # if "dimension" in error_str: 120 | # return """ 121 | # Ensure your vector query: 122 | # 1. Uses 1536-dimensional vectors 123 | # 2. Has proper array syntax 124 | # 3. Includes all required vector components 125 | # """ 126 | 127 | # elif tool_name == "Fulltext_QueryTool": 128 | # if "invalid" in error_str: 129 | # return """ 130 | # Check your text search query: 131 | # 1. Use proper text search operators 132 | # 2. Ensure search terms are properly quoted 133 | # 3. Verify language settings if applicable 134 | # """ 135 | 136 | # return None 137 | 138 | # Example usage in the main app: 139 | """ 140 | try: 141 | process_confirmed_query(confirmation_result) 142 | except Exception as e: 143 | handle_query_error(e) 144 | 145 | # Get user-friendly error message 146 | friendly_message = get_friendly_error_message(e) 147 | st.error(friendly_message) 148 | 149 | # Get tool-specific suggestions if available 150 | if suggestion := handle_tool_specific_error(e, st.session_state.tool_name): 151 | st.info("Suggestions for fixing the error:" + suggestion) 152 | """ -------------------------------------------------------------------------------- /utils/message_handler.py: -------------------------------------------------------------------------------- 1 | # File: utils/message_handler.py 2 | 3 | from langchain.schema import AIMessage 4 | from typing import Union, Dict, Any 5 | import streamlit as st 6 | 7 | def store_ai_message(query_response: Union[str, AIMessage, Any], executed_query: str) -> None: 8 | """ 9 | Store an AI message in the session state with proper formatting and tool information. 10 | 11 | Args: 12 | query_response: The response from the query execution, can be: 13 | - str: A simple string response 14 | - AIMessage: A pre-formatted AI message 15 | - Any: Any other type of response that will be converted to string 16 | executed_query: The actual query that was executed 17 | """ 18 | # Create tool call information 19 | tool_call = { 20 | "name": st.session_state.tool_name, 21 | "args": create_tool_args(executed_query), 22 | "id": generate_tool_call_id() 23 | } 24 | 25 | # Handle different types of responses 26 | if isinstance(query_response, str): 27 | ai_message = AIMessage( 28 | content=query_response, 29 | tool_calls=[tool_call] 30 | ) 31 | elif isinstance(query_response, AIMessage): 32 | # If it's already an AIMessage, ensure it has tool calls 33 | if not hasattr(query_response, 'tool_calls'): 34 | query_response.tool_calls = [tool_call] 35 | ai_message = query_response 36 | else: 37 | # For any other type of response, convert to string 38 | ai_message = AIMessage( 39 | content=str(query_response), 40 | tool_calls=[tool_call] 41 | ) 42 | 43 | # Store the message in session state 44 | st.session_state.messages.append(ai_message) 45 | 46 | def create_tool_args(executed_query: str) -> Dict[str, str]: 47 | """ 48 | Create the arguments dictionary for a tool call. 49 | 50 | Args: 51 | executed_query: The query that was executed 52 | 53 | Returns: 54 | Dict containing the query information and context 55 | """ 56 | return { 57 | "query": st.session_state.current_query, 58 | "executed_query": executed_query, 59 | "original_question": st.session_state.current_chain_input 60 | } 61 | 62 | def generate_tool_call_id() -> str: 63 | """ 64 | Generate a unique ID for a tool call. 65 | 66 | Returns: 67 | str: A unique identifier for the tool call 68 | """ 69 | # Get the current number of messages 70 | message_count = len(st.session_state.messages) 71 | return f"call_{message_count + 1}" 72 | 73 | # def format_query_response(response: Any, tool_name: str) -> str: 74 | # """ 75 | # Format a query response based on the tool type. 76 | 77 | # Args: 78 | # response: The raw response from the tool 79 | # tool_name: The name of the tool that generated the response 80 | 81 | # Returns: 82 | # str: Formatted response string 83 | # """ 84 | # if tool_name == "SQL_QueryTool": 85 | # return format_sql_response(response) 86 | # elif tool_name == "Graph_QueryTool": 87 | # return format_graph_response(response) 88 | # elif tool_name == "Vector_QueryTool": 89 | # return format_vector_response(response) 90 | # elif tool_name == "Fulltext_QueryTool": 91 | # return format_fulltext_response(response) 92 | # else: 93 | # return str(response) 94 | 95 | # def format_sql_response(response: Any) -> str: 96 | # """Format SQL query results""" 97 | # if not response: 98 | # return "No results found." 99 | 100 | # if isinstance(response, list): 101 | # # Convert list of results to markdown table 102 | # if not response[0]: 103 | # return "No results found." 104 | 105 | # headers = response[0].keys() 106 | # table = "| " + " | ".join(headers) + " |\n" 107 | # table += "| " + " | ".join(["---" for _ in headers]) + " |\n" 108 | 109 | # for row in response: 110 | # table += "| " + " | ".join(str(row[col]) for col in headers) + " |\n" 111 | # return table 112 | 113 | # return str(response) 114 | 115 | # def format_graph_response(response: Any) -> str: 116 | # """Format graph query results""" 117 | # if not response: 118 | # return "No graph relationships found." 119 | 120 | # if isinstance(response, list): 121 | # # Format graph relationships 122 | # formatted = "Found the following relationships:\n\n" 123 | # for rel in response: 124 | # formatted += f"- {rel['start']} → {rel['type']} → {rel['end']}\n" 125 | # return formatted 126 | 127 | # return str(response) 128 | 129 | # def format_vector_response(response: Any) -> str: 130 | # """Format vector query results""" 131 | # if not response: 132 | # return "No similar items found." 133 | 134 | # if isinstance(response, list): 135 | # # Format similarity results 136 | # formatted = "Found similar items (with similarity scores):\n\n" 137 | # for item in response: 138 | # score = item.get('similarity', 0) 139 | # formatted += f"- {item['content']} (similarity: {score:.2f})\n" 140 | # return formatted 141 | 142 | # return str(response) 143 | 144 | # def format_fulltext_response(response: Any) -> str: 145 | # """Format fulltext search results""" 146 | # if not response: 147 | # return "No matching documents found." 148 | 149 | # if isinstance(response, list): 150 | # # Format text search results 151 | # formatted = "Found the following matches:\n\n" 152 | # for doc in response: 153 | # score = doc.get('score', 0) 154 | # formatted += f"- [{score:.2f}] {doc['content']}\n" 155 | # return formatted 156 | 157 | # return str(response) 158 | 159 | # Example usage: 160 | """ 161 | def process_confirmed_query(query: str) -> None: 162 | with st.spinner("Processing response..."): 163 | # Execute the query 164 | query_response = langchain_tools.execute_query( 165 | st.session_state.current_chain_input, 166 | query 167 | ) 168 | 169 | # Store the response 170 | store_ai_message(query_response, query) 171 | 172 | # Clear states 173 | clear_confirmation_state() 174 | """ -------------------------------------------------------------------------------- /utils/umls_mapper.py: -------------------------------------------------------------------------------- 1 | #from openai import OpenAI 2 | from typing import Set, Dict, List 3 | import json 4 | from pydantic import BaseModel 5 | import yaml 6 | import os 7 | import utils.umls as umls 8 | import utils.my_llm as my_llm 9 | #import umls 10 | #import umls as umls 11 | #import my_llm as my_llm 12 | 13 | #with open("../config.yaml", "r") as stream: 14 | with open("config.yaml", "r") as stream: 15 | try: 16 | PARAM = yaml.safe_load(stream) 17 | except yaml.YAMLError as exc: 18 | print(exc) 19 | 20 | # Set up your OpenAI API key 21 | os.environ["OPENAI_API_KEY"] = PARAM['openai_api'] 22 | 23 | from typing import Set, Dict, List 24 | from pydantic import BaseModel 25 | import re 26 | 27 | class MedicalTerms(BaseModel): 28 | drugs: List[str] 29 | disorders: List[str] 30 | mechanisms: List[str] 31 | 32 | def term_extractor(question: str) -> Set[str]: 33 | """ 34 | Extract only explicitly mentioned medical terms from a question. 35 | 36 | Args: 37 | question (str): The input question to analyze 38 | 39 | Returns: 40 | Set[str]: Set of extracted medical terms that are explicitly present in the input 41 | """ 42 | 43 | print ("*****term_extractor question", question) 44 | try: 45 | # Define the function for term extraction 46 | tools = [ 47 | { 48 | "type": "function", 49 | "function": { 50 | "name": "extract_medical_terms", 51 | "description": "Extract ONLY the drugs, disorders, and mechanisms of action that are EXPLICITLY mentioned in the user input. DO NOT include any terms that aren't word-for-word present in the input.", 52 | "parameters": { 53 | "type": "object", 54 | "properties": { 55 | "drugs": { 56 | "type": "array", 57 | "items": {"type": "string"}, 58 | "description": "List of drug names that appear verbatim in the text. Must be exact matches only." 59 | }, 60 | "disorders": { 61 | "type": "array", 62 | "items": {"type": "string"}, 63 | "description": "List of disorders or medical conditions that appear verbatim in the text. Must be exact matches only." 64 | }, 65 | "mechanisms": { 66 | "type": "array", 67 | "items": {"type": "string"}, 68 | "description": "List of mechanisms of action of drugs that appear verbatim in the text. Must be exact matches only." 69 | } 70 | }, 71 | "required": ["drugs", "disorders", "mechanisms"] 72 | } 73 | } 74 | } 75 | ] 76 | 77 | # Call OpenAI API with function calling 78 | response = my_llm.client.chat.completions.create( 79 | model=my_llm.chat_deployment_name, 80 | messages=[ 81 | { 82 | "role": "system", 83 | "content": """You are a precise medical term extractor. Your ONLY job is to find terms that are explicitly written in the input text. 84 | Rules: 85 | 1. NEVER suggest or infer terms that aren't explicitly present 86 | 2. Only extract exact matches - no synonyms or related terms 87 | 3. If you're unsure if something is a medical term, don't extract it 88 | 4. Ignore hypothetical or question words (e.g., "what drug" should not extract "drug") 89 | 5. Case-sensitive matching only 90 | """ 91 | }, 92 | { 93 | "role": "user", 94 | "content": question 95 | } 96 | ], 97 | tools=tools, 98 | tool_choice={"type": "function", "function": {"name": "extract_medical_terms"}}, 99 | temperature=0 100 | ) 101 | 102 | # Extract and validate the terms 103 | if response.choices[0].message.tool_calls: 104 | tool_call = response.choices[0].message.tool_calls[0] 105 | function_args = eval(tool_call.function.arguments) 106 | 107 | print ("function_args", function_args) 108 | 109 | # Additional validation: Only keep terms that are actually in the input 110 | validated_args = { 111 | 'drugs': [term for term in function_args['drugs'] if term.lower() in question.lower()], 112 | 'disorders': [term for term in function_args['disorders'] if term.lower() in question.lower()], 113 | 'mechanisms': [term for term in function_args['mechanisms'] if term.lower() in question.lower()] 114 | } 115 | 116 | # Validate with Pydantic 117 | terms = MedicalTerms(**validated_args) 118 | 119 | # Combine all terms into a set 120 | terms_set = set() 121 | terms_set.update(terms.drugs) 122 | terms_set.update(terms.disorders) 123 | terms_set.update(terms.mechanisms) 124 | 125 | return terms_set 126 | 127 | return set() 128 | 129 | except Exception as e: 130 | print(f"Error in term extraction: {e}") 131 | return set() 132 | 133 | 134 | def entity_recognition(terms: Set[str]): 135 | results = {} 136 | 137 | try: 138 | for term in terms: 139 | # Call UMLS API for each term 140 | umls_results = umls.search(term, umls_token=PARAM["umls_token"], amount_of_results=1) 141 | 142 | if umls_results: 143 | results[term] = {"name": umls_results[0]["name"], "cui": umls_results[0]["cui"]} 144 | except Exception as e: 145 | print(f"Error in UMLS entity recognition: {e}") 146 | return results 147 | 148 | def expand_question(original_question: str, terms: Dict[str, str]) -> str: 149 | """ 150 | Expand the original question with the extracted medical terms. 151 | 152 | Args: 153 | original_question (str): The original question 154 | terms (Set[str]): Set of extracted medical terms 155 | 156 | Returns: 157 | str: The expanded question 158 | """ 159 | expanded_question = original_question 160 | 161 | for term in terms: 162 | expanded_question = expanded_question.replace(term, f"{term} ({{UMLS_name: \"{terms[term]['name']}\", CUI: \"{terms[term]['cui']}\"}})") 163 | 164 | return expanded_question 165 | 166 | # Example usage 167 | if __name__ == "__main__": 168 | 169 | 170 | # # Replace with your actual OpenAI API key 171 | 172 | # # Test with sample questions 173 | # for question in sample_questions: 174 | # print("\nQuestion:", question) 175 | # terms = term_extractor(question) 176 | # print("Extracted terms:", terms) 177 | 178 | question = "What drug may treat type 2 diabetes? Give me 5 results" 179 | 180 | terms = term_extractor(question) 181 | print ("terms", terms) 182 | 183 | umls_terms = entity_recognition(terms) 184 | print ("umls_terms", umls_terms) 185 | 186 | # Expand the question with UMLS terms 187 | expanded_question = expand_question(question, umls_terms) 188 | 189 | print ("expanded_question", expanded_question) -------------------------------------------------------------------------------- /ui/icons.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from langchain.schema import HumanMessage, AIMessage 3 | 4 | # Tool icons mapping 5 | TOOL_ICONS = { 6 | "sql": """ 7 | 8 | 9 | 11 | SQL 12 | 13 | 14 | """, 15 | 16 | "graph": """ 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | """, 37 | 38 | "vector": """ 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | """, 53 | 54 | "fulltext": """ 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | """, 66 | 67 | "mimicking": """ 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | """, 86 | 87 | "user": """ 88 | 89 | 90 | 91 | 92 | """, 93 | 94 | "default": """ 95 | 96 | 97 | 98 | 99 | 100 | """, 101 | 102 | "clarifying": """ 103 | 104 | 105 | 106 | 107 | 108 | 110 | 111 | 112 | 113 | 114 | """, 115 | 116 | "memory": """ 117 | 118 | 119 | 125 | 126 | 127 | 135 | 136 | """ 137 | 138 | } 139 | 140 | 141 | # Tool descriptions for tooltips 142 | TOOL_DESCRIPTIONS = { 143 | "sql": "For all tables in the DrugDB", 144 | "graph": "For the relation-rich drugs, disorders, and MOA data", 145 | "vector": "Only for the disorder definitions", 146 | "fulltext": "Only for the study titles of clinical trials", 147 | "mimicking": "Generate complex queries by mimicking the examples" 148 | } 149 | 150 | 151 | def svg_to_base64(svg_string: str) -> str: 152 | """Convert SVG string to base64 encoded data URL""" 153 | svg_string = ' '.join(svg_string.split()).strip() 154 | b64 = base64.b64encode(svg_string.encode('utf-8')).decode('utf-8') 155 | return f"data:image/svg+xml;base64,{b64}" 156 | 157 | def get_tool_icon_and_description(message): 158 | """Get the appropriate icon and description for a message based on tool used""" 159 | if isinstance(message, HumanMessage): 160 | return svg_to_base64(TOOL_ICONS["user"].strip()), None 161 | elif isinstance(message, AIMessage): 162 | if hasattr(message, 'tool_calls') and message.tool_calls: 163 | tool_name = message.tool_calls[0]["name"] 164 | return svg_to_base64(TOOL_ICONS.get(tool_name, TOOL_ICONS["default"]).strip()), TOOL_DESCRIPTIONS.get(tool_name) 165 | return svg_to_base64(TOOL_ICONS["default"].strip()), None 166 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | def init_session_state(): 4 | """Initialize session state variables""" 5 | session_vars = { 6 | "messages": [], 7 | "awaiting_confirmation": False, 8 | "current_query": None, 9 | "current_chain_input": None, 10 | "tool_name": None, 11 | "last_error": None, 12 | "retry_count": 0, 13 | "tool_call_id": "dummy_id" 14 | } 15 | 16 | for var, default in session_vars.items(): 17 | if var not in st.session_state: 18 | st.session_state[var] = default 19 | 20 | 21 | def add_button_styles(): 22 | """Add custom CSS styles to the interface""" 23 | st.markdown(""" 24 | 245 | 246 | 247 | """, unsafe_allow_html=True) -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # File: app.py 2 | 3 | ## It has two states: 4 | ## 1. from LangGraph, for internal sanitization, UMLS expansion. The sanitized will only be shown in the comment section of the sql 5 | ## 2. from st.session_state, this is for UI to display the original question and the final answer. 6 | 7 | from config import init_session_state, add_button_styles 8 | from ui.chat_display import display_chat_messages 9 | from ui.query_confirmation import create_query_confirmation_ui 10 | from langchain.schema import HumanMessage 11 | import streamlit as st 12 | import json 13 | from utils.my_langchain_tools import * 14 | from utils.error_handler import handle_query_error, clear_error_state, clear_confirmation_state 15 | from utils.message_handler import store_ai_message 16 | from utils.chain_processor import process_chain_response 17 | import my_db_specifics 18 | from langchain_core.messages import ToolMessage, AIMessage 19 | import graph_definition as gd 20 | import os 21 | #from copy import deepcopy 22 | 23 | config = {"configurable": {"thread_id": "1", "user_id": "sixing"}} 24 | 25 | # Example queries that can be used via buttons 26 | EXAMPLE_QUERIES = [] 27 | 28 | def load_examples(): 29 | for example_list in [ 30 | (my_db_specifics.sql_examples), 31 | (my_db_specifics.graph_examples), 32 | (my_db_specifics.full_text_search_examples), 33 | (my_db_specifics.vector_search_examples) 34 | ]: 35 | for example in example_list: 36 | EXAMPLE_QUERIES.append(example) 37 | 38 | with open("interaction.jsonl", "r") as file: 39 | for line in file: 40 | example = json.loads(line) 41 | EXAMPLE_QUERIES.append(example) 42 | 43 | load_examples() 44 | 45 | def create_example_buttons(): 46 | """Create buttons for example queries in a single column""" 47 | 48 | for idx, example in enumerate(EXAMPLE_QUERIES): 49 | if st.button( 50 | example["input"], 51 | key=f"example_{idx}", 52 | use_container_width=True, 53 | ): 54 | 55 | handle_example_query(example) 56 | st.rerun() 57 | 58 | 59 | 60 | def handle_example_query(example): 61 | """Handle when an example query button is clicked""" 62 | # Add the user's "question" to the chat 63 | st.session_state.messages.append(AIMessage(content=example["input"], additional_kwargs={"function": {"arguments": str({"question": example["input"], "top_k": 5})}})) 64 | st.session_state.messages.append(HumanMessage(content=example["input"])) 65 | 66 | # Set up the confirmation state as if the bot generated this query 67 | st.session_state.awaiting_confirmation = True 68 | st.session_state.current_query = example["query"] 69 | st.session_state.current_chain_input = example["input"] 70 | st.session_state.tool_name = example["tool_name"] 71 | 72 | 73 | def process_confirmed_query(query): 74 | """Process a confirmed query and store the response""" 75 | with st.spinner("Processing confirmed query..."): 76 | #print ("hello", query) 77 | 78 | #print ("in if prompt", app.get_state(config)) 79 | 80 | tool_name = st.session_state.tool_name 81 | 82 | tool_message = [ 83 | { 84 | 85 | "name": tool_name, 86 | "type": "user", 87 | "content": query 88 | } 89 | ] 90 | 91 | gd.app.update_state(config, {"messages": tool_message}, as_node="human_feedback") 92 | print ("process_confirmed_query", gd.app.get_state(config)) 93 | #app.stream(None, config, stream_mode="values") 94 | events = list(gd.app.stream(None, config, stream_mode="values")) 95 | #print ("events", events) 96 | last_event = events[-1] 97 | #print (last_event) 98 | #print ("state", event) 99 | #print (event["messages"][-1].content) 100 | 101 | print ("----question-----") 102 | print (last_event["messages"][-1].additional_kwargs.get("question")) 103 | print ("----final_query-----") 104 | print (last_event["messages"][-1].additional_kwargs.get("query")) 105 | 106 | print ("----query_result-----") 107 | print (last_event["messages"][-1].additional_kwargs.get("execute_result")) 108 | 109 | print ("----answer-----") 110 | print (last_event["messages"][-1].content) 111 | 112 | store_ai_message(last_event["messages"][-1].content, last_event["messages"][-1].additional_kwargs.get("query")) 113 | 114 | clear_confirmation_state() 115 | 116 | def handle_confirmation_result(confirmation_result): 117 | """Handle the result of query confirmation""" 118 | 119 | if confirmation_result == "waiting": 120 | return False 121 | 122 | if confirmation_result is not None: 123 | try: 124 | process_confirmed_query(confirmation_result) 125 | return True 126 | except Exception as e: 127 | handle_query_error(e) 128 | return True 129 | else: 130 | st.warning("Query rejected. Please try a different question.") 131 | clear_error_state() 132 | return True 133 | 134 | def run_chatbot(): 135 | """Main function to run the chatbot interface""" 136 | # Configure the sidebar 137 | 138 | with st.sidebar: 139 | st.markdown("### Example queries you can try:") 140 | create_example_buttons() 141 | # Main chat interface 142 | st.title("DrugBot 💊") 143 | 144 | # Initialize session state 145 | init_session_state() 146 | 147 | print ("********************At the beginning st.session_state", st.session_state) 148 | 149 | # Display chat messages in main area 150 | display_chat_messages() 151 | 152 | 153 | 154 | # Handle confirmation UI if needed 155 | if st.session_state.awaiting_confirmation: 156 | #print ("in awaiting_confirmation", app.get_state(config)) 157 | confirmation_result = create_query_confirmation_ui() 158 | 159 | if handle_confirmation_result(confirmation_result): 160 | st.rerun() 161 | 162 | # Create columns for chat input and dropdown 163 | input_col, dropdown_col, is_expanded_col = st.columns([5, 1, 1]) 164 | #input_col, dropdown_col = st.columns([5, 1]) 165 | 166 | with input_col: 167 | prompt = st.chat_input( 168 | "What would you like to know about the drugs database?", 169 | key="chat_input" 170 | ) 171 | 172 | with dropdown_col: 173 | user_tool = st.selectbox( 174 | "Tools to select:", 175 | options=["Automatic", "SQL", "Graph", "Vector", "Fulltext", "Mimicking"], 176 | key="tool_selector", 177 | label_visibility="collapsed" 178 | ) 179 | 180 | with is_expanded_col: 181 | is_expanded_checkbox = st.checkbox( 182 | "Medical terms to UMLS", 183 | key="is_expanded" 184 | ) 185 | 186 | 187 | if prompt: 188 | 189 | 190 | #print ("input_message", input_message) 191 | try: 192 | expanded_prompt = prompt 193 | 194 | if is_expanded_checkbox: 195 | with st.spinner("Querying UMLS..."): 196 | expanded_prompt = expand_question(prompt) 197 | 198 | input_message = HumanMessage(content=expanded_prompt, tool_choice=user_tool.lower()) 199 | 200 | with st.spinner("Processing response..."): 201 | #st.session_state.messages.append(HumanMessage(content=input_message.content)) 202 | st.session_state.messages.append(HumanMessage(content=prompt)) 203 | for event in gd.app.stream({"messages": [input_message]}, config, stream_mode="values"): 204 | print ("len:", len(event["messages"])) 205 | 206 | #print ("in try if prompt", app.get_state(config).values["messages"]) 207 | 208 | 209 | #st.session_state.tool_name = "clarifying" 210 | #print ("!!!!!!!!!!!!!!!!!!!!!!gd.app.get_state(config).values", gd.app.get_state(config).values) 211 | generated_message = gd.app.get_state(config).values["messages"][-1] 212 | #print ("!!!!!!!!!!!!!!!!!!!!!!generated_message", generated_message) 213 | #### AI needs to ask a clarifying question 214 | if isinstance(generated_message, AIMessage): 215 | tool_name = "clarifying" 216 | tool_call_id = "dummy_tool_id" 217 | 218 | print ("=================================================tool_name", tool_name, "tool_call_id", tool_call_id) 219 | process_chain_response(generated_message.content, tool_name, tool_call_id, expanded_prompt) 220 | elif isinstance(generated_message, ToolMessage): 221 | print ("============================ in prompt, toolmessage, generated_message", generated_message) 222 | generated_query = generated_message.content 223 | #print ("generated_query\n", type(app.get_state(config).values["messages"][-1]), app.get_state(config).values["messages"][-1]) 224 | tool_name = generated_message.name 225 | tool_call_id = generated_message.tool_call_id 226 | 227 | process_chain_response(generated_query, tool_name, tool_call_id, expanded_prompt) 228 | except Exception as e: 229 | st.error(f"Error: {str(e)}") 230 | st.rerun() 231 | 232 | if __name__ == "__main__": 233 | 234 | # checkpoint = "checkpoints.db" 235 | # if os.path.isfile(checkpoint): 236 | # os.remove(checkpoint) 237 | 238 | st.set_page_config( 239 | page_title="DrugBot", 240 | page_icon="💊", 241 | layout="wide", # Make better use of screen width 242 | initial_sidebar_state="expanded" # Start with sidebar visible 243 | ) 244 | add_button_styles() 245 | run_chatbot() -------------------------------------------------------------------------------- /utils/umls.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | from threading import Thread 4 | import queue 5 | import yaml 6 | # with open("config.yaml", "r") as stream: 7 | # try: 8 | # PARAM = yaml.safe_load(stream) 9 | # except yaml.YAMLError as exc: 10 | # print(exc) 11 | 12 | # umls_token = PARAM["umls_token"] 13 | #https://uts-ws.nlm.nih.gov/rest/content/current/source/MSH/D014839/relations?apiKey= 14 | 15 | #https://uts-ws.nlm.nih.gov/rest/content/current/source/HPO/HP:0002013/relations?includeRelationLabels=CHD&apiKey= 16 | done = set() 17 | 18 | def get_all_items(source_name: str, id: str, umls_token: str, includeRelationLabels: str="", additionalRelationLabel: str=""): 19 | extra_parameter = "" 20 | if includeRelationLabels != "": 21 | extra_parameter = f"includeRelationLabels={includeRelationLabels}" 22 | 23 | if additionalRelationLabel != "": 24 | extra_parameter += f"&includeAdditionalRelationLabels={additionalRelationLabel}" 25 | 26 | url = f"https://uts-ws.nlm.nih.gov/rest/content/current/source/{source_name}/{id}/relations?{extra_parameter}&apiKey={umls_token}" 27 | #print (url) 28 | 29 | response = requests.get(url).json() 30 | 31 | if "status" in response and str(response["status"]) == "404": 32 | return [] 33 | else: 34 | #print ("response", response) 35 | 36 | page_count = int(response["pageCount"]) 37 | 38 | #print ("page_count", page_count) 39 | 40 | results = get_item(response, additionalRelationLabel) 41 | 42 | for i in range(2, page_count+1): 43 | new_url = url + f"&pageNumber={i}" 44 | 45 | new_response = requests.get(new_url).json() 46 | 47 | #print ("new_url", new_url) 48 | #print ("new_response", new_response) 49 | 50 | results += get_item(new_response) 51 | 52 | return results 53 | 54 | 55 | 56 | def get_item(response: str, additionalRelationLabel: str=""): 57 | 58 | result = [] 59 | 60 | if "result" not in response: 61 | return result 62 | 63 | else: 64 | for r in response["result"]: 65 | if r["classType"] == "AtomClusterRelation" and "relatedId" in r and "relatedIdName" in r: 66 | 67 | 68 | if additionalRelationLabel == "" or (additionalRelationLabel != "" and r["additionalRelationLabel"] == additionalRelationLabel): 69 | item_parent_id = r["relatedId"].split("/")[-1] 70 | item_name = r["relatedFromIdName"] 71 | item_parent_name = r["relatedIdName"] 72 | 73 | result.append((item_name, item_parent_id, item_parent_name)) 74 | 75 | 76 | return result 77 | 78 | def get_relation(subcategory_id: str, subcategory: str, umls_token: str, includeRelationLabels: str="", additionalRelationLabel: str=""): 79 | entities = [] 80 | for entity in get_all_items(subcategory, subcategory_id, umls_token, includeRelationLabels, additionalRelationLabel): 81 | #print (entity) 82 | entities.append({"id": entity[1], "name": entity[2]}) 83 | return entities 84 | 85 | def recursive_get_subcategory_parent(subcategory_id: str, subcategory: str, umls_token: str, includeRelationLabels: str="", additionalRelationLabel: str=""): 86 | #print ("done", done) 87 | 88 | in_queue = queue.Queue() 89 | out_queue = queue.Queue() 90 | 91 | threads = 1 92 | 93 | def work(): 94 | while True: 95 | son_subcategory_id = in_queue.get() 96 | #print ("son_subcategory_id", son_subcategory_id, get_all_items(subcategory, son_subcategory_id, "includeRelationLabels=CHD&includeAdditionalRelationLabels=isa", umls_token)) 97 | 98 | #print (done) 99 | if son_subcategory_id not in done: 100 | #for entity in get_all_items("HPO", son_item_id, "includeRelationLabels=CHD&includeAdditionalRelationLabels=isa", umls_token): 101 | #print ("here", get_all_items(subcategory, son_subcategory_id, extra_parameter, umls_token)) 102 | for entity in get_all_items(subcategory, son_subcategory_id, umls_token, includeRelationLabels, additionalRelationLabel): 103 | out_queue.put((son_subcategory_id, entity[0], entity[1], entity[2])) 104 | 105 | in_queue.put(entity[1]) 106 | done.add(son_subcategory_id) 107 | 108 | in_queue.task_done() 109 | 110 | for i in range(threads): 111 | 112 | t = Thread(target=work) 113 | t.daemon = True 114 | t.start() 115 | 116 | 117 | in_queue.put(subcategory_id) 118 | 119 | 120 | in_queue.join() 121 | 122 | result = [] 123 | 124 | while not out_queue.empty(): 125 | 126 | result.append( out_queue.get()) 127 | 128 | #print result 129 | 130 | return result 131 | 132 | def search(query_name: str, umls_token: str, partial: bool=False, amount_of_results: int=3): 133 | 134 | url = f"https://uts-ws.nlm.nih.gov/rest/search/current?string={query_name}&apiKey={umls_token}&partialSearch={str(partial).lower()}" 135 | 136 | response = requests.get(url).json() 137 | 138 | #print ("response", response, type(response)) 139 | 140 | if "status" in response and str(response["status"]) == "404": 141 | return [] 142 | 143 | else: 144 | 145 | 146 | 147 | if "result" in response and "results" in response["result"]: 148 | result = [] 149 | 150 | for r in response["result"]["results"]: 151 | if len(result) < amount_of_results and "ui" in r and "name" in r: 152 | temp = {} 153 | temp["cui"] = r["ui"] 154 | temp["name"] = r["name"] 155 | result.append(temp) 156 | return result 157 | else: 158 | return [] 159 | 160 | def get_subcategory_id(umls: str, subcategory: str, umls_token: str): 161 | url = f"https://uts-ws.nlm.nih.gov/rest/content/current/CUI/{umls}/atoms?&apiKey={umls_token}&sabs={subcategory}" 162 | 163 | response = requests.get(url).json() 164 | 165 | if "status" in response and str(response["status"]) == "404": 166 | return "" 167 | 168 | else: 169 | 170 | result = "" 171 | 172 | if "result" in response: 173 | 174 | for r in response["result"]: 175 | if "rootSource" in r and r["rootSource"] == subcategory and "code" in r: 176 | #if "rootSource" in r and r["rootSource"] == "MSH" and "code" in r: 177 | 178 | result = r["code"].split("/")[-1] 179 | 180 | return result 181 | else: 182 | 183 | return result 184 | 185 | def get_semantic_type(cui: str, umls_token: str): 186 | url = f"https://uts-ws.nlm.nih.gov/rest/content/current/CUI/{cui}?apiKey={umls_token}" 187 | 188 | response = requests.get(url).json() 189 | 190 | if "status" in response and str(response["status"]) == "404": 191 | return [] 192 | 193 | else: 194 | 195 | 196 | 197 | if "result" in response: 198 | result = [] 199 | 200 | if "semanticTypes" in response["result"]: 201 | for s in response["result"]["semanticTypes"]: 202 | result.append(s["name"]) 203 | 204 | return result 205 | else: 206 | 207 | return [] 208 | 209 | def get_name(MSH: str, umls_token: str): 210 | url = f"https://uts-ws.nlm.nih.gov/rest/content/current/source/MSH/{MSH}?apiKey={umls_token}" 211 | 212 | response = requests.get(url).json() 213 | 214 | #print ("response", response, type(response)) 215 | 216 | if "status" in response and str(response["status"]) == "404": 217 | return "" 218 | 219 | else: 220 | 221 | result = "" 222 | if "result" in response and "name" in response["result"]: 223 | 224 | result = response["result"]["name"] 225 | 226 | return result 227 | else: 228 | 229 | return result 230 | 231 | 232 | 233 | def get_substance_ATC_by_cui (cui: str, umls_token: str): 234 | atc_code = get_subcategory_id(cui, "ATC", umls_token) 235 | return atc_code 236 | 237 | def get_substance_ATC_by_name (substrance: str, umls_token: str): 238 | cui = search(substrance, umls_token) 239 | atc_code = get_substance_ATC_by_cui(cui, umls_token) 240 | return atc_code 241 | 242 | def get_cui_definition(cui: str, umls_token: str): 243 | url = f"https://uts-ws.nlm.nih.gov/rest/content/current/CUI/{cui}/definitions?apiKey={umls_token}" 244 | response = requests.get(url).json() 245 | 246 | 247 | if "status" in response and str(response["status"]) == "404": 248 | return "" 249 | 250 | else: 251 | result = "" 252 | if "result" in response: 253 | 254 | for r in response["result"]: 255 | if "rootSource" in r and r["rootSource"] == "MSH" and"value" in r: 256 | 257 | result = r["value"] 258 | 259 | return result 260 | else: 261 | 262 | return result 263 | 264 | 265 | if __name__ == "__main__": 266 | #print(get_parent_HPO("HP:0000001", umls_token)) 267 | #print(recursive_get_parent_HPO("HP:0011458", umls_token)) 268 | 269 | #### get may_be_prevented_by drugs of a disease 270 | #result = get_all_items("MSH", "D014839", "includeAdditionalRelationLabels=may_be_prevented_by", umls_token) 271 | 272 | #assert len(result) == 40 273 | #print (result) 274 | 275 | #### recursive get all parent HPO 276 | #print ("\n\n") 277 | #print(recursive_get_parent_HPO("HP:0011458", umls_token)) 278 | ### 279 | #print (get_all_items("MSH", "D015282", "includeAdditionalRelationLabels=isa", umls_token)) 280 | #print (search("Gsk2256098", umls_token)) 281 | #print (get_MSH("C2981865", umls_token)) 282 | 283 | #print (get_name("C000604998", umls_token)) 284 | #print (get_all_items("MSH", "C000604998", "includeAdditionalRelationLabels=mapped_to", umls_token)) 285 | 286 | #name = get_subcategory_id("C0754187", "MSH", umls_token) 287 | 288 | #print (name) 289 | 290 | #print(get_substance_ATC_by_name("pralsetinib", umls_token)) 291 | #print (search("ACYCLOVIR", umls_token)) 292 | #print (get_semantic_type("C0001367", umls_token)) 293 | 294 | print (get_cui_definition("C0019151", umls_token)) -------------------------------------------------------------------------------- /graph_definition.py: -------------------------------------------------------------------------------- 1 | 2 | from langgraph.prebuilt import ToolNode, tools_condition 3 | from langgraph.checkpoint.sqlite import SqliteSaver 4 | from langgraph.graph import END, StateGraph, MessagesState, START 5 | import utils.my_langchain_tools as my_langchain_tools 6 | import sqlite3 7 | from langchain_openai import ChatOpenAI 8 | from langgraph.store.memory import InMemoryStore 9 | from typing import TypedDict, Literal, Dict 10 | from langchain_openai import ChatOpenAI 11 | import streamlit as st 12 | from langchain_core.messages import trim_messages 13 | 14 | import utils.memory_handler as my_memory 15 | 16 | from langchain_core.messages import SystemMessage, ToolMessage, HumanMessage, AIMessage 17 | 18 | import json 19 | import utils.my_llm as my_llm 20 | 21 | #llm = ChatOpenAI(model="gpt-4o", temperature=0) 22 | 23 | 24 | all_tools = [my_langchain_tools.sql, my_langchain_tools.graph, my_langchain_tools.mimicking, my_langchain_tools.vector, my_langchain_tools.fulltext] 25 | all_tools_category = {"sql": [my_langchain_tools.sql], "graph": [my_langchain_tools.graph], "mimicking": [my_langchain_tools.mimicking], 26 | "vector": [my_langchain_tools.vector], "fulltext": [my_langchain_tools.fulltext], "automatic": all_tools} 27 | 28 | class State(MessagesState): 29 | selected_tools: list[str] 30 | #umls_terms: Dict[str, str] 31 | 32 | 33 | 34 | # Define the function that calls the model 35 | def choose_query_tool(state, config, store): 36 | MODEL_SYSTEM_MESSAGE = """You are a helpful assistant tasked with performing Q&A with a drug-trial DuckDB as backend. 37 | Your task is to call the right query tool and forward the user's question to it. 38 | IMPORTANT: do not modify or generate a query yourself. That is the job of the query tools. 39 | Capture both the question and the amount of results 'top_k' that the user want to see. If the user does not specify the amount, don't ask the user back to clarify and just set top_k = 30. 40 | Only one query tool for one question. Do not break the question into multiple parts. 41 | If the user has defined some concepts or terms, use them in your query faithfully to personalize your responses. 42 | Here is the memory (it may be empty): {memory}""" 43 | 44 | print("---call_tool_to_generate_query---") 45 | 46 | # Get selected tools 47 | selected_tools = state["selected_tools"] 48 | obj_tools = [] 49 | for t in selected_tools: 50 | if t in all_tools_category: 51 | obj_tools = all_tools_category[t] 52 | 53 | print("---call_tool_to_generate_query---\nobj_tools:", obj_tools) 54 | model = my_llm.llm.bind_tools(obj_tools) 55 | 56 | # Get user context and memories 57 | user_id = config["configurable"]["user_id"] 58 | namespace = ("concept", user_id) 59 | existing_items = store.search(namespace) 60 | existing_memories = ([existing_item.value 61 | for existing_item in existing_items] 62 | if existing_items 63 | else None) 64 | 65 | # Format memories 66 | format_memories = "\n" 67 | if existing_memories: 68 | for memo in existing_memories: 69 | format_memories += f" {memo.get('name')}: {str(memo.get('items'))}\n" 70 | 71 | # Prepare messages 72 | system_msg = MODEL_SYSTEM_MESSAGE.format(memory=format_memories) 73 | messages = state["messages"] 74 | print("++++++++---In choose_query_tool\nmessages:", messages) 75 | 76 | # Get recent message history 77 | last_human_msg = [] 78 | for m in messages[::-1]: 79 | if m.type in ["human", "ai"]: 80 | if m.content.strip(): 81 | last_human_msg.insert(0, m) 82 | if len(last_human_msg) == 10: 83 | break 84 | 85 | print("++++++++---last_msg", last_human_msg) 86 | 87 | # Get tool selection from the model 88 | tool_selection = model.invoke([SystemMessage(content=system_msg)] + last_human_msg) 89 | print("++++++++---In choose_query_tool, tool_selection:", tool_selection) 90 | 91 | 92 | return {"messages": [tool_selection]} 93 | 94 | 95 | def limit_query_tool(state, config, store): 96 | """ You have five tools to choose from to answer the user's question. 97 | sql covers all the tables and should be prefered. 98 | graph covers the relationships among drugs, disorders and MOA. 99 | vector covers the disorder definition. 100 | fulltext covers the trials' StudyTitles. 101 | mimicking uses user defined query templates and is good for complex queries. automatic means you choose the best tool.""" 102 | print ("---limit_query_tool---") 103 | tool_call_id = "" 104 | for m in state["messages"][::-1]: 105 | if m.type == "ai": 106 | #print ("---limit_query_tool---", m.type, m) 107 | tool_call_id = m.additional_kwargs["tool_calls"][0]["id"] 108 | break 109 | 110 | 111 | print ("I captured the tool_call_id", tool_call_id) 112 | state["messages"].append(ToolMessage(content='', tool_call_id=tool_call_id)) 113 | 114 | for m in state["messages"][::-1]: 115 | if m.type == "human": 116 | tool_choice = m.tool_choice 117 | if tool_choice in all_tools_category: 118 | return {"selected_tools": [tool_choice]} 119 | else: 120 | return {"selected_tools": ["automatic"]} 121 | #return {"selected_tools": } 122 | 123 | MODEL_SYSTEM_MESSAGE = """You are a helpful chatbot. 124 | 125 | You are designed to be a companion to a user, helping them get answer from a drug-trial database. 126 | 127 | You have a long term memory which keeps track one thing: 128 | User defined concepts (terms and their definitions or examples provided by the user, such as the 'Four Horsemen' are cardiovascular disease, cancer, neurodegenerative disease, and foundational disease) 129 | 130 | Here is the current User Concepts (may be empty if no information has been collected yet): 131 | 132 | {user_definition} 133 | 134 | 135 | Here are your instructions for reasoning about the user's messages: 136 | 137 | 1. Reason carefully about the user's messages as presented below. 138 | 139 | 2. Must take one and only one of the following actions, never more than one!!!!!!! Do NOT try to answer the question yourself, it is the job of the query tools: 140 | - If the message looks like a definition, update the user's definition by calling the update_concept tool 141 | - If the message looks like a question or a request, use the limit_query_tool route to generate a query 142 | 143 | 3. Tell the user that you have updated your memory, if appropriate: 144 | - Tell the user them when you update the concept list""" 145 | 146 | class Choose_Direction(TypedDict): 147 | """ Decision on which route to go next """ 148 | action_type: Literal['update_concept', 'limit_query_tool'] 149 | 150 | def select_intent(state, config, store): 151 | """Load user defined concepts from the store and use them to personalize the chatbot's response.""" 152 | 153 | # Get the user ID from the config 154 | user_id = config["configurable"]["user_id"] 155 | 156 | # Retrieve profile memory from the store 157 | namespace = ("profile", user_id) 158 | memories = store.search(namespace) 159 | if memories: 160 | user_definition = memories[0].value 161 | else: 162 | user_definition = None 163 | 164 | system_msg = MODEL_SYSTEM_MESSAGE.format(user_definition=user_definition) 165 | 166 | print ("Before response") 167 | #print ("""---select_intent---\n state["messages"]:""", state) 168 | #for m in state["messages"]: 169 | # print ("---select_intent---", m) 170 | 171 | #messages = filter_messages(state["messages"]) 172 | 173 | #messages = state["messages"] 174 | # messages = trim_messages( 175 | # state["messages"], 176 | # max_tokens=32000, 177 | # strategy="last", 178 | # token_counter=ChatOpenAI(model="gpt-4o"), 179 | # allow_partial=False, 180 | # ) 181 | #messages = filter_messages(state["messages"]) 182 | #response = llm.bind_tools([Choose_Direction], parallel_tool_calls=False).invoke([SystemMessage(content=system_msg)]+state["messages"]) 183 | 184 | response = my_llm.llm.bind_tools([Choose_Direction], parallel_tool_calls=False).invoke([SystemMessage(content=system_msg)] + state["messages"]) 185 | 186 | print ("---select_intent---\nresponse: ", response) 187 | #print ("+++++---select_intent---+++\nresponse: ", response, "\nconfig: ", config, "\nstore: ", store) 188 | return {"messages": [response]} 189 | 190 | 191 | 192 | def route_message(state, config, store) -> Literal[END, "update_concept", "limit_query_tool"]: 193 | 194 | """Reflect on the memories and chat history to decide whether to update the memory collection.""" 195 | message = state['messages'][-1] 196 | print ("---route_message---\n message:", state['messages']) 197 | print ("len(message.tool_calls) = ", len(message.tool_calls)) 198 | if len(message.tool_calls) ==0: 199 | return END 200 | else: 201 | tool_call = message.tool_calls[0] 202 | print ("---route_message---\n tool_call:", tool_call) 203 | if tool_call['args']["action_type"] == "update_concept": 204 | print ("***********update_concept") 205 | return "update_concept" 206 | elif tool_call['args']["action_type"] == "limit_query_tool": 207 | print ("***********limit_query_tool") 208 | return "limit_query_tool" 209 | else: 210 | raise ValueError 211 | 212 | def human_feedback(state): 213 | print("---human_feedback state---", state) 214 | print("---human_confirms_query---") 215 | pass 216 | 217 | 218 | 219 | builder = StateGraph(State) 220 | # Define the three nodes we will cycle between 221 | 222 | builder.add_node("select_intent", select_intent) 223 | 224 | #builder.add_node("intents", intents) 225 | 226 | builder.add_node("update_concept", my_memory.update_concept) 227 | 228 | 229 | builder.add_node("limit_query_tool", limit_query_tool) 230 | builder.add_node("choose_query_tool", choose_query_tool) 231 | tool_node = ToolNode(all_tools) 232 | builder.add_node("tools", tool_node) 233 | #workflow.add_node("ask_human", ask_human) 234 | builder.add_node("human_feedback", human_feedback) 235 | builder.add_node("execute_query_and_answer", my_langchain_tools.execute_query_and_answer) 236 | 237 | builder.add_edge(START, "select_intent") 238 | builder.add_conditional_edges("select_intent", route_message) 239 | builder.add_edge("update_concept", END) 240 | 241 | builder.add_edge("limit_query_tool", "choose_query_tool") 242 | builder.add_conditional_edges("choose_query_tool", tools_condition, path_map=["tools", "__end__"]) 243 | 244 | #workflow.add_edge("tools", "execute_query_and_answer") 245 | builder.add_edge("tools", "human_feedback") 246 | 247 | # After we get back the human response, we go back to the choose_query_tool 248 | builder.add_edge("human_feedback", "execute_query_and_answer") 249 | 250 | builder.add_edge("execute_query_and_answer", END) 251 | 252 | #memory = MemorySaver() 253 | 254 | # if os.path.exists("checkpoints.db"): 255 | # os.remove("checkpoints.db") 256 | db_path = 'checkpoints.db' 257 | conn = sqlite3.connect(db_path, check_same_thread=False) 258 | within_thread_memory = SqliteSaver(conn) 259 | #within_thread_memory = MemorySaver() 260 | 261 | # Finally, we compile it! 262 | # This compiles it into a LangChain Runnable, 263 | # meaning you can use it as you would any other runnable 264 | # We add a breakpoint BEFORE the `ask_human` node so it never executes 265 | across_thread_memory = InMemoryStore() 266 | app = builder.compile(checkpointer=within_thread_memory, interrupt_before=["human_feedback"], store=across_thread_memory) -------------------------------------------------------------------------------- /my_db_specifics.py: -------------------------------------------------------------------------------- 1 | property_graph_definition = """CREATE PROPERTY GRAPH drug_graph 2 | VERTEX TABLES ( 3 | Drug, Disorder, MOA 4 | ) 5 | EDGE TABLES ( 6 | DrugDisorder SOURCE KEY (drug_cui) REFERENCES Drug (drug_cui) 7 | DESTINATION KEY (disorder_cui) REFERENCES Disorder (disorder_cui) 8 | LABEL MAY_TREAT, 9 | DrugMOA SOURCE KEY (drug_cui) REFERENCES Drug (drug_cui) 10 | DESTINATION KEY (moa_id) REFERENCES MOA (moa_id) 11 | LABEL HAS_MOA 12 | );""" 13 | 14 | 15 | initialization_commands = ["LOAD duckpgq;", "LOAD fts;", "LOAD vss;"] 16 | 17 | 18 | 19 | sql_examples = [ 20 | { "input": "How many drugs are there", 21 | "query": 'SELECT COUNT("drug_cui") AS "drug_count" FROM "Drug"', 22 | "tool_name": "sql" 23 | }, 24 | { 25 | "input": "What is the MOA of abiraterone", 26 | "query": """SELECT MOA.name 27 | FROM DrugMOA, Drug, MOA 28 | WHERE DrugMOA.drug_cui = Drug.drug_cui 29 | AND DrugMOA.moa_id = MOA.moa_id 30 | AND LOWER(Drug.name) = LOWER('abiraterone');""", 31 | "tool_name": "sql" 32 | }, 33 | { 34 | "input": "What diseases can fluocinolone acetonide treat?", 35 | "query": """SELECT Disorder.name 36 | FROM DrugDisorder, Drug, Disorder 37 | WHERE DrugDisorder.drug_cui = Drug.drug_cui 38 | AND DrugDisorder.disorder_cui = Disorder.disorder_cui 39 | AND LOWER(Drug.name) = LOWER('fluocinolone acetonide');""", 40 | "tool_name": "sql" 41 | }, 42 | 43 | { 44 | "input": "Count all the trials with 'Fluticasone propionate' by sponsor and then by phase?", 45 | "query": """SELECT Sponsor, Phase, COUNT(PostingID) AS trial_count 46 | FROM Trials, Drug 47 | WHERE LOWER(Drug.name) = LOWER('Fluticasone propionate') AND list_contains(Trials.drug_cui, Drug.drug_cui) 48 | ORDER BY Sponsor, Phase;""", 49 | "tool_name": "sql" 50 | }, 51 | ] 52 | 53 | sql_database_prompt = """ 54 | The Drug table contains information about drugs. Each row represents a drug and has the following columns: 55 | - drug_cui: The unique UMLS identifier for the drug. 56 | - name: The name of the drug. 57 | 58 | The Disorder table contains information about disorders. Each row represents a disorder and has the following columns: 59 | - disorder_cui: The unique UMLS identifier for the disorder. 60 | - name: The name of the disorder. 61 | - definition: The definition of the disorder provided by UMLS. 62 | - definitionEmbedding: The 1536 vector embedding of the definition of the disorder. 63 | 64 | The MOA table contains information about the mechanism of action of drugs. Each row represents a mechanism of action and has the following columns: 65 | - moa_id: The unique UMLS identifier for the mechanism of action. 66 | - name: The name of the mechanism of action. 67 | 68 | The DrugMOA table contains information about the relationship between drugs and their mechanisms of action. Each row represents that a drug has the mechanism of action and has the following columns: 69 | - drug_cui: The unique UMLS identifier for the drug. 70 | - moa_id: The unique UMLS identifier for the mechanism of action. 71 | 72 | The DrugDisorder table contains information about the relationship between drugs and disorders. Each row represents that a drug may be used to treat a disorder and has the following columns: 73 | - drug_cui: The unique UMLS identifier for the drug. 74 | - disorder_cui: The unique UMLS identifier for the disorder. 75 | 76 | The Trials table contains information about 2000+ clinical trials. Each row represents that a trial. 77 | - PostingID: The unique id for the trial. 78 | - Sponsor: The sponsor behind the trial. 79 | - StudyTitle: The title of the trial. I have created a full-text search index on this column. For example, you can use the match_bm25 function to search for relevant trials whose StudyTitle contain "double blind & Valaciclovir". 80 | SELECT PostingID as trial_id, StudyTitle, score 81 | FROM ( 82 | SELECT *, fts_main_Trials.match_bm25( 83 | PostingID, 84 | 'double blind & Valaciclovir', 85 | fields := 'StudyTitle' 86 | ) AS score 87 | FROM Trials 88 | ) 89 | WHERE score IS NOT NULL 90 | ORDER BY score DESC LIMIT 5; 91 | - Disorder: The disorder that the trial focuses on. 92 | - Phase: The phase of the trial. 93 | - LinkToSponsorStudyRegistry: The link to the sponsor's study registry. 94 | - LinkToClinicalTrials: The link to the clinicaltrials.gov website. 95 | - drug_cui: This column contains the drug_cui for the "Drug" that the trial tested. Use this column to join with the "Drug" table. 96 | - drug_names: This column contains the drug names for the "cui" that the trial focuses on. 97 | """ 98 | 99 | graph_examples = [ 100 | { 101 | "input": "What is the MOA of Nicotinamide? Use the graph query.", 102 | "query": """FROM GRAPH_TABLE (drug_graph 103 | MATCH 104 | (d:Drug WHERE LOWER(d.name) = LOWER('Nicotinamide'))-[h:HAS_MOA]->(m:MOA) 105 | COLUMNS (m.name AS moa_name) 106 | ); 107 | """, 108 | "tool_name": "graph" 109 | }, 110 | { 111 | "input": "Which drugs can be used to treat Alzheimer's Disease? Only give me all results without limit.", 112 | "query": """FROM GRAPH_TABLE (drug_graph 113 | MATCH 114 | (i:Drug)-[m:MAY_TREAT]->(c:Disorder WHERE LOWER(c.name) = LOWER('Alzheimer''s Disease')) 115 | COLUMNS (i.name AS drug_name) 116 | ); 117 | """, 118 | "tool_name": "graph" 119 | }, 120 | { 121 | "input": "What is the mechanism of action of drugs that can treat Alzheimer's Disease? Give me 5 drugs and their MOA.", 122 | "query": """FROM GRAPH_TABLE (drug_graph 123 | MATCH 124 | (mo:MOA)<-[h:HAS_MOA]-(i:Drug)-[m:MAY_TREAT]->(c:Disorder WHERE LOWER(c.name) = LOWER('Alzheimer''s Disease')) 125 | COLUMNS (i.name AS drug_name, mo.name AS moa_name) 126 | ); 127 | """, 128 | "tool_name": "graph" 129 | }, 130 | { 131 | "input": "What disorders can drugs with the 'GABA B receptor interactions' MOA treat? Give me 3 drugs and their targeting disorders.", 132 | "query": """FROM GRAPH_TABLE (drug_graph 133 | MATCH 134 | (mo:MOA WHERE LOWER(mo.name) = LOWER('GABA B receptor interactions'))<-[h:HAS_MOA]-(i:Drug)-[m:MAY_TREAT]->(c:Disorder) 135 | COLUMNS (i.name AS drug_name, c.name AS disorder_name) 136 | ); 137 | """, 138 | "tool_name": "graph" 139 | }, 140 | { 141 | "input": "What diseases can etoposide treat? Only give me all results without limit.", 142 | "query": """FROM GRAPH_TABLE (drug_graph 143 | MATCH 144 | (i:Drug WHERE LOWER(i.name) = LOWER('etoposide'))-[m:MAY_TREAT]->(c:Disorder) 145 | COLUMNS (c.name AS disorder_name) 146 | ); 147 | """, 148 | "tool_name": "graph" 149 | }, 150 | { 151 | "input": "What is Fascioliasis?", 152 | "query": """FROM GRAPH_TABLE (drug_graph 153 | MATCH 154 | (c:Disorder WHERE LOWER(c.name) = LOWER('Fascioliasis')) 155 | COLUMNS (c.definition AS disorder_definition) 156 | ); 157 | """, 158 | "tool_name": "graph" 159 | }, 160 | { 161 | "input": "What drugs' names contain the 'libenclam'?", 162 | "query": """FROM GRAPH_TABLE (drug_graph 163 | MATCH (i:Drug WHERE LOWER(i.name) LIKE '%' || LOWER('libenclam') || '%') 164 | COLUMNS (i.name AS drug_name) 165 | ); 166 | """, 167 | "tool_name": "graph" 168 | } 169 | 170 | ] 171 | 172 | graph_database_prompt = """ 173 | The Drug table contains information about drugs. Each row represents a drug and has the following columns: 174 | - drug_cui: The unique UMLS identifier for the drug. 175 | - name: The name of the drug. 176 | 177 | The Disorder table contains information about disorders. Each row represents a disorder and has the following columns: 178 | - disorder_cui: The unique UMLS identifier for the disorder. 179 | - name: The name of the disorder. 180 | - definition: The definition of the disorder provided by UMLS. 181 | - definitionEmbedding: The 1536 vector embedding of the definition of the disorder. 182 | 183 | The MOA table contains information about the mechanism of action of drugs. Each row represents a mechanism of action and has the following columns: 184 | - moa_id: The unique UMLS identifier for the mechanism of action. 185 | - name: The name of the mechanism of action. 186 | 187 | The DrugMOA table contains information about the relationship between drugs and their mechanisms of action. Each row represents that a drug has the mechanism of action and has the following columns: 188 | - drug_cui: The unique UMLS identifier for the drug. 189 | - moa_id: The unique UMLS identifier for the mechanism of action. 190 | 191 | The DrugDisorder table contains information about the relationship between drugs and disorders. Each row represents that a drug may be used to treat a disorder and has the following columns: 192 | - drug_cui: The unique UMLS identifier for the drug. 193 | - disorder_cui: The unique UMLS identifier for the disorder. 194 | 195 | There is a PROPERTY GRAPH called "drug_graph". This graph is defined as follows: 196 | VERTEX TABLES ( 197 | Drug, Disorder, MOA 198 | ) 199 | EDGE TABLES ( 200 | DrugDisorder SOURCE KEY (drug_cui) REFERENCES Drug (drug_cui) 201 | DESTINATION KEY (disorder_cui) REFERENCES Disorder (disorder_cui) 202 | LABEL MAY_TREAT, 203 | DrugMOA SOURCE KEY (drug_cui) REFERENCES Drug (drug_cui) 204 | DESTINATION KEY (moa_id) REFERENCES MOA (moa_id) 205 | LABEL HAS_MOA 206 | ); 207 | """ 208 | 209 | full_text_search_examples = [ 210 | { "input": """Search 3 trials with these words in their titles "double blind & Valaciclovir".""", 211 | "query": """SELECT PostingID as trial_id, StudyTitle, score 212 | FROM ( 213 | SELECT *, fts_main_Trials.match_bm25( 214 | PostingID, 215 | 'double blind & Valaciclovir', 216 | fields := 'StudyTitle' 217 | ) AS score 218 | FROM Trials 219 | ) 220 | WHERE score IS NOT NULL 221 | ORDER BY score DESC LIMIT 3;""", 222 | "tool_name": "fulltext" 223 | }, 224 | ] 225 | 226 | full_text_search_query_template = """SELECT PostingID as trial_id, StudyTitle, score 227 | FROM ( 228 | SELECT *, fts_main_Trials.match_bm25( 229 | PostingID, 230 | '{original_question}', 231 | fields := '{field}' 232 | ) AS score 233 | FROM Trials 234 | ) 235 | WHERE score IS NOT NULL 236 | ORDER BY score DESC LIMIT {limit};""" 237 | 238 | 239 | 240 | 241 | vector_search_query_template = """SELECT name, definition 242 | FROM Disorder 243 | ORDER BY array_distance(definitionEmbedding, embeddings('{question}')::FLOAT[1536]) 244 | LIMIT {limit};""" 245 | 246 | 247 | vector_search_examples = [ 248 | { "input": "Show 3 joint-related disorder?", 249 | "query": vector_search_query_template.format(question ="joint-related disorder", limit=3), 250 | "tool_name": "vector" 251 | }, 252 | 253 | ] 254 | -------------------------------------------------------------------------------- /evaluation/drugbot_data.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question": "What is the MOA of hydroflumethiazide?", 4 | "generated_answer": "Sure, here is the information formatted into a sentence:\n\nThe mechanism of action (MOA) of hydroflumethiazide (CUI: C0020273) is that it belongs to the class of \"Sodium Chloride Symporter Inhibitors\".", 5 | "expected_answer": "Sodium Chloride Symporter Inhibitors" 6 | }, 7 | { 8 | "question": "What diseases can hydroflumethiazide treat?", 9 | "generated_answer": "Sure, here are the results formatted into sentences:\n\nHydroflumethiazide (CUI: C0020273) can treat the following diseases:\n1. Nephrotic Syndrome\n2. Hypertensive disease\n3. Heart failure\n4. Edema\n5. Glomerulonephritis\n6. Kidney Failure\n7. Liver Cirrhosis", 10 | "expected_answer": "Hypertensive disease, Glomerulonephritis, Heart failure, Edema, Kidney Failure, Nephrotic Syndrome, Liver Cirrhosis" 11 | }, 12 | { 13 | "question": "Can fluocinolone acetonide treat Facial Dermatoses?", 14 | "generated_answer": "Sure, I can format the information into sentences for easy understanding:\n\n---\n\n### Query Result:\n\n- **Drug:** fluocinolone acetonide\n- **Condition:** Facial Dermatoses\n\n### Interpretation:\n\nThe query result indicates that \"fluocinolone acetonide\" has a relationship with \"Facial Dermatoses.\"\n\n---\n\nIf you prefer a table format, here it is:\n\n| Drug | Condition |\n|-----------------------|--------------------|\n| fluocinolone acetonide | Facial Dermatoses |\n\n---\n\nThis presentation maintains the integrity of the original data from the query result.", 15 | "expected_answer": "Yes" 16 | }, 17 | { 18 | "question": "What is the MOA of abiraterone?", 19 | "generated_answer": "To understand the mechanism of action (MOA) of abiraterone (CUI: C0754011), here are the results formatted into sentences:\n\n1. Abiraterone acts as a \"Cytochrome P450 2C8 Inhibitors\".\n2. Abiraterone acts as a \"Cytochrome P450 2D6 Inhibitors\".\n3. Abiraterone acts as a \"Cytochrome P450 17A1 Inhibitors\".\n\nAlternatively, the results can be presented in a table format:\n\n| Mechanism of Action |\n|----------------------------------------------|\n| Cytochrome P450 2C8 Inhibitors |\n| Cytochrome P450 2D6 Inhibitors |\n| Cytochrome P450 17A1 Inhibitors |\n\nThese results indicate the specific cytochrome P450 enzymes that abiraterone inhibits.", 20 | "expected_answer": "Cytochrome P450 17A1 Inhibitors, Cytochrome P450 2D6 Inhibitors, Cytochrome P450 2C8 Inhibitors" 21 | }, 22 | { 23 | "question": "double blind & Alzheimer's disease", 24 | "generated_answer": "Here are the results formatted into sentences for better understanding:\n\n1. A Preliminary, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease. (Score: 4.241119052925081)\n2. A 15-week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease. (Score: 4.241119052925081)\n3. A 30-Week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease. (Score: 4.241119052925081)\n4. A 30-Week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease. (Score: 4.241119052925081)\n5. A 54-week, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Effects of Donepezil Hydrochloride (E2020) on Functional Outcomes in Patients with Alzheimer's Disease. (Score: 4.148404157417646)\n6. A Randomized, Double-Blind, Placebo-Controlled Evaluation of the Effects of Donepezil Hydrochloride (E2020) in the Management of Patients with Alzheimer's Disease in a Nursing Home Facility. (Score: 4.059656212167762)\n7. Efficacy and Safety of the Rivastigmine Transdermal Patch in Patients With Probable Alzheimer's Disease. (Score: 3.949404716274111)\n8. Clinical efficacy and safety of donepezil on cognitive and global function in patients with Alzheimer's disease. A 24-week, multicenter, double-blind, placebo-controlled study in Japan. (Score: 3.893084565453792)\n9. A 24-weeks, Multi-center, Randomized, Double-blind, Placebo Controlled Study to Evaluate the Efficacy and Safety of Donepezil Hydrochloride in Chinese Subjects With Severe Alzheimer's Disease. (Score: 3.814821642474924)\n10. A 24-Week Efficacy, Safety and Tolerability of Rivastigmine Patch Study in Patients With Probable Alzheimer's Disease. (Score: 3.760182960813248)", 25 | "expected_answer": "'A Preliminary, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.', 'A 15-week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.', 'A 30-Week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.', 'A 30-Week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.', 'A 54-week, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Effects of Donepezil Hydrochloride (E2020) on Functional Outcomes in Patients with Alzheimer's Disease.'" 26 | }, 27 | { 28 | "question": "What drug may treat type 2 diabetes?", 29 | "generated_answer": "No results found.", 30 | "expected_answer": "glibenclamide, insulin (human), acarbose, dulaglutide, insulin lispro, acetohexamide, insulin degludec, insulin aspart, linagliptin, pioglitazone, glimepiride, insulin glargine, empagliflozin, pramlintide, miglitol, repaglinide, tirzepatide, lixisenatide, rosiglitazone, insulin glulisine, glipizide, tolbutamide, tolazamide, canagliflozin, alogliptin, dapagliflozin, chlorpropamide, exenatide, saxagliptin, metformin" 31 | }, 32 | { 33 | "question": "What drug may treat \u591a\u767a\u51fa\u8840\u6027\u8089\u816b?", 34 | "generated_answer": "Certainly! Here are the results formatted into sentences for better understanding:\n\n1. Vinblastine may treat Kaposi Sarcoma.\n2. Dactinomycin may treat Kaposi Sarcoma.\n3. Peginterferon alfa-2b may treat Kaposi Sarcoma.\n4. Alitretinoin may treat Kaposi Sarcoma.\n5. Etoposide may treat Kaposi Sarcoma.\n6. Interferon alfa-2b may treat Kaposi Sarcoma.\n7. Interferon alfa-2a may treat Kaposi Sarcoma.\n8. Peginterferon alfa-2a may treat Kaposi Sarcoma.\n9. Paclitaxel may treat Kaposi Sarcoma.\n10. Interferon alfacon-1 may treat Kaposi Sarcoma.\n11. Vinblastine sulfate may treat Kaposi Sarcoma.\n12. Doxorubicin hydrochloride liposome may treat Kaposi Sarcoma.\n\nAlternatively, here is the information presented in a table:\n\n| Drug |\n|-----------------------------------|\n| Vinblastine |\n| Dactinomycin |\n| Peginterferon alfa-2b |\n| Alitretinoin |\n| Etoposide |\n| Interferon alfa-2b |\n| Interferon alfa-2a |\n| Peginterferon alfa-2a |\n| Paclitaxel |\n| Interferon alfacon-1 |\n| Vinblastine sulfate |\n| Doxorubicin hydrochloride liposome|\n\nThese drugs may treat Kaposi Sarcoma (CUI: C0036220).", 35 | "expected_answer": "vinblastine, dactinomycin, peginterferon alfa-2b, alitretinoin, etoposide, interferon alfa-2b, interferon alfa-2a, peginterferon alfa-2a, paclitaxel, interferon alfacon-1, vinblastine sulfate, doxorubicin hydrochloride liposome" 36 | }, 37 | { 38 | "question": "\u591a\u767a\u51fa\u8840\u6027\u8089\u816b\u306e\u6cbb\u7642\u85ac\u306b\u3064\u3044\u3066?", 39 | "generated_answer": "Sure, here are the results formatted into a table for better understanding:\n\n| Treatment for Kaposi Sarcoma (CUI: C0036220) |\n|---------------------------------------------------------|\n| vinblastine |\n| dactinomycin |\n| peginterferon alfa-2b |\n| alitretinoin |\n| etoposide |\n| interferon alfa-2b |\n| interferon alfa-2a |\n| peginterferon alfa-2a |\n| paclitaxel |\n| interferon alfacon-1 |\n| vinblastine sulfate |\n| doxorubicin hydrochloride liposome |\n\nThese are the treatment options for Kaposi Sarcoma according to the provided query results.", 40 | "expected_answer": "vinblastine, dactinomycin, peginterferon alfa-2b, alitretinoin, etoposide, interferon alfa-2b, interferon alfa-2a, peginterferon alfa-2a, paclitaxel, interferon alfacon-1, vinblastine sulfate, doxorubicin hydrochloride liposome" 41 | }, 42 | { 43 | "question": "What disorder can 1,1-Dimethylbiguanide treat?", 44 | "generated_answer": "Certainly! Here is the information formatted into sentences for better readability:\n\nMetformin (CUI: C0025598) can be used to treat the following disorder: \n\n- \"Diabetes Mellitus, Non-Insulin-Dependent\"\n\nAlternatively, here is the information presented in a table format:\n\n| Medication (CUI) | Disorder Treated |\n|-------------------------|----------------------------------------|\n| Metformin (CUI: C0025598)| Diabetes Mellitus, Non-Insulin-Dependent |\n\nThis ensures that the quoted strings from the query result remain unchanged.", 45 | "expected_answer": "Diabetes Mellitus, Non-Insulin-Dependent" 46 | }, 47 | { 48 | "question": "What diseases can aflibercept treat?", 49 | "generated_answer": "Here are the diseases that aflibercept (CUI: C1134659) can treat:\n\n1. \"Macular retinal edema\"\n2. \"Diabetic Retinopathy\"", 50 | "expected_answer": "macular edema and diabetic retinopathy" 51 | }, 52 | { 53 | "question": "Which substances can be used to treat Macular Edema?", 54 | "generated_answer": "Sure, here is the information formatted into sentences:\n\nThe substances that can be used to treat Macular retinal edema (CUI: C0271051) are:\n\n1. Flurbiprofen\n2. Aflibercept\n3. Ranibizumab", 55 | "expected_answer": "flurbiprofen, ranibizumab and aflibercept" 56 | }, 57 | { 58 | "question": "What ATC category does donepezil have?", 59 | "generated_answer": "Here is the formatted result based on the given query:\n\n**Question:**\nWhat ATC category does donepezil (CUI: C0527316) have?\n\n**Answer:**\nDonepezil (CUI: C0527316) belongs to the \"Cholinesterase Inhibitors\" category.", 60 | "expected_answer": "N06DA02" 61 | }, 62 | { 63 | "question": "What kind of MOA does donepezil have?", 64 | "generated_answer": "Donepezil (CUI: C0527316) has the following mechanism of action (MOA):\n\n- It belongs to the class of \"Cholinesterase Inhibitors\".", 65 | "expected_answer": "Cholinesterase Inhibitors" 66 | }, 67 | { 68 | "question": "Do acetohexamide and pioglitazone share an MOA? If yes, what is it?", 69 | "generated_answer": "Based on the query results, it appears that acetohexamide (CUI: C0000992) and pioglitazone (CUI: C0071097) share a mechanism of action (MOA). Both are classified under \"Insulin Receptor Agonists.\"\n\nIn sentence form:\nAcetohexamide and pioglitazone share a mechanism of action. They are both classified as \"Insulin Receptor Agonists.\"\n\nIn table form:\n\n| Drug | CUI | Mechanism of Action |\n|---------------|-----------|--------------------------|\n| Acetohexamide | C0000992 | Insulin Receptor Agonists|\n| Pioglitazone | C0071097 | Insulin Receptor Agonists|", 70 | "expected_answer": "Insulin Receptor Agonists" 71 | }, 72 | { 73 | "question": "Give me a substance that may treat Type 1 Diabetes and has \"glargine\" in its name", 74 | "generated_answer": "Sure, here is the information formatted into a table and sentences:\n\n### Table Format\n\n| Substance Name |\n|----------------------|\n| insulin glargine |\n\n### Sentence Format\n\nThe substance that may treat Diabetes Mellitus, Insulin-Dependent (CUI: C0011854) and has 'glargine' in its name is 'insulin glargine'.", 75 | "expected_answer": "insulin glargine" 76 | }, 77 | { 78 | "question": "Which substance that may treat Type 1 Diabetes but does not have the MOA Insulin Receptor Agonists", 79 | "generated_answer": "", 80 | "expected_answer": "pramlintide" 81 | } 82 | ] -------------------------------------------------------------------------------- /evaluation/qa_set.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question": "What is the MOA of hydroflumethiazide?", 4 | "query": "FROM GRAPH_TABLE (drug_graph\n MATCH\n (d:Drug WHERE LOWER(d.name) = LOWER('hydroflumethiazide'))-[h:HAS_MOA]->(m:MOA)\n COLUMNS (m.name AS moa_name)\n ) LIMIT 30;", 5 | "expected_answer": "Sodium Chloride Symporter Inhibitors", 6 | "draft_query": "FROM GRAPH_TABLE (drug_graph\n MATCH\n (d:Drug WHERE LOWER(d.name) = LOWER('hydroflumethiazide'))-[h:HAS_MOA]->(m:MOA)\n COLUMNS (m.name AS moa_name)) LIMIT 5;", 7 | "generated_answer": "The mechanism of action (MOA) of hydroflumethiazide is categorized as follows: Inhibitor Type: Sodium Chloride Symporter Inhibitors. This indicates that hydroflumethiazide works by inhibiting the sodium chloride symporter." 8 | }, 9 | { 10 | "question": "What diseases can hydroflumethiazide treat?", 11 | "query": "FROM GRAPH_TABLE (drug_graph \n MATCH \n (i:Drug WHERE LOWER(i.name) = LOWER('hydroflumethiazide'))-[m:MAY_TREAT]->(c:Disorder) \n COLUMNS (c.name AS disorder_name)) LIMIT 30;", 12 | "expected_answer": "Hypertensive disease, Glomerulonephritis, Heart failure, Edema, Kidney Failure, Nephrotic Syndrome, Liver Cirrhosis", 13 | "draft_query": "FROM GRAPH_TABLE (drug_graph \n MATCH \n (i:Drug WHERE LOWER(i.name) = LOWER('hydroflumethiazide'))-[m:MAY_TREAT]->(c:Disorder) \n COLUMNS (c.name AS disorder_name) \n ) \n LIMIT 30;", 14 | "generated_answer": "Hypertensive disease\nGlomerulonephritis\nHeart failure\nEdema\nKidney Failure\nNephrotic Syndrome" 15 | }, 16 | { 17 | "question": "Can fluocinolone acetonide treat Facial Dermatoses?", 18 | "query":"FROM GRAPH_TABLE (drug_graph \n(i:Drug WHERE LOWER(i.name) = LOWER('fluocinolone acetonide'))-[m:MAY_TREAT]->(c:Disorder) \n COLUMNS (c.name AS disorder_name) \n ) \nLIMIT 4;", 19 | "expected_answer": "Yes", 20 | "draft_query": "FROM GRAPH_TABLE (drug_graph \n(i:Drug WHERE LOWER(i.name) = LOWER('fluocinolone acetonide'))-[m:MAY_TREAT]->(c:Disorder) \n COLUMNS (c.name AS disorder_name) \n ) \nLIMIT 4;", 21 | "generated_answer": "Leg Dermatoses\nPsoriasis\nPruritus\nFacial Dermatoses" 22 | }, 23 | { 24 | "question": "What is the MOA of abiraterone?", 25 | "query": "SELECT MOA.name\n FROM DrugMOA,\n Drug,\n MOA\n WHERE DrugMOA.drug_cui = Drug.drug_cui\n AND DrugMOA.moa_id = MOA.moa_id\n AND LOWER(Drug.name) = LOWER('abiraterone');", 26 | "expected_answer": "Cytochrome P450 17A1 Inhibitors, Cytochrome P450 2D6 Inhibitors, Cytochrome P450 2C8 Inhibitors", 27 | "draft_query": "SELECT MOA.name\n FROM DrugMOA,\n Drug,\n MOA\n WHERE DrugMOA.drug_cui = Drug.drug_cui\n AND DrugMOA.moa_id = MOA.moa_id\n AND LOWER(Drug.name) = LOWER('abiraterone');", 28 | "generated_answer": "Based on the query results, here is a formatted presentation of the information:\n### Diseases Related to Fluocinolone Acetonide\n| Disease/Condition |\n|------------------------------------|\n| Cytochrome P450 17A1 Inhibitors |\n| Cytochrome P450 2D6 Inhibitors |\n| Cytochrome P450 2C8 Inhibitors | \nThese entries represent types of inhibitors associated with fluocinolone acetonide." 29 | }, 30 | { 31 | "question": "double blind & Alzheimer's disease", 32 | "query": "SELECT PostingID AS trial_id, StudyTitle, score\n FROM (SELECT *,\n fts_main_Trials.match_bm25(PostingID, 'double blind & Alzheimer''s disease', fields := 'StudyTitle') AS score\n FROM Trials)\n WHERE score IS NOT NULL\n ORDER BY score DESC\n LIMIT 10;", 33 | "expected_answer": "'A Preliminary, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.', 'A 15-week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.', 'A 30-Week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.', 'A 30-Week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.', 'A 54-week, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Effects of Donepezil Hydrochloride (E2020) on Functional Outcomes in Patients with Alzheimer's Disease.'", 34 | "draft_query": "SELECT PostingID AS trial_id, StudyTitle, score\n FROM (SELECT *,\n fts_main_Trials.match_bm25(PostingID, 'double blind & Alzheimer''s disease', fields := 'StudyTitle') AS score\n FROM Trials)\n WHERE score IS NOT NULL\n ORDER BY score DESC\n LIMIT 10;", 35 | "generated_answer": "Here is a summary of the query results formatted into a table for easier understanding:\n| Study ID | Title | Score |\n|----------|---------------------------------------------------------------------------------------------------------------------|----------------|\n| 19971 | 'A Preliminary, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.' | 4.2411 |\n| 19972 | 'A 15-week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.' | 4.2411 |\n| 19973 | 'A 30-Week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.' | 4.2411 |\n| 19974 | 'A 30-Week, Multi-Center, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Safety and Efficacy of E2020 in Patients with Alzheimer's Disease.' | 4.2411 |\n| 19978 | 'A 54-week, Randomized, Double-Blind, Placebo-Controlled Evaluation of the Effects of Donepezil Hydrochloride (E2020) on Functional Outcomes in Patients with Alzheimer's Disease.' | 4.1484 |\n| 19977 | 'A Randomized, Double-Blind, Placebo-Controlled Evaluation of the Effects of Donepezil Hydrochloride (E2020) in the Management of Patients with Alzheimer's Disease in a Nursing Home Facility.' | 4.0597 |\n| 4688 | 'Efficacy and Safety of the Rivastigmine Transdermal Patch in Patients With Probable Alzheimer's Disease' | 3.9494 |\n| 19970 | 'Clinical efficacy and safety of donepezil on cognitive and global function in patients with Alzheimer's disease. A 24-week, multicenter, double-blind, placebo-controlled study in Japan.' | 3.8931 |\n| 19980 | 'A 24-weeks, Multi-center, Randomized, Double-blind, Placebo Controlled Study to Evaluate the Efficacy and Safety of Donepezil Hydrochloride in Chinese Subjects With Severe Alzheimer's Disease' | 3.8148 |\n| 21103 | 'A 24-Week Efficacy, Safety and Tolerability of Rivastigmine Patch Study in Patients With Probable Alzheimer's Disease' | 3.7602 |\n This table lists various studies related to Alzheimer's disease that employed a double-blind methodology, along with their respective titles and scores." 36 | }, 37 | { 38 | "question": "What drug may treat type 2 diabetes?", 39 | "generated_answer": "", 40 | "expected_answer": "glibenclamide, insulin (human), acarbose, dulaglutide, insulin lispro, acetohexamide, insulin degludec, insulin aspart, linagliptin, pioglitazone, glimepiride, insulin glargine, empagliflozin, pramlintide, miglitol, repaglinide, tirzepatide, lixisenatide, rosiglitazone, insulin glulisine, glipizide, tolbutamide, tolazamide, canagliflozin, alogliptin, dapagliflozin, chlorpropamide, exenatide, saxagliptin, metformin" 41 | }, 42 | { 43 | "question": "What drug may treat 多発出血性肉腫?", 44 | "generated_answer": "", 45 | "expected_answer": "vinblastine, dactinomycin, peginterferon alfa-2b, alitretinoin, etoposide, interferon alfa-2b, interferon alfa-2a, peginterferon alfa-2a, paclitaxel, interferon alfacon-1, vinblastine sulfate, doxorubicin hydrochloride liposome" 46 | }, 47 | { 48 | "question": "多発出血性肉腫の治療薬について?", 49 | "generated_answer": "", 50 | "expected_answer": "vinblastine, dactinomycin, peginterferon alfa-2b, alitretinoin, etoposide, interferon alfa-2b, interferon alfa-2a, peginterferon alfa-2a, paclitaxel, interferon alfacon-1, vinblastine sulfate, doxorubicin hydrochloride liposome" 51 | }, 52 | { 53 | "question": "What disorder can 1,1-Dimethylbiguanide treat?", 54 | "generated_answer": "", 55 | "expected_answer": "Diabetes Mellitus, Non-Insulin-Dependent" 56 | }, 57 | { 58 | "question": "What diseases can aflibercept treat?", 59 | "generated_answer": "Aflibercept can treat the following diseases:\n\n| Substance | Condition |\n|-------------|-----------------------|\n| aflibercept | Macular Edema |\n| aflibercept | Diabetic Retinopathy |", 60 | "expected_answer": "macular edema and diabetic retinopathy" 61 | }, 62 | { 63 | "question": "Which substances can be used to treat Macular Edema?", 64 | "generated_answer": "The substances that can be used to treat **Macular Edema** are:\n\n| Substance | CUI | ATC Codes | ATC URLs | Level |\n|--------------|-----------|------------------------------------------------|--------------------------------------------------------------------------------------------------------------|-------|\n| ranibizumab | C1566537 | S01LA04 | [ATC/S01LA04](http://purl.bioontology.org/ontology/ATC/S01LA04) | 5 |\n| aflibercept | C1134659 | L01XX44, S01LA05 | [ATC/L01XX44](http://purl.bioontology.org/ontology/ATC/L01XX44), [ATC/S01LA05](http://purl.bioontology.org/ontology/ATC/S01LA05) | 5 |\n| flurbiprofen | C0016377 | M01AE09, S01BC04, M02AA19, R02AX01 | [ATC/M01AE09](http://purl.bioontology.org/ontology/ATC/M01AE09), [ATC/S01BC04](http://purl.bioontology.org/ontology/ATC/S01BC04), [ATC/M02AA19](http://purl.bioontology.org/ontology/ATC/M02AA19), [ATC/R02AX01](http://purl.bioontology.org/ontology/ATC/R02AX01) | 5 |\n\n### Mechanisms of Action\n\n| Substance | Mechanism of Action |\n|--------------|---------------------------------------------------------------|\n| ranibizumab | Vascular Endothelial Growth Factor Receptor Inhibitors |\n| ranibizumab | Vascular Endothelial Growth Factor-directed Antibody Interactions |\n| ranibizumab | Vascular Endothelial Growth Factor Inhibitors |\n| aflibercept | Vascular Endothelial Growth Factor Inhibitors |\n| flurbiprofen | Cyclooxygenase Inhibitors |\n\n### Products Containing These Substances\n\n| Product Name | Substance | Condition Name |\n|-------------------------------------------------------------------|--------------|----------------|\n| Eylea aflibercept (rch) 40 mg/mL solution for intravitreal injection vial | aflibercept | Macular Edema |\n| EYLEA 114.3 mg/ml solution for injection | aflibercept | Macular Edema |", 65 | "expected_answer": "flurbiprofen, ranibizumab and aflibercept" 66 | }, 67 | { 68 | "question": "What ATC category does donepezil have?", 69 | "generated_answer": "Donepezil belongs to the **Anticholinesterase anti-dementia drugs** category in the ATC classification system. Here are the details:\n\n| Substance | ATC Code | ATC Category | ATC URL |\n|-----------------------------------|----------|-------------------------------------|-----------------------------------------------------|\n| donepezil | N06DA02 | Anticholinesterase anti-dementia drugs | [ATC N06DA](http://purl.bioontology.org/ontology/ATC/N06DA) |\n| donepezil and memantine | N06DA52 | Anticholinesterase anti-dementia drugs | [ATC N06DA](http://purl.bioontology.org/ontology/ATC/N06DA) |\n| donepezil, memantine and Ginkgo folium | N06DA53 | Anticholinesterase anti-dementia drugs | [ATC N06DA](http://purl.bioontology.org/ontology/ATC/N06DA) |\n\nDonepezil is primarily used to treat **Alzheimer Disease**.", 70 | "expected_answer": "N06DA02" 71 | }, 72 | { 73 | "question": "What kind of MOA does donepezil have?", 74 | "generated_answer": "Donepezil is associated with the mechanism of action (MOA) of being a **Cholinesterase Inhibitor**.", 75 | "expected_answer": "Cholinesterase Inhibitors" 76 | }, 77 | { 78 | "question": "Do acetohexamide and pioglitazone share an MOA? If yes, what is it?", 79 | "generated_answer": "Yes, acetohexamide and pioglitazone share a mechanism of action (MOA). The shared MOA is **Insulin Receptor Agonists**.\n\nHere is a summary of the shared MOA:\n\n| Substance | Shared MOA |\n|----------------|--------------------------|\n| acetohexamide | Insulin Receptor Agonists|\n| pioglitazone | Insulin Receptor Agonists|", 80 | "expected_answer": "Insulin Receptor Agonists" 81 | }, 82 | { 83 | "question": "Give me a substance that may treat Type 1 Diabetes and has \"glargine\" in its name", 84 | "generated_answer": "A substance that may treat Type 1 Diabetes and has *glargine* in its name is **insulin glargine**.\n\nHere are some details about *insulin glargine*:\n\n| **Substance Name** | **ATC Code** | **Condition** | **Mechanism of Action** |\n|--------------------|--------------|---------------------------------|------------------------------|\n| insulin glargine | A10AE04 | Diabetes Mellitus, Type 1 | Insulin Receptor Agonists |\n\nAdditionally, there is a product containing *insulin glargine*:\n\n| **Product Name** | **Chinese Name** | **License Number** |\n|-----------------------------------------------|-----------------------------------|-----------------------------------|\n| Soliqua solution for injection 100 units+50\u03bcg | \u723d\u80f0\u9054\u6ce8\u5c04\u5291100\u55ae\u4f4d+50\u5fae\u514b | \u885b\u90e8\u83cc\u75ab\u8f38\u5b57\u7b2c001080\u865f |\n\nIf you need more information, feel free to ask!", 85 | "expected_answer": "insulin glargine" 86 | }, 87 | { 88 | "question": "Which substance that may treat Type 1 Diabetes but does not have the MOA Insulin Receptor Agonists", 89 | "generated_answer": "The substance that may treat **Diabetes Mellitus, Type 1** but does not have the **MOA Insulin Receptor Agonists** is **pramlintide**. \n\nHere is the detailed information:\n\n| Substance | ATC Code | CUI | ATC URL |\n|------------|-----------|-----------|------------------------------------------------------|\n| pramlintide | A10BX05 | C0537551 | [ATC A10BX05](http://purl.bioontology.org/ontology/ATC/A10BX05) |\n\n**Mechanism of Action**: Amylin Agonists", 90 | "expected_answer": "pramlintide" 91 | } 92 | ] -------------------------------------------------------------------------------- /utils/my_langchain_tools.py: -------------------------------------------------------------------------------- 1 | #import duckdb 2 | import streamlit as st 3 | from langchain_community.utilities import SQLDatabase 4 | #from langchain_openai import ChatOpenAI 5 | import yaml 6 | #import os 7 | from langchain_core.example_selectors import SemanticSimilarityExampleSelector 8 | #from langchain_openai import OpenAIEmbeddings 9 | from langchain_openai import AzureOpenAIEmbeddings 10 | from langchain_core.output_parsers import StrOutputParser 11 | #from operator import itemgetter 12 | #from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool 13 | from sqlalchemy import create_engine 14 | from langchain_core.tools import tool 15 | from langchain_core.prompts import FewShotPromptTemplate 16 | from langchain_core.prompts import PromptTemplate 17 | import duckdb 18 | import json 19 | from langchain_community.vectorstores import LanceDB 20 | from langchain.chains import create_sql_query_chain 21 | from typing import Dict 22 | #from langchain_core.messages import AIMessage 23 | # from langchain_core.runnables import ( 24 | # Runnable, 25 | # RunnableLambda, 26 | # RunnableMap, 27 | # RunnablePassthrough, 28 | # ) 29 | 30 | import utils.umls_mapper as umls_mapper 31 | import my_db_specifics as my_db_specifics 32 | import utils.my_llm as my_llm 33 | 34 | 35 | print ("---my_langchain_tools---, st.session_state", st.session_state) 36 | 37 | with open("config.yaml", "r") as stream: 38 | try: 39 | PARAM = yaml.safe_load(stream) 40 | except yaml.YAMLError as exc: 41 | print(exc) 42 | 43 | 44 | 45 | 46 | def expand_question(question: str) -> str: 47 | """Expand the original question with the extracted medical terms.""" 48 | #print ("---expand_question---") 49 | #messages = filter_messages(state["messages"]) 50 | 51 | print ("---expand_question---\n question:", question) 52 | terms = umls_mapper.term_extractor(question) 53 | 54 | print ("---expand_question---\n terms:", terms) 55 | 56 | umls_terms = umls_mapper.entity_recognition(terms) 57 | 58 | expanded_version = question 59 | for t in umls_terms: 60 | 61 | expanded_version = expanded_version.replace(t, f"{umls_terms[t]['name']} (CUI: {umls_terms[t]['cui']})") 62 | 63 | 64 | return expanded_version 65 | 66 | @tool 67 | def mimicking(question: str, top_k: int): 68 | """ When you think the question is unlikely to be answer by a single simple query tool, 69 | or the question may likely require a complex combination of sql, vector, graph, and full-text search tools, 70 | or it may require to join several tables, use this tool to generate those complex queries by closely mimicing the examples.""" 71 | examples = [] 72 | 73 | print ("in mimicking") 74 | 75 | for line in open("interaction.jsonl", "r").readlines(): 76 | example = json.loads(line) 77 | examples.append(example) 78 | 79 | print ("examples", json.dumps(examples, indent=2)) 80 | database_description = my_db_specifics.sql_database_prompt 81 | 82 | print ("before example_selector") 83 | example_selector = SemanticSimilarityExampleSelector.from_examples( 84 | examples, 85 | my_llm.embeddings, 86 | LanceDB, 87 | k=5, 88 | input_keys=["input"], 89 | ) 90 | print ("before example_prompt") 91 | example_prompt = PromptTemplate.from_template("User input: {input}\nquery: {query}") 92 | print ("after example_prompt") 93 | complex_generation_prompt = FewShotPromptTemplate( 94 | example_selector=example_selector, 95 | example_prompt=example_prompt, 96 | prefix="""You are a duckdb expert. Given an input question, take the examples as templates, and only substitute the template variables with those extracted from the question. Closely mimicing the examples and don't modify the examplar structure easily, since they are curated by human. Add a 'LIMIT {top_k}' clause to the end of the query. \n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding queries. Use them to as inspiration generate your query. 97 | In the question, if you see the '(CUI: xxx)' is present after a medical term, instead of the term, use the that CUI to write your query. 98 | - Almost always start with SELECT, unless it is a graph query. 99 | - The subquery in FROM clause should have an alias, without the keyword AS, Here is an example: SELECT * FROM Trials, GRAPH_TABLE( ... ) drug_for_disease WHERE Trials.drug_cui = drug_for_disease.drug_cui 100 | - If the search term contains a single quote, it should be escaped with another single quote. For example, 'Alzheimer's Disease' should be 'Alzheimer''s Disease'. 101 | - Only return query not anything else like ```sql ... ``` 102 | - Every variable in the graph pattern has to be bound by a variable. For example, (i:Drug)-[:MAY_TREAT]->(c:Disorder WHERE c.name = 'Alzheimer''s Disease') is not correct because :MAY_TREAT is not bound to a variable. Instead, it should be (i:Drug)-[m:MAY_TREAT]->(c:Disorder WHERE c.name = 'Alzheimer''s Disease'). 103 | - If it is a graph query, use "COLUMNS" as the return statement in the graph query. 104 | - Based on the question, include a 'LIMIT' clause before the end of the query. Never write 'LIMIT 0;' nor 'LIMIT;' If you are unsure about the number of results, remove the LIMIT clause entirely. 105 | - Make sure all parentheses are balanced. 106 | - Ends with a semicolon 107 | - Output the final query only. 108 | """, 109 | suffix="User input: {input}\ngraph query: ", 110 | input_variables=["input", "table_info", "top_k"], 111 | ) 112 | 113 | print ("before generate_query") 114 | generate_query = ( 115 | complex_generation_prompt 116 | | my_llm.llm | StrOutputParser() 117 | ) 118 | 119 | query = generate_query.invoke({"input": question, "table_info": database_description, "top_k": top_k}) 120 | print ("query", query) 121 | return query 122 | 123 | 124 | @tool 125 | def sql(question: str, top_k: int): 126 | """ Use the SQL route to get the answer from the database. It can find data across all tables. Consider it as the default tool. top_k is the number of results to return.""" 127 | 128 | print ("sql question", question, "top_k", top_k) 129 | #print ("sql umls_terms", umls_terms) 130 | database_description = my_db_specifics.sql_database_prompt 131 | 132 | examples = my_db_specifics.sql_examples 133 | 134 | 135 | example_selector = SemanticSimilarityExampleSelector.from_examples( 136 | examples, 137 | my_llm.embeddings, 138 | LanceDB, 139 | k=5, 140 | input_keys=["input"], 141 | ) 142 | print ("==== sql ====") 143 | 144 | example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}") 145 | sql_generation_prompt = FewShotPromptTemplate( 146 | example_selector=example_selector, 147 | example_prompt=example_prompt, 148 | prefix="""You are a DuckDB expert. Given an input question, create a syntactically correct DuckDB query to run. Ignore the {top_k} parameter for now. 149 | Here is the relevant table info: {table_info} 150 | In the question, if you see the '(CUI: xxx)' is present after a medical term, instead of the term, use the that CUI to write your query. 151 | - If the search term contains a single quote, it should be escaped with another single quote. For example, 'Alzheimer's Disease' should be 'Alzheimer''s Disease'. 152 | - Only return SQL Query not anything else like ```sql ... ``` 153 | - Using NOT IN with NULL values 154 | - Using UNION when UNION ALL should have been used 155 | - Using BETWEEN for exclusive ranges 156 | - Data type mismatch in predicates 157 | - Using the correct number of arguments for functions 158 | - Casting to the correct data type 159 | - Using the proper columns for joins 160 | - Never write a LIMIT clause. 161 | - Make sure all parentheses are balanced. 162 | - Ends with a semicolon 163 | - Output the final SQL query only. 164 | Below are a number of examples of questions and their corresponding SQL queries.""", 165 | suffix="User input: {input}\nSQL query: ", 166 | input_variables=["input", "top_k", "table_info"], 167 | ) 168 | 169 | engine = create_engine('duckdb:///' + PARAM['drugdb_path'], connect_args={ 170 | 'read_only': True 171 | }) 172 | 173 | db = SQLDatabase(engine=engine) 174 | 175 | #db = SQLDatabase.from_uri('duckdb:///' + PARAM['drugdb_path']) 176 | write_query = create_sql_query_chain(my_llm.llm, db, sql_generation_prompt) 177 | print ("write_query", write_query) 178 | 179 | 180 | 181 | sql_query = write_query.invoke({"question": question, "table_info": database_description, "top_k": top_k}) 182 | sql_query = sql_query.strip() 183 | if top_k is not None: 184 | if sql_query.endswith(";"): 185 | sql_query = sql_query[:-1] + f"\n LIMIT {top_k};" 186 | else: 187 | sql_query = sql_query + f"\n LIMIT {top_k};" 188 | #print ("sql_query", sql_query) 189 | 190 | engine.dispose() 191 | #db.close() 192 | #return {"draft_query": [sql_query]} 193 | return sql_query 194 | #print ("sql_query", sql_query) 195 | #return AIMessage(sql_query) 196 | 197 | 198 | @tool 199 | def graph(question: str, top_k: int): 200 | 201 | """Use the graph query language route to get the answer from the database. Only suitable for questions that involve the interrelationship between the Drugs, Disorders, and MOA tables. top_k is the number of results to return.""" 202 | 203 | print ("++++++++++++++++question", question, "top_k", top_k) 204 | 205 | database_description = my_db_specifics.graph_database_prompt 206 | print ("before examples") 207 | examples = my_db_specifics.graph_examples 208 | print ("before example_selector") 209 | example_selector = SemanticSimilarityExampleSelector.from_examples( 210 | examples, 211 | my_llm.embeddings, 212 | LanceDB, 213 | k=5, 214 | input_keys=["input"] 215 | ) 216 | print ("before example_prompt") 217 | example_prompt = PromptTemplate.from_template("User input: {input}\ngraph query: {query}") 218 | print ("after example_prompt") 219 | #- Add a 'LIMIT {top_k}' clause to the end of the query. Place the LIMIT clause after the closing parenthesis. 220 | pgq_generation_prompt = FewShotPromptTemplate( 221 | example_selector=example_selector, 222 | example_prompt=example_prompt, 223 | prefix="""You are a DuckPGQ expert. Given an input question, create a syntactically correct graph query to run. 224 | Here is the relevant table info: {table_info} 225 | In the question, if you see the '(CUI: xxx)' is present after a medical term, instead of the term, use the that CUI to write your query. 226 | DuckPGQ is very similar to Cypher. But there are some differences. 227 | Double check the user's DuckPGQ graph query for common mistakes, including: 228 | - If the search term contains a single quote, it should be escaped with another single quote. For example, 'Alzheimer's Disease' should be 'Alzheimer''s Disease'. 229 | - It must start with "FROM GRAPH_TABLE (drug_graph" before the MATCH clause. It ends with a closing parenthesis before the LIMIT clause. 230 | - Only return graph query not anything else like ```sql ... ``` 231 | - Every variable in the graph pattern has to be bound by a variable. For example, (i:Drug)-[:MAY_TREAT]->(c:Disorder WHERE c.name = 'Alzheimer''s Disease') is not correct because :MAY_TREAT is not bound to a variable. Instead, it should be (i:Drug)-[m:MAY_TREAT]->(c:Disorder WHERE c.name = 'Alzheimer''s Disease'). 232 | - Use "COLUMNS" as the return statement in the graph query. 233 | - Replace all '\n' with a space. 234 | 235 | - Never write a LIMIT clause. 236 | - Make sure all parentheses are balanced. 237 | - Ends with a semicolon 238 | - Output the final graph query only. 239 | Below are a number of examples of questions and their corresponding graph queries.""", 240 | suffix="User input: {input}\ngraph query: ", 241 | input_variables=["input", "table_info", "top_k"], 242 | ) 243 | 244 | 245 | 246 | generate_query = ( 247 | pgq_generation_prompt 248 | | my_llm.llm | StrOutputParser() 249 | ) 250 | 251 | graph_query = generate_query.invoke({"input": question, "table_info": database_description, "top_k": top_k}) 252 | graph_query = graph_query.strip() 253 | #print ("graph_query", graph_query) 254 | if top_k is not None: 255 | if graph_query.endswith(";"): 256 | graph_query = graph_query[:-1] + f"\n LIMIT {top_k};" 257 | else: 258 | graph_query = graph_query + f"\n LIMIT {top_k};" 259 | return graph_query 260 | #return AIMessage(graph_query) 261 | 262 | 263 | @tool 264 | def vector(question: str, top_k: int = 5) -> str: 265 | """Use vector search to get the disorder from the database. Only suitable for fuzzy questions that involve the definition of disorder, 266 | such as "joint-related disorders" or "disorders that cause rash". If the user asks for a specific disorder, such as 'what is Pericarditis?', use the SQL or graph instead.""" 267 | 268 | vector_query = my_db_specifics.vector_search_query_template.format(question=question, limit=top_k) 269 | 270 | return vector_query 271 | 272 | @tool 273 | def fulltext(question: str, top_k: int = 10) -> str: 274 | """Use the full text search to get the trials from the database. Only suitable for questions that involve the StudyTitle. Use this tool when users question does not read like a sentence and looks like some keywords instead. Keep the original query for the user's reference. And keep all the operators such as &, |, and ! in the query.""" 275 | field_with_full_text_search = "StudyTitle" 276 | 277 | generate_query = my_db_specifics.full_text_search_query_template.format(original_question=question.replace("'", "''"), field=field_with_full_text_search, limit=top_k) 278 | 279 | 280 | return generate_query 281 | 282 | 283 | def execute_query_and_answer(state): 284 | 285 | """ Node to answer a question """ 286 | print("---execute_query_and_answer---") 287 | 288 | #print ("state", state) 289 | # for i, m in enumerate(state["messages"]): 290 | # print (i, m) 291 | 292 | messages = state["messages"] 293 | question = "" 294 | query = "" 295 | #query = state["messages"][-2].content 296 | for m in messages[::-1]: 297 | if question != "" and query != "": 298 | break 299 | #print ("execute_query_and_answer.m", m) 300 | if m.type == "ai" and len(question) == 0: 301 | if len(m.content.strip()) == 0: 302 | print ("---execute_query_and_answer--- m", m) 303 | question = json.loads(m.additional_kwargs["tool_calls"][0]["function"]["arguments"])["question"] 304 | elif m.type == "human" and len(query) == 0: 305 | #print ("m.content", m.content) 306 | query = m.content 307 | 308 | 309 | #1 content='' additional_kwargs={'tool_calls': [{'id': 'call_u3LxRfcWLN4lGxWD9joQzYmN', 'function': {'arguments': '{"question":"What diseases can hydroflumethiazide treat?","top_k":5}', 'name': 'sql'}, 'type': 'function'}], 'refusal': None} response_metadata={'token_usage': {'completion_tokens': 28, 'prompt_tokens': 109, 'total_tokens': 137, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_45cf54deae', 'finish_reason': 'tool_calls', 'logprobs': None} id='run-82248ab2-7303-467c-8a4e-a5c3ca416eb3-0' tool_calls=[{'name': 'sql', 'args': {'question': 'What diseases can hydroflumethiazide treat?', 'top_k': 5}, 'id': 'call_u3LxRfcWLN4lGxWD9joQzYmN', 'type': 'tool_call'}] usage_metadata={'input_tokens': 109, 'output_tokens': 28, 'total_tokens': 137, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}} 310 | #{'arguments': '{"question":"What diseases can hydroflumethiazide treat?","top_k":5}', 'name': 'sql'} 311 | #print(state["messages"][-3].additional_kwargs["tool_calls"][0]["function"]["arguments"]) 312 | #question = json.loads(state["messages"][-3].additional_kwargs["tool_calls"][0]["function"]["arguments"])["question"] 313 | 314 | #pass 315 | 316 | def embeddings(document: str) -> list[float]: 317 | #result = my_llm.embeddings(model=PARAM['vector_embedding_model']).embed_query(document) 318 | result = my_llm.embeddings.embed_query(document) 319 | return result 320 | 321 | con = duckdb.connect(PARAM['drugdb_path'], config = {"allow_unsigned_extensions": "true"}) 322 | con.create_function('embeddings', embeddings) 323 | 324 | 325 | 326 | for c in my_db_specifics.initialization_commands: 327 | #print (c) 328 | con.sql(c) 329 | 330 | # # Answer 331 | print ("query", query) 332 | 333 | execute_result = con.sql(query).fetchall() 334 | print ("execute_result", execute_result) 335 | 336 | 337 | final_response = "" 338 | if len(execute_result) == 0: 339 | final_response = "No results found." 340 | else: 341 | 342 | answer_prompt = PromptTemplate.from_template( 343 | """Given the Question {question} and the query_result {query_result}, format the results into sentences or a table for the human to understand. 344 | Don't add any data or facts outside of the query_result. 345 | Don't alter the quoted strings from the query_result even if they are not grammatically correct (no conversion of "inhibitors" to "inhibitor" and vice versa). 346 | """ 347 | ) 348 | 349 | formulate_human_readable_answer = ( 350 | answer_prompt 351 | | my_llm.llm 352 | | StrOutputParser() 353 | ) 354 | 355 | final_response = formulate_human_readable_answer.invoke({"question": question, "query_result": execute_result}) 356 | 357 | con.close() 358 | 359 | tool_call_id = "" 360 | for m in state["messages"][::-1]: 361 | #print ("m", m.type, m) 362 | if m.type == "tool": 363 | tool_call_id = m.tool_call_id 364 | break 365 | #print ("in execute_query_and_answer tool_calls = state['messages'][-1].tool_calls", tool_calls[0]['id']) 366 | print ("return tool_call_id", tool_call_id) 367 | #print ("return final_response", final_response) 368 | return {"messages": {"question": question, "query": query, "execute_result": execute_result, "role": "assistant", "content": final_response, "tool_call_id":tool_call_id}} 369 | --------------------------------------------------------------------------------