├── .gitignore ├── ARCH-OPEN ├── books_data │ ├── llm_qa_pairs_books_0_4305.pkl │ └── synthetic_data_textbook.ipynb ├── pubmed_data │ ├── llm_qa_pairs_pubmed_0_3309.pkl │ └── synthetic_data_pubmed.ipynb ├── pubmed_qa_pairs.json ├── synthetic_data_compilation.ipynb └── textbook_qa_pairs.json ├── README.md ├── files ├── answer │ ├── fine-tuned │ │ └── text.txt │ └── raw │ │ └── text.txt └── query │ └── text.txt ├── final_models └── text.txt ├── generate_histo_patches.ipynb ├── generate_histo_patches.py ├── generate_llava_med_query.py ├── images ├── path-rag.png └── text.txt ├── llama_7B_model_weights.py ├── model_delta_weights └── text.txt ├── recall_calculation.py ├── tem_data.ipynb └── test.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | pvqa/* 2 | LLaVA-Med/* 3 | histo_image_patch/* 4 | nohup.out 5 | arch/* 6 | histocartography/* -------------------------------------------------------------------------------- /ARCH-OPEN/books_data/llm_qa_pairs_books_0_4305.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/embedded-robotics/path-rag/e1bd5ea6aec4825d708b4f7b77b2fcaf52fff824/ARCH-OPEN/books_data/llm_qa_pairs_books_0_4305.pkl -------------------------------------------------------------------------------- /ARCH-OPEN/books_data/synthetic_data_textbook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import os\n", 11 | "import json\n", 12 | "import openai\n", 13 | "import backoff\n", 14 | "from PIL import Image\n", 15 | "import pickle\n", 16 | "import time\n", 17 | "import sys" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "sys.path.append(\"data/mn27889/path-rag\")\n", 27 | "file_path = os.path.join(\"arch\", \"books_set\", \"captions.json\")\n", 28 | "\n", 29 | "with open(file_path, 'rb') as file:\n", 30 | " captions_data = json.load(file)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "OPENAI_API_PATH = os.path.join(os.getcwd(), 'api.key')\n", 40 | "\n", 41 | "with open(OPENAI_API_PATH) as f:\n", 42 | " openai.api_key = f.read().strip()\n", 43 | "\n", 44 | "@backoff.on_exception(backoff.expo, openai.OpenAIError)\n", 45 | "def completions_with_backoff(**kwargs):\n", 46 | " return openai.chat.completions.create(**kwargs)\n", 47 | "\n", 48 | "def gpt(user_prompt, system_prompt=\"You are an expert pathologist\", model=\"gpt-4\", temperature=0.7, max_tokens=1000) -> list:\n", 49 | "\n", 50 | " messages = [{\"role\": \"system\", \"content\": system_prompt},\n", 51 | " {\"role\": \"user\", \"content\": user_prompt}]\n", 52 | " \n", 53 | " res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens)\n", 54 | " \n", 55 | " return res.choices[0].message.content" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "base_prompt = '''You are provided with a text description (figure caption) of a pathology image. Unfortunately, you don't have access to the original image.\n", 65 | "Your job is to generate a total of 5 open-ended question/answer pairs from this figure caption starting with \"What\" or \"Where\". Below are the requirements to generate the question/answer pairs:\n", 66 | "\n", 67 | "- Avoid quoting or referring to specific facts, terms, abbreviations, dates, numbers or names, as these may reveal the conversation is based on the text information, rather than image itself.\n", 68 | "- Focus on the visual aspects of the image that can be inferred without the text information\n", 69 | "- Do not use phrases like \"mentioned\", \"caption\", \"context\", \"without the image\" in the question/answer pairs. Instead, refer to the information as being \"in the image\" or preferably don't mention anything\n", 70 | "- Ensure that question/anwer pairs are diverse and cover a range of visual aspects of the image\n", 71 | "- Answer responsibly, avoiding overconfidence, and do not provide medical advice or diagnostic information\n", 72 | "\n", 73 | "Caption: {caption}\n", 74 | "Question:\n", 75 | "Answer:\n", 76 | "'''" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# Getting the results and saving it\n", 86 | "index_list = []\n", 87 | "figure_id_list = []\n", 88 | "letter_list = []\n", 89 | "caption_list = []\n", 90 | "uuid_list = []\n", 91 | "llm_response_list = []\n", 92 | "\n", 93 | "start_index = 0\n", 94 | "current_index = start_index\n", 95 | "total_records = len(captions_data)\n", 96 | "\n", 97 | "while True:\n", 98 | " try:\n", 99 | " for index in range(start_index, total_records):\n", 100 | " current_index = index\n", 101 | " figure_id = captions_data[str(current_index)]['figure_id']\n", 102 | " letter = captions_data[str(current_index)]['letter']\n", 103 | " caption = captions_data[str(current_index)]['caption']\n", 104 | " uuid = captions_data[str(current_index)]['uuid']\n", 105 | " \n", 106 | " user_prompt = base_prompt.format(caption = caption)\n", 107 | " response = gpt(user_prompt)\n", 108 | " \n", 109 | " index_list.append(current_index)\n", 110 | " figure_id_list.append(figure_id)\n", 111 | " letter_list.append(letter)\n", 112 | " caption_list.append(caption)\n", 113 | " uuid_list.append(uuid)\n", 114 | " llm_response_list.append(response)\n", 115 | "\n", 116 | " print(\"Index:\", current_index)\n", 117 | " print(\"Figure_ID:\", figure_id)\n", 118 | " print(\"Letter:\", letter)\n", 119 | " print(\"Caption:\", caption)\n", 120 | " print(\"UUID:\", uuid)\n", 121 | " print()\n", 122 | " print(response)\n", 123 | " print()\n", 124 | " \n", 125 | " except Exception as err:\n", 126 | " print(\"Something went wrong: \", err)\n", 127 | " start_index = current_index\n", 128 | " print(\"Waiting for 10 seconds before continuing again with index:\", start_index)\n", 129 | " time.sleep(10)\n", 130 | "\n", 131 | " # Break the loop if current_index has completed\n", 132 | " if current_index == (total_records - 1):\n", 133 | " break\n", 134 | "\n", 135 | "llm_qa_pairs_books = pd.DataFrame({'index': index_list, 'figure_id': figure_id_list, 'letter': letter_list,\n", 136 | " 'caption': caption_list, 'uuid': uuid_list, 'llm_qa_pairs_books': llm_response_list})\n", 137 | "\n", 138 | "file_name = 'llm_qa_pairs_books_' + str(start_index) + '_' + str(total_records) + '.pkl'\n", 139 | "\n", 140 | "with open(file_name, 'wb') as file:\n", 141 | " pickle.dump(llm_qa_pairs_books, file)" 142 | ] 143 | } 144 | ], 145 | "metadata": { 146 | "kernelspec": { 147 | "display_name": "path-rag", 148 | "language": "python", 149 | "name": "python3" 150 | }, 151 | "language_info": { 152 | "codemirror_mode": { 153 | "name": "ipython", 154 | "version": 3 155 | }, 156 | "file_extension": ".py", 157 | "mimetype": "text/x-python", 158 | "name": "python", 159 | "nbconvert_exporter": "python", 160 | "pygments_lexer": "ipython3", 161 | "version": "3.9.19" 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 2 166 | } 167 | -------------------------------------------------------------------------------- /ARCH-OPEN/pubmed_data/llm_qa_pairs_pubmed_0_3309.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/embedded-robotics/path-rag/e1bd5ea6aec4825d708b4f7b77b2fcaf52fff824/ARCH-OPEN/pubmed_data/llm_qa_pairs_pubmed_0_3309.pkl -------------------------------------------------------------------------------- /ARCH-OPEN/pubmed_data/synthetic_data_pubmed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import os\n", 11 | "import json\n", 12 | "import openai\n", 13 | "import backoff\n", 14 | "from PIL import Image\n", 15 | "import pickle\n", 16 | "import time\n", 17 | "import sys" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "sys.path.append(\"data/mn27889/path-rag\")\n", 27 | "file_path = os.path.join(\"arch\", \"pubmed_set\", \"captions.json\")\n", 28 | "\n", 29 | "with open(file_path, 'rb') as file:\n", 30 | " captions_data = json.load(file)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "OPENAI_API_PATH = os.path.join(os.getcwd(), 'api.key')\n", 40 | "\n", 41 | "with open(OPENAI_API_PATH) as f:\n", 42 | " openai.api_key = f.read().strip()\n", 43 | "\n", 44 | "@backoff.on_exception(backoff.expo, openai.OpenAIError)\n", 45 | "def completions_with_backoff(**kwargs):\n", 46 | " return openai.chat.completions.create(**kwargs)\n", 47 | "\n", 48 | "def gpt(user_prompt, system_prompt=\"You are an expert pathologist\", model=\"gpt-4\", temperature=0.7, max_tokens=1000) -> list:\n", 49 | "\n", 50 | " messages = [{\"role\": \"system\", \"content\": system_prompt},\n", 51 | " {\"role\": \"user\", \"content\": user_prompt}]\n", 52 | " \n", 53 | " res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens)\n", 54 | " \n", 55 | " return res.choices[0].message.content" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "base_prompt = '''You are provided with a text description (figure caption) of a pathology image. Unfortunately, you don't have access to the original image.\n", 65 | "Your job is to generate a total of 5 open-ended question/answer pairs from this figure caption starting with \"What\" or \"Where\". Below are the requirements to generate the question/answer pairs:\n", 66 | "\n", 67 | "- Avoid quoting or referring to specific facts, terms, abbreviations, dates, numbers or names, as these may reveal the conversation is based on the text information, rather than image itself.\n", 68 | "- Focus on the visual aspects of the image that can be inferred without the text information\n", 69 | "- Do not use phrases like \"mentioned\", \"caption\", \"context\", \"without the image\" in the question/answer pairs. Instead, refer to the information as being \"in the image\" or preferably don't mention anything\n", 70 | "- Ensure that question/anwer pairs are diverse and cover a range of visual aspects of the image\n", 71 | "- Answer responsibly, avoiding overconfidence, and do not provide medical advice or diagnostic information\n", 72 | "\n", 73 | "Caption: {caption}\n", 74 | "Question:\n", 75 | "Answer:\n", 76 | "'''" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# Getting the results and saving it\n", 86 | "index_list = []\n", 87 | "caption_list = []\n", 88 | "uuid_list = []\n", 89 | "llm_response_list = []\n", 90 | "\n", 91 | "start_index = 0\n", 92 | "current_index = start_index\n", 93 | "total_records = len(captions_data)\n", 94 | "\n", 95 | "while True:\n", 96 | " try:\n", 97 | " for index in range(start_index, total_records):\n", 98 | " current_index = index\n", 99 | " caption = captions_data[str(current_index)]['caption']\n", 100 | " uuid = captions_data[str(current_index)]['uuid']\n", 101 | " \n", 102 | " user_prompt = base_prompt.format(caption = caption)\n", 103 | " response = gpt(user_prompt)\n", 104 | " \n", 105 | " index_list.append(current_index)\n", 106 | " caption_list.append(caption)\n", 107 | " uuid_list.append(uuid)\n", 108 | " llm_response_list.append(response)\n", 109 | "\n", 110 | " print(\"Index:\", current_index)\n", 111 | " print(\"Caption:\", caption)\n", 112 | " print(\"UUID:\", uuid)\n", 113 | " print()\n", 114 | " print(response)\n", 115 | " print()\n", 116 | " \n", 117 | " except Exception as err:\n", 118 | " print(\"Something went wrong: \", err)\n", 119 | " start_index = current_index\n", 120 | " print(\"Waiting for 10 seconds before continuing again with index:\", start_index)\n", 121 | " time.sleep(10)\n", 122 | "\n", 123 | " # Break the loop if current_index has completed\n", 124 | " if current_index == (total_records - 1):\n", 125 | " break\n", 126 | "\n", 127 | "\n", 128 | "llm_qa_pairs_pubmed = pd.DataFrame({'index': index_list, 'caption': caption_list, 'uuid': uuid_list, 'llm_qa_pairs_pubmed': llm_response_list})\n", 129 | "\n", 130 | "file_name = 'llm_qa_pairs_pubmed_' + str(start_index) + '_' + str(total_records) + '.pkl'\n", 131 | "\n", 132 | "with open(file_name, 'wb') as file:\n", 133 | " pickle.dump(llm_qa_pairs_pubmed, file)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "path-rag", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.9.19" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 2 165 | } 166 | -------------------------------------------------------------------------------- /ARCH-OPEN/synthetic_data_compilation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### This code will compile the questions/answers from both the responses received from GPT for textbook and pubmed data" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import json\n", 17 | "import pickle\n", 18 | "import pandas as pd\n", 19 | "import re" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "### Pubmed Data" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "with open('pubmed_data/llm_qa_pairs_pubmed_0_3309.pkl', 'rb') as file:\n", 36 | " pubmed_data = pickle.load(file)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stderr", 46 | "output_type": "stream", 47 | "text": [ 48 | "<>:7: SyntaxWarning: invalid escape sequence '\\s'\n", 49 | "<>:8: SyntaxWarning: invalid escape sequence '\\s'\n", 50 | "<>:7: SyntaxWarning: invalid escape sequence '\\s'\n", 51 | "<>:8: SyntaxWarning: invalid escape sequence '\\s'\n", 52 | "/tmp/ipykernel_2634287/1379646763.py:7: SyntaxWarning: invalid escape sequence '\\s'\n", 53 | " questions = re.findall('(Question.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n", 54 | "/tmp/ipykernel_2634287/1379646763.py:8: SyntaxWarning: invalid escape sequence '\\s'\n", 55 | " answers = re.findall('(Answer.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "pubmed_data_list = []\n", 61 | "\n", 62 | "for index, row in pubmed_data.iterrows():\n", 63 | " caption = row['caption']\n", 64 | " uuid = row['uuid']\n", 65 | " qa_pairs = row['llm_qa_pairs']\n", 66 | " questions = re.findall('(Question.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n", 67 | " answers = re.findall('(Answer.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n", 68 | " assert len(questions) == 5\n", 69 | " assert len(answers) == 5\n", 70 | "\n", 71 | " data = {'caption': caption,\n", 72 | " 'uuid': uuid}\n", 73 | " for i in range(0,5):\n", 74 | " data['Question_' + str(i+1)] = questions[i][1]\n", 75 | " data['Answer_' + str(i+1)] = answers[i][1]\n", 76 | " \n", 77 | " pubmed_data_list.append(data)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "with open('pubmed_qa_pairs.json', 'w') as file:\n", 87 | " json.dump(pubmed_data_list, file, indent=2)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "### Textbook Data" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "with open('books_data/llm_qa_pairs_books_0_4305.pkl', 'rb') as file:\n", 104 | " textbook_data = pickle.load(file)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 6, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stderr", 114 | "output_type": "stream", 115 | "text": [ 116 | "<>:9: SyntaxWarning: invalid escape sequence '\\s'\n", 117 | "<>:10: SyntaxWarning: invalid escape sequence '\\s'\n", 118 | "<>:9: SyntaxWarning: invalid escape sequence '\\s'\n", 119 | "<>:10: SyntaxWarning: invalid escape sequence '\\s'\n", 120 | "/tmp/ipykernel_2634287/2596557835.py:9: SyntaxWarning: invalid escape sequence '\\s'\n", 121 | " questions = re.findall('(Question.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n", 122 | "/tmp/ipykernel_2634287/2596557835.py:10: SyntaxWarning: invalid escape sequence '\\s'\n", 123 | " answers = re.findall('(Answer.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "textbook_data_list = []\n", 129 | "\n", 130 | "for index, row in textbook_data.iterrows():\n", 131 | " figure_id = row['figure_id']\n", 132 | " letter = row['letter']\n", 133 | " caption = row['caption']\n", 134 | " uuid = row['uuid']\n", 135 | " qa_pairs = row['llm_qa_pairs_books']\n", 136 | " questions = re.findall('(Question.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n", 137 | " answers = re.findall('(Answer.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n", 138 | " assert len(questions) == 5\n", 139 | " assert len(answers) == 5\n", 140 | "\n", 141 | " data = {'figure_id':figure_id,\n", 142 | " 'letter': letter,\n", 143 | " 'caption': caption,\n", 144 | " 'uuid': uuid}\n", 145 | " for i in range(0,5):\n", 146 | " data['Question_' + str(i+1)] = questions[i][1]\n", 147 | " data['Answer_' + str(i+1)] = answers[i][1]\n", 148 | " \n", 149 | " textbook_data_list.append(data)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 7, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "with open('textbook_qa_pairs.json', 'w') as file:\n", 159 | " json.dump(textbook_data_list, file, indent=2)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [] 168 | } 169 | ], 170 | "metadata": { 171 | "kernelspec": { 172 | "display_name": "path-rag-dpo", 173 | "language": "python", 174 | "name": "python3" 175 | }, 176 | "language_info": { 177 | "codemirror_mode": { 178 | "name": "ipython", 179 | "version": 3 180 | }, 181 | "file_extension": ".py", 182 | "mimetype": "text/x-python", 183 | "name": "python", 184 | "nbconvert_exporter": "python", 185 | "pygments_lexer": "ipython3", 186 | "version": "3.12.4" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 2 191 | } 192 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Path-RAG: Knowledge-Guided Key Region Retrieval for Open-ended Pathology Visual Question Answering 2 | 3 |

4 |
5 | 6 | *Accurate diagnosis and prognosis assisted by pathology images are essential for cancer treatment selection and planning. Despite the recent trend of adopting deep-learning approaches for analyzing complex pathology images, they fall short as they often overlook the domain-expert understanding of tissue structure and cell composition. In this work, we focus on a challenging Open-ended Pathology VQA (PathVQA-Open) task and propose a novel framework named Path-RAG, which leverages HistoCartography to retrieve relevant domain knowledge from pathology images and significantly improves performance on PathVQA-Open. Admitting the complexity of pathology image analysis, Path-RAG adopts a human-centered AI approach by retrieving domain knowledge using HistoCartography to select the relevant patches from pathology images. Our experiments suggest that domain guidance can significantly boost the accuracy of LLaVA-Med from 38\% to 47\%, with a notable gain of 28\% for H\&E-stained pathology images in the PathVQA-Open dataset. For longer-form question and answer pairs, our model consistently achieves significant improvements of 32.5\% in ARCH-Open PubMed and 30.6\% in ARCH-Open Books on H\&E images.* 7 |

8 | 9 | --- 10 | 11 | Awais Naeem*, Tianhao Li*, Huang-Ru Liao*, Jiawei Xu*, Aby Mammen Mathew*, Zehao Zhu*, Zhen Tan**, Ajay Jaiswal*, Raffi Salibian*** , Ziniu Hu*** , Tianlong Chen****, Ying Ding* 12 | 13 | *University of Texas at Austin, USA \ 14 | **Arizona State University, USA \ 15 | ***University of California, Los Angeles, USA \ 16 | ****Massachusetts Institute of Technology, USA 17 | 18 | --- 19 | 20 | ## Path-RAG Implementation 21 | 22 | ### 1. Clone this repository and navigate to path-rag folder 23 | 24 | ```Shell 25 | git clone https://github.com/embedded-robotics/path-rag.git 26 | cd path-rag 27 | ``` 28 | 29 | ### 2. Install Package: Create conda environment 30 | 31 | ```Shell 32 | conda create -n path-rag python=3.10 -y 33 | conda activate path-rag 34 | pip install --upgrade pip # enable PEP 660 support for LLaVA-Med 35 | ``` 36 | 37 | ### 3. Download the PathVQA dataset from the following link 38 | 39 | [PathVQA Dataset](https://github.com/UCSD-AI4H/PathVQA/blob/master/data/README.md) 40 | 41 | ### 4. Clone the HistoCartography tool, setup the model checkpoints in `histocartography/checkpoints` and install the dependencies 42 | 43 | ```Shell 44 | git clone https://github.com/BiomedSciAI/histocartography 45 | ``` 46 | 47 | ### 5. Clone the LLaVA-Med repository and install the dependencies 48 | 49 | ```Shell 50 | git clone https://github.com/microsoft/LLaVA-Med 51 | ``` 52 | 53 | ### 6. Download the LLaMA-7B model and weights from HuggingFace 54 | 55 | ```Shell 56 | python llama_7B_model_weights.py # LLaMA-7B weights/model stored into $HF_HOME (By Default $HF_HOME = ~/.cache/huggingface) 57 | ``` 58 | 59 | ### 7. Download LLaVA-Med delta weights `llava_med_in_text_60k_ckpt2_delta` and `pvqa-9epoch_delta` from `https://github.com/microsoft/LLaVA-Med#model-download`. Put them inside a folder named `model_delta_weights` 60 | 61 | ### 8. Apply the LLaVA-Med delta weights to base LLaMA-7B to come up with the final weights for LLaVA-Med 62 | 63 | ```Shell 64 | cd LLaVA-Med 65 | ``` 66 | 67 | #### LLaVA-Med pre-trained on general biomedicine data 68 | 69 | ```Shell 70 | !python3 -m llava.model.apply_delta \ 71 | --base ~/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16 \ 72 | --target ../final_models/llava_med \ 73 | --delta ../model_delta_weights/llava_med_in_text_60k_ckpt2_delta 74 | ``` 75 | 76 | #### LLaVA-Med fine-tuned on PathVQA 77 | 78 | ```Shell 79 | !python -m llava.model.apply_delta \ 80 | --base ~/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16 \ 81 | --target ../final_models/llava_med_pvqa \ 82 | --delta ../model_delta_weights/pvqa-9epoch_delta 83 | ``` 84 | 85 | ```Shell 86 | cd .. 87 | ``` 88 | 89 | ### 9. Generate the top patches for open-ended PathVQA images using HistoCartography 90 | 91 | ```Shell 92 | python generate_histo_patches.py 93 | ``` 94 | 95 | ### 10. Generate the files for query to be asked for LLaVA-Med for both the images and patches 96 | 97 | ```Shell 98 | python generate_llava_med_query.py 99 | ``` 100 | 101 | ### 11. Now we need to generate the answer for all the query files using raw model (`final_models/llava_med`) and fine-tuned model (`final_models/llava_med_pvqa`) 102 | 103 | ```Shell 104 | cd LLaVA-Med 105 | ``` 106 | 107 | #### Raw Model 108 | ```Shell 109 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med \ 110 | --question-file ../files/query/image_direct.jsonl \ 111 | --image-folder ../pvqa/images/test \ 112 | --answers-file ../files/answer/raw/answer_image_direct.jsonl 113 | ``` 114 | 115 | ```Shell 116 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med \ 117 | --question-file ../files/query/patch_direct.jsonl \ 118 | --image-folder ../pvqa/images/test \ 119 | --answers-file ../files/answer/raw/answer_patch_direct.jsonl 120 | ``` 121 | 122 | ```Shell 123 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med \ 124 | --question-file ../files/query/image_description.jsonl \ 125 | --image-folder ../pvqa/images/test \ 126 | --answers-file ../files/answer/raw/answer_image_description.jsonl 127 | ``` 128 | 129 | ```Shell 130 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med \ 131 | --question-file ../files/query/patch_description.jsonl \ 132 | --image-folder ../pvqa/images/test \ 133 | --answers-file ../files/answer/raw/answer_patch_description.jsonl 134 | ``` 135 | 136 | #### Fine-Tuned Model 137 | ```Shell 138 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \ 139 | --question-file ../files/query/image_direct.jsonl \ 140 | --image-folder ../pvqa/images/test \ 141 | --answers-file ../files/answer/fine-tuned/answer_image_direct.jsonl 142 | ``` 143 | 144 | ```Shell 145 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \ 146 | --question-file ../files/query/patch_direct.jsonl \ 147 | --image-folder ../pvqa/images/test \ 148 | --answers-file ../files/answer/fine-tuned/answer_patch_direct.jsonl 149 | ``` 150 | 151 | ```Shell 152 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \ 153 | --question-file ../files/query/image_description.jsonl \ 154 | --image-folder ../pvqa/images/test \ 155 | --answers-file ../files/answer/fine-tuned/answer_image_description.jsonl 156 | ``` 157 | 158 | ```Shell 159 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \ 160 | --question-file ../files/query/patch_description.jsonl \ 161 | --image-folder ../pvqa/images/test \ 162 | --answers-file ../files/answer/fine-tuned/answer_patch_description.jsonl 163 | ``` 164 | 165 | ### 12. Evaluate the results for different use-cases using `recall_calculation.py` 166 | 167 | (i) Path-RAG w/o GPT: Combine the answer of image + all patches to be the final predicted answer\ 168 | (ii) Path-RAG (description): Combine the description of image + all patches. Then involve GPT-4 for reasoning to ge the final predicted answer (See Supplementary Section for Prompts)\ 169 | (iii) Path-RAG (answer): Combine the answer of image + all patches. Then involve GPT-4 for reasoning to ge the final predicted answer (See Supplementary Section for Prompts) 170 | 171 | 172 | ## ARCH-Open Dataset 173 | 174 | 1. Download the `books_set` and `pubmed_set` of ARCH dataset from `https://warwick.ac.uk/fac/cross_fac/tia/data/arch`. Store both of these folders in a folder named `arch`. Both `books_set` and `pubmed_set` contains `captions.json` which lists a **caption** and a **UUID**, whereas **UUID** represents the file name in `images` folder and **caption** represents the description of the image. 175 | 176 | 2. Using `captions.json` and `images` folder under `arch/books_set`, run the notebooks `ARCH-OPEN/books_data/synthetic_data_textbook.ipynb` by specifying the OpenAI credentials to generate the question-answer pairs for books set 177 | 178 | 3. Using `captions.json` and `images` folder under `arch/pubmed_set`, run the notebooks `ARCH-OPEN/pubmed_data/synthetic_data_pubmed.ipynb` by specifying the OpenAI credentials to generate the question-answer pairs for pubmed set 179 | 180 | 4. Run the notebook `ARCH-OPEN/synthetic_data_compilation.ipynb` to compile the `pubmed` and `books` question-answer pairs into json files namely `ARCH-OPEN/pubmed_qa_pairs.json` and `ARCH-OPEN/textbook_qa_pairs.json`. These files are already provided to be used directly 181 | 182 | 5. The `pubmed_qa_pairs.json` and `textbook_qa_pairs.json` files contain 5 question-pairs for each pair of `caption` and `uuid` (refers to image name in arch data `arch/pubmed_set/images`, `arch/books_set/images`) in the following format (for both `pubmed_set` and `books_set`): 183 | 184 | ```Shell 185 | { 186 | "figure_id": "00", 187 | "letter": "A", 188 | "caption": " A, Spindle cell variant of embryonal rhabdomyosarcoma is characterized by fascicles of eosinophilic spindle cells (B), some of which can show prominent paranuclear vacuolisation, as seen in leiomyosarcoma.", 189 | "uuid": "890e2e79-ab0a-4a2e-9d62-b0b6b3d43884", 190 | "Question_1": "What could be the general shape of cells in a spindle cell variant of embryonal rhabdomyosarcoma as seen in the image?", 191 | "Answer_1": "The cells often present with a spindle-like elongated shape.", 192 | "Question_2": "What type of structures could be visible in the image indicating the presence of spindle cells?", 193 | "Answer_2": "Fascicles, or bundles, of cells could be visible in the image, indicating the presence of spindle cells.", 194 | "Question_3": "Where in the cell would we likely find paranuclear vacuolisation in the image?", 195 | "Answer_3": "Paranuclear vacuolisation is usually seen around the nucleus area of the cell.", 196 | "Question_4": "What color might the spindle cells appear in the image?", 197 | "Answer_4": "The spindle cells may appear eosinophilic, or pinkish-red, under the microscope due to staining.", 198 | "Question_5": "What visual feature might differentiate spindle cells from leiomyosarcoma cells in the image?", 199 | "Answer_5": "Spindle cells might show prominent paranuclear vacuolisation, a feature that can differentiate them from leiomyosarcoma cells." 200 | } 201 | ``` 202 | 203 | ## Acknolwdgement 204 | We would like to acknowledge the following funding supports: NIH OT2OD032581, NIH OTA-21-008, NIH 1OT2OD032742-01, NSF 2333703, NSF 2303038. -------------------------------------------------------------------------------- /files/answer/fine-tuned/text.txt: -------------------------------------------------------------------------------- 1 | For Answer files by fine-tuned model 2 | -------------------------------------------------------------------------------- /files/answer/raw/text.txt: -------------------------------------------------------------------------------- 1 | For Answer files by raw model 2 | -------------------------------------------------------------------------------- /files/query/text.txt: -------------------------------------------------------------------------------- 1 | For Query files 2 | -------------------------------------------------------------------------------- /final_models/text.txt: -------------------------------------------------------------------------------- 1 | Forfinal modesl 2 | -------------------------------------------------------------------------------- /generate_histo_patches.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from PIL import Image\n", 10 | "import numpy as np\n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 8, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "data_path = \"arch\"\n", 21 | "\n", 22 | "# Extract the image data for PubMed dataset\n", 23 | "pubmed_path_images = os.path.join(data_path, \"pubmed_set\", \"images\")\n", 24 | "pubmed_img_uuid_list = os.listdir(pubmed_path_images)\n", 25 | "pubmed_img_uuid = [uuid.split('.')[0] for uuid in pubmed_img_uuid_list]\n", 26 | "pubmed_img_uuid_path = [os.path.join(pubmed_path_images, img_uuid) for img_uuid in pubmed_img_uuid_list]\n", 27 | "\n", 28 | "# # Extract the image data for Books dataset\n", 29 | "# books_path_images = os.path.join(data_path, \"books_set\", \"images\")\n", 30 | "# books_img_uuid = os.listdir(books_path_images)\n", 31 | "# books_img_uuid = [uuid.split('.')[0] for uuid in books_img_uuid]\n", 32 | "# books_img_uuid_path = [os.path.join(books_path_images, img_uuid + '.png') for img_uuid in books_img_uuid]" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 6, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "71\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "for i in range(len(pubmed_img_uuid)):\n", 50 | " if pubmed_img_uuid[i] == 'a6fade92-c4ce-4eed-b8aa-795e85c4645e':\n", 51 | " print(i)\n", 52 | " break" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [] 61 | } 62 | ], 63 | "metadata": { 64 | "kernelspec": { 65 | "display_name": "path-rag", 66 | "language": "python", 67 | "name": "python3" 68 | }, 69 | "language_info": { 70 | "codemirror_mode": { 71 | "name": "ipython", 72 | "version": 3 73 | }, 74 | "file_extension": ".py", 75 | "mimetype": "text/x-python", 76 | "name": "python", 77 | "nbconvert_exporter": "python", 78 | "pygments_lexer": "ipython3", 79 | "version": "3.9.19" 80 | } 81 | }, 82 | "nbformat": 4, 83 | "nbformat_minor": 2 84 | } 85 | -------------------------------------------------------------------------------- /generate_histo_patches.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import os 4 | import json 5 | import pickle 6 | import sys 7 | sys.path.append(os.path.join(os.getcwd(), 'histocartography')) 8 | 9 | from histocartography.preprocessing import NucleiExtractor, DeepFeatureExtractor, KNNGraphBuilder 10 | 11 | PVQA_DATA_PATH = "pvqa" 12 | HISTO_PATCH_SAVE_PATH = "histo_image_patch" 13 | ARCH_DATA_PATH = "arch" 14 | 15 | # Cell Graph Generation Definitions 16 | nuclei_detector = NucleiExtractor() 17 | feats_extractor = DeepFeatureExtractor(architecture='resnet34', patch_size=72, resize_size=224) 18 | knn_graph_builder = KNNGraphBuilder(k=5, thresh=50, add_loc_feats=True) 19 | 20 | # PathVQA Dataset Processing 21 | def get_path_vqa_open_images(data_path : str = "pvqa"): 22 | 23 | # Reading the PathVQA dataset 24 | img_train_path = os.path.join(data_path, "images", "test") 25 | qas_train_path = os.path.join(data_path, "qas", "test", "test_qa.pkl") 26 | with open(qas_train_path, 'rb') as file: 27 | pvqa_qas = pickle.load(file) 28 | 29 | # Getting all open-ended images 30 | qas_general = [qas for qas in pvqa_qas if qas['answer'] != 'yes' and qas['answer'] != 'no'] 31 | img_general = [qas['image'] for qas in qas_general] 32 | img_general = list(set(img_general)) 33 | img_general = sorted(img_general, key=str) 34 | img_general_path = [img_train_path + img_name + '.jpg' for img_name in img_general] 35 | 36 | return img_general, img_general_path 37 | 38 | # PathVQA Dataset Processing 39 | def get_arch_open_images(data_path : str = "arch"): 40 | 41 | # Extract the image data for PubMed dataset 42 | pubmed_path_images = os.path.join(data_path, "pubmed_set", "images") 43 | pubmed_img_uuid_list = os.listdir(pubmed_path_images) 44 | pubmed_img_uuid = [uuid.split('.')[0] for uuid in pubmed_img_uuid_list] 45 | pubmed_img_uuid_path = [os.path.join(pubmed_path_images, img_uuid) for img_uuid in pubmed_img_uuid_list] 46 | 47 | # Extract the image data for Books dataset 48 | books_path_images = os.path.join(data_path, "books_set", "images") 49 | books_img_uuid_list = os.listdir(books_path_images) 50 | books_img_uuid = [uuid.split('.')[0] for uuid in books_img_uuid_list] 51 | books_img_uuid_path = [os.path.join(books_path_images, img_uuid) for img_uuid in books_img_uuid_list] 52 | 53 | return pubmed_img_uuid, pubmed_img_uuid_path, books_img_uuid, books_img_uuid_path 54 | 55 | # Save top patches using histocartography 56 | def save_histocartography_top_patches(img_general : list, img_general_path: list): 57 | for image_idx in range(0, len(img_general)): 58 | 59 | print(f"{image_idx}: Started ") 60 | query_img = Image.open(img_general_path[image_idx]).convert(mode="RGB") 61 | image = np.array(query_img) 62 | nuclei_map, nuclei_centers = nuclei_detector.process(image) 63 | 64 | # Only consider if more than 5 nuclei are detected since knn needs to form a graph using 5 neighbors. 65 | # If less than 5 nuclei are present, most of the images are not pathology related 66 | if nuclei_centers.shape[0] > 5: 67 | print(f"{image_idx}: Patches ") 68 | 69 | # Get the Features 70 | features = feats_extractor.process(image, nuclei_map) 71 | 72 | # Make Cell Graph 73 | cell_graph = knn_graph_builder.process(nuclei_map, features) 74 | 75 | # Make calculations to extract patches and the overlap images 76 | width, height = query_img.size 77 | width_range = np.linspace(0, width, 4, dtype=int) 78 | height_range = np.linspace(0, height, 4, dtype=int) 79 | 80 | overlap_percent = 20 81 | width_overlap = int((overlap_percent/100) * width) 82 | height_overlap = int((overlap_percent/100) * height) 83 | 84 | # Extract the patches 85 | image_patches = [] 86 | patch_nuclei_centers = [] 87 | for i in range(len(width_range)-1): 88 | for j in range(len(height_range)-1): 89 | # Consider the overlap width from second patch only 90 | if i != 0: 91 | start_width = width_range[i] - width_overlap 92 | else: 93 | start_width = width_range[i] 94 | 95 | # Consider the overlap height from second patch only 96 | if j != 0: 97 | start_height = height_range[j] - height_overlap 98 | else: 99 | start_height = height_range[j] 100 | 101 | # List out the patch ranges 102 | left = start_width 103 | upper = start_height 104 | right = width_range[i+1] 105 | lower = height_range[j+1] 106 | 107 | center_list = [] 108 | for center in nuclei_centers: 109 | if ((center[0] >= left) and (center[0] <=right) and 110 | (center[1] >= upper) and (center[1] <=lower)): 111 | center_list.append(center) 112 | 113 | image_patches.append(query_img.crop((left, upper, right, lower))) 114 | patch_nuclei_centers.append(center_list) 115 | 116 | # Calculate the length of nuclei in each patch 117 | patch_center_length = [] 118 | for center in patch_nuclei_centers: 119 | patch_center_length.append(len(center)) 120 | 121 | # Sort the patch indices based on maximum number of nuclei 122 | sorted_indices_desc = np.flip(np.argsort(patch_center_length)) 123 | 124 | # Create a directory to store all the patches of the image 125 | save_directory = os.path.join(HISTO_PATCH_SAVE_PATH, img_general[image_idx]) 126 | if not os.path.isdir(save_directory): 127 | os.mkdir(save_directory) 128 | 129 | # Store all the image patches into the newly created directory 130 | for patch_index in range(0,6): 131 | save_file_path = os.path.join(save_directory, str(patch_index+1) + ".png") 132 | image_patches[sorted_indices_desc[patch_index]].save(save_file_path) 133 | 134 | print(f"{image_idx}: Ended ") 135 | print(".........") 136 | 137 | # Save top patches using histocartography 138 | def save_histocartography_top_patches_arch(img_uuid : list, img_uuid_path: list, books_pubmed_class: str): 139 | for image_idx in range(0, len(img_uuid)): 140 | 141 | print(f"{image_idx}/{len(img_uuid)}: Started ") 142 | 143 | query_img = Image.open(img_uuid_path[image_idx]).convert(mode="RGB") 144 | image = np.array(query_img) 145 | nuclei_map, nuclei_centers = nuclei_detector.process(image) 146 | 147 | # Only consider if more than 5 nuclei are detected since knn needs to form a graph using 5 neighbors. 148 | # If less than 5 nuclei are present, most of the images are not pathology related 149 | if nuclei_centers.shape[0] > 5: 150 | print(f"{image_idx}: Creating Patches ") 151 | 152 | # Get the Features 153 | features = feats_extractor.process(image, nuclei_map) 154 | 155 | # Make Cell Graph 156 | cell_graph = knn_graph_builder.process(nuclei_map, features) 157 | 158 | # Make calculations to extract patches and the overlap images 159 | width, height = query_img.size 160 | width_range = np.linspace(0, width, 4, dtype=int) 161 | height_range = np.linspace(0, height, 4, dtype=int) 162 | 163 | overlap_percent = 20 164 | width_overlap = int((overlap_percent/100) * width) 165 | height_overlap = int((overlap_percent/100) * height) 166 | 167 | # Extract the patches 168 | image_patches = [] 169 | patch_nuclei_centers = [] 170 | for i in range(len(width_range)-1): 171 | for j in range(len(height_range)-1): 172 | # Consider the overlap width from second patch only 173 | if i != 0: 174 | start_width = width_range[i] - width_overlap 175 | else: 176 | start_width = width_range[i] 177 | 178 | # Consider the overlap height from second patch only 179 | if j != 0: 180 | start_height = height_range[j] - height_overlap 181 | else: 182 | start_height = height_range[j] 183 | 184 | # List out the patch ranges 185 | left = start_width 186 | upper = start_height 187 | right = width_range[i+1] 188 | lower = height_range[j+1] 189 | 190 | center_list = [] 191 | for center in nuclei_centers: 192 | if ((center[0] >= left) and (center[0] <=right) and 193 | (center[1] >= upper) and (center[1] <=lower)): 194 | center_list.append(center) 195 | 196 | image_patches.append(query_img.crop((left, upper, right, lower))) 197 | patch_nuclei_centers.append(center_list) 198 | 199 | # Calculate the length of nuclei in each patch 200 | patch_center_length = [] 201 | for center in patch_nuclei_centers: 202 | patch_center_length.append(len(center)) 203 | 204 | # Sort the patch indices based on maximum number of nuclei 205 | sorted_indices_desc = np.flip(np.argsort(patch_center_length)) 206 | 207 | # Create a directory to store all the patches of the image 208 | save_directory = os.path.join(os.getcwd(), HISTO_PATCH_SAVE_PATH, ARCH_DATA_PATH, books_pubmed_class, img_uuid[image_idx]) 209 | if not os.path.isdir(save_directory): 210 | os.mkdir(save_directory) 211 | 212 | # Store all the image patches into the newly created directory 213 | for patch_index in range(0,6): 214 | save_file_path = os.path.join(save_directory, str(patch_index+1) + ".png") 215 | image_patches[sorted_indices_desc[patch_index]].save(save_file_path) 216 | else: 217 | # Create an empty directory for non-pathology images 218 | print(f"{image_idx}: Non-Pathology Images") 219 | save_directory = os.path.join(os.getcwd(), HISTO_PATCH_SAVE_PATH, ARCH_DATA_PATH, books_pubmed_class, img_uuid[image_idx]) 220 | if not os.path.isdir(save_directory): 221 | os.mkdir(save_directory) 222 | 223 | print(f"{image_idx}/{len(img_uuid)}: Ended ") 224 | print(".........") 225 | 226 | 227 | if __name__ == "__main__": 228 | 229 | # # Get all the open-ended images of PathVQA 230 | # img_general, img_general_path = get_path_vqa_open_images(PVQA_DATA_PATH) 231 | 232 | # # Generate the top patches using histocartography and save them 233 | # save_histocartography_top_patches(img_general, img_general_path) 234 | 235 | pubmed_img_uuid, pubmed_img_uuid_path, books_img_uuid, books_img_uuid_path = get_arch_open_images(ARCH_DATA_PATH) 236 | 237 | save_histocartography_top_patches_arch(pubmed_img_uuid, pubmed_img_uuid_path, "pubmed") 238 | 239 | save_histocartography_top_patches_arch(books_img_uuid, books_img_uuid_path, "books") 240 | -------------------------------------------------------------------------------- /generate_llava_med_query.py: -------------------------------------------------------------------------------- 1 | # This program will generate the query .json files which will need to be input into LLaVA-Med to generate descriptions or answers 2 | # for all the images in PathVQA and all the relevant patches 3 | 4 | import os 5 | import pandas as pd 6 | import pickle 7 | from tqdm import tqdm 8 | import json 9 | 10 | PVQA_DATA_PATH = "pvqa" 11 | HISTO_PATCH_SAVE_PATH = "histo_image_patch" 12 | LLAVA_MED_QUERY_PATH = os.path.join("files", "query") 13 | 14 | def get_path_vqa_open_data(data_path : str = "pvqa"): 15 | # Reading the PathVQA dataset 16 | qas_train_path = os.path.join(data_path, "qas", "test", "test_qa.pkl") 17 | with open(qas_train_path, 'rb') as file: 18 | pvqa_qas = pickle.load(file) 19 | 20 | # Getting all open-ended images 21 | qas_general = [qas for qas in pvqa_qas if qas['answer'] != 'yes' and qas['answer'] != 'no'] 22 | img_general = [qas['image'] for qas in qas_general] 23 | 24 | return qas_general, img_general 25 | 26 | def generate_image_files_direct(qas_general: list, img_general: list): 27 | # Generating all the info for images with original questions 28 | question = [] 29 | idx = 0 30 | for i in range(0, len(qas_general)): 31 | question.append({"question_id": idx, "image": img_general[i]+'.jpg', "text": qas_general[i]+"\n"}) 32 | idx = idx+1 33 | 34 | # Writing each dictionary as a JSON object on a new line 35 | img_direct_path = os.path.join(LLAVA_MED_QUERY_PATH, 'image_direct.jsonl') 36 | with open(img_direct_path, 'w') as file: 37 | for item in question: 38 | json_line = json.dumps(item) # Convert dictionary to JSON string 39 | file.write(json_line + '\n') 40 | 41 | def generate_patch_files_direct(qas_general: list, img_general: list): 42 | # Generating all the info for patches with original questions 43 | question = [] 44 | idx = 0 45 | for i in range(0, len(qas_general)): 46 | question.append({"question_id": idx, "image": img_general[i]+"/1.png", "text": qas_general[i]+"\n"}) 47 | idx = idx+1 48 | question.append({"question_id": idx, "image": img_general[i]+"/2.png", "text": qas_general[i]+"\n"}) 49 | idx = idx+1 50 | question.append({"question_id": idx, "image": img_general[i]+"/3.png", "text": qas_general[i]+"\n"}) 51 | idx = idx+1 52 | 53 | # Writing each dictionary as a JSON object on a new line 54 | patch_direct_path = os.path.join(LLAVA_MED_QUERY_PATH, 'patch_direct.jsonl') 55 | with open(patch_direct_path, 'w') as file: 56 | for item in question: 57 | json_line = json.dumps(item) # Convert dictionary to JSON string 58 | file.write(json_line + '\n') 59 | 60 | def generate_image_files_description(img_general: list): 61 | # Getting the description info for all the images 62 | question = [] 63 | idx = 0 64 | for i in img_general: 65 | question.append({"question_id": idx, "image": i+'.jpg', "text": "Describe the following image in detail.\n"}) 66 | idx = idx+1 67 | 68 | # Writing each dictionary as a JSON object on a new line 69 | img_desc_path = os.path.join(LLAVA_MED_QUERY_PATH, 'image_description.jsonl') 70 | with open(img_desc_path, 'w') as file: 71 | for item in question: 72 | json_line = json.dumps(item) # Convert dictionary to JSON string 73 | file.write(json_line + '\n') 74 | 75 | def generate_patch_files_description(): 76 | # Getting all the directories of generated patches 77 | patch_dir_list = os.listdir(HISTO_PATCH_SAVE_PATH) 78 | patch_dir_list = patch_dir_list.sort() 79 | 80 | # Getting the description info for top 3 patches 81 | question = [] 82 | idx = 0 83 | for i in patch_dir_list: 84 | question.append({"question_id": idx, "image": i+"/1.png", "text": "Describe the following image in detail.\n"}) 85 | idx = idx+1 86 | question.append({"question_id": idx, "image": i+"/2.png", "text": "Describe the following image in detail.\n"}) 87 | idx = idx+1 88 | question.append({"question_id": idx, "image": i+"/3.png", "text": "Describe the following image in detail.\n"}) 89 | idx = idx+1 90 | 91 | # Writing each dictionary as a JSON object on a new line 92 | patch_desc_path = os.path.join(LLAVA_MED_QUERY_PATH, 'patch_description.jsonl') 93 | with open(patch_desc_path, 'w') as file: 94 | for item in question: 95 | json_line = json.dumps(item) # Convert dictionary to JSON string 96 | file.write(json_line + '\n') 97 | 98 | if __name__ == "__main__": 99 | 100 | # Get all the data from PathVQA 101 | qas_general, img_general = get_path_vqa_open_data(PVQA_DATA_PATH) 102 | 103 | # Generate Image File with Answers 104 | generate_image_files_direct(qas_general, img_general) 105 | 106 | # Generate Patch File with Answers 107 | generate_patch_files_direct(qas_general, img_general) 108 | 109 | # Generate Image File with Descriptions 110 | generate_image_files_description(img_general) 111 | 112 | # Generate Patch File with Answers 113 | generate_patch_files_description() 114 | -------------------------------------------------------------------------------- /images/path-rag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/embedded-robotics/path-rag/e1bd5ea6aec4825d708b4f7b77b2fcaf52fff824/images/path-rag.png -------------------------------------------------------------------------------- /images/text.txt: -------------------------------------------------------------------------------- 1 | Add repo images 2 | -------------------------------------------------------------------------------- /llama_7B_model_weights.py: -------------------------------------------------------------------------------- 1 | # This program will download LLaMA-7B weights from HuggingFace and store the resulting model into the path specified by $HF_HOME 2 | # By Default $HF_HOME = ~/.cache/huggingface 3 | 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | 6 | # Load tokenizer and model 7 | tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") 8 | model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b") 9 | 10 | # Try out an example to make sure the base model is downloaded successfully and working 11 | input_text = "The future of AI in healthcare is" 12 | input_ids = tokenizer.encode(input_text, return_tensors="pt") 13 | # Generate text 14 | output = model.generate(input_ids, max_length=50, num_return_sequences=1) 15 | # Decode and print the generated text 16 | print(tokenizer.decode(output[0], skip_special_tokens=True)) -------------------------------------------------------------------------------- /model_delta_weights/text.txt: -------------------------------------------------------------------------------- 1 | For model delta weights 2 | -------------------------------------------------------------------------------- /recall_calculation.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import re 3 | 4 | contractions = { 5 | "aint": "ain't", 6 | "arent": "aren't", 7 | "cant": "can't", 8 | "couldve": "could've", 9 | "couldnt": "couldn't", 10 | "couldn'tve": "couldn't've", 11 | "couldnt've": "couldn't've", 12 | "didnt": "didn't", 13 | "doesnt": "doesn't", 14 | "dont": "don't", 15 | "hadnt": "hadn't", 16 | "hadnt've": "hadn't've", 17 | "hadn'tve": "hadn't've", 18 | "hasnt": "hasn't", 19 | "havent": "haven't", 20 | "hed": "he'd", 21 | "hed've": "he'd've", 22 | "he'dve": "he'd've", 23 | "hes": "he's", 24 | "howd": "how'd", 25 | "howll": "how'll", 26 | "hows": "how's", 27 | "Id've": "I'd've", 28 | "I'dve": "I'd've", 29 | "Im": "I'm", 30 | "Ive": "I've", 31 | "isnt": "isn't", 32 | "itd": "it'd", 33 | "itd've": "it'd've", 34 | "it'dve": "it'd've", 35 | "itll": "it'll", 36 | "let's": "let's", 37 | "maam": "ma'am", 38 | "mightnt": "mightn't", 39 | "mightnt've": "mightn't've", 40 | "mightn'tve": "mightn't've", 41 | "mightve": "might've", 42 | "mustnt": "mustn't", 43 | "mustve": "must've", 44 | "neednt": "needn't", 45 | "notve": "not've", 46 | "oclock": "o'clock", 47 | "oughtnt": "oughtn't", 48 | "ow's'at": "'ow's'at", 49 | "'ows'at": "'ow's'at", 50 | "'ow'sat": "'ow's'at", 51 | "shant": "shan't", 52 | "shed've": "she'd've", 53 | "she'dve": "she'd've", 54 | "she's": "she's", 55 | "shouldve": "should've", 56 | "shouldnt": "shouldn't", 57 | "shouldnt've": "shouldn't've", 58 | "shouldn'tve": "shouldn't've", 59 | "somebody'd": "somebodyd", 60 | "somebodyd've": "somebody'd've", 61 | "somebody'dve": "somebody'd've", 62 | "somebodyll": "somebody'll", 63 | "somebodys": "somebody's", 64 | "someoned": "someone'd", 65 | "someoned've": "someone'd've", 66 | "someone'dve": "someone'd've", 67 | "someonell": "someone'll", 68 | "someones": "someone's", 69 | "somethingd": "something'd", 70 | "somethingd've": "something'd've", 71 | "something'dve": "something'd've", 72 | "somethingll": "something'll", 73 | "thats": "that's", 74 | "thered": "there'd", 75 | "thered've": "there'd've", 76 | "there'dve": "there'd've", 77 | "therere": "there're", 78 | "theres": "there's", 79 | "theyd": "they'd", 80 | "theyd've": "they'd've", 81 | "they'dve": "they'd've", 82 | "theyll": "they'll", 83 | "theyre": "they're", 84 | "theyve": "they've", 85 | "twas": "'twas", 86 | "wasnt": "wasn't", 87 | "wed've": "we'd've", 88 | "we'dve": "we'd've", 89 | "weve": "we've", 90 | "werent": "weren't", 91 | "whatll": "what'll", 92 | "whatre": "what're", 93 | "whats": "what's", 94 | "whatve": "what've", 95 | "whens": "when's", 96 | "whered": "where'd", 97 | "wheres": "where's", 98 | "whereve": "where've", 99 | "whod": "who'd", 100 | "whod've": "who'd've", 101 | "who'dve": "who'd've", 102 | "wholl": "who'll", 103 | "whos": "who's", 104 | "whove": "who've", 105 | "whyll": "why'll", 106 | "whyre": "why're", 107 | "whys": "why's", 108 | "wont": "won't", 109 | "wouldve": "would've", 110 | "wouldnt": "wouldn't", 111 | "wouldnt've": "wouldn't've", 112 | "wouldn'tve": "wouldn't've", 113 | "yall": "y'all", 114 | "yall'll": "y'all'll", 115 | "y'allll": "y'all'll", 116 | "yall'd've": "y'all'd've", 117 | "y'alld've": "y'all'd've", 118 | "y'all'dve": "y'all'd've", 119 | "youd": "you'd", 120 | "youd've": "you'd've", 121 | "you'dve": "you'd've", 122 | "youll": "you'll", 123 | "youre": "you're", 124 | "youve": "you've", 125 | } 126 | 127 | manual_map = { 128 | "none": "0", 129 | "zero": "0", 130 | "one": "1", 131 | "two": "2", 132 | "three": "3", 133 | "four": "4", 134 | "five": "5", 135 | "six": "6", 136 | "seven": "7", 137 | "eight": "8", 138 | "nine": "9", 139 | "ten": "10", 140 | } 141 | articles = ["a", "an", "the"] 142 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 143 | comma_strip = re.compile("(\d)(\,)(\d)") 144 | punct = [ 145 | ";", 146 | r"/", 147 | "[", 148 | "]", 149 | '"', 150 | "{", 151 | "}", 152 | "(", 153 | ")", 154 | "=", 155 | "+", 156 | "\\", 157 | "_", 158 | "-", 159 | ">", 160 | "<", 161 | "@", 162 | "`", 163 | ",", 164 | "?", 165 | "!", 166 | ] 167 | 168 | def normalize_word(token): 169 | _token = token 170 | for p in punct: 171 | if (p + " " in token or " " + p in token) or ( 172 | re.search(comma_strip, token) != None 173 | ): 174 | _token = _token.replace(p, "") 175 | else: 176 | _token = _token.replace(p, " ") 177 | token = period_strip.sub("", _token, re.UNICODE) 178 | 179 | _token = [] 180 | temp = token.lower().split() 181 | for word in temp: 182 | word = manual_map.setdefault(word, word) 183 | if word not in articles: 184 | _token.append(word) 185 | for i, word in enumerate(_token): 186 | if word in contractions: 187 | _token[i] = contractions[word] 188 | token = " ".join(_token) 189 | token = token.replace(",", "") 190 | return token 191 | 192 | def split_sentence(sentence, n): 193 | words = defaultdict(int) 194 | # tmp_sentence = re.sub("[^a-zA-Z ]", "", sentence) 195 | tmp_sentence = sentence 196 | tmp_sentence = tmp_sentence.lower() 197 | tmp_sentence = tmp_sentence.strip().split() 198 | length = len(tmp_sentence) 199 | for i in range(length - n + 1): 200 | tmp_words = " ".join(tmp_sentence[i: i + n]) 201 | if tmp_words: 202 | words[tmp_words] += 1 203 | return words 204 | 205 | def calculate_f1score(candidate, reference): 206 | 207 | candidate = normalize_word(candidate) 208 | reference = normalize_word(reference) 209 | 210 | candidate_words = split_sentence(candidate, 1) 211 | reference_words = split_sentence(reference, 1) 212 | word_set = set() 213 | for word in candidate_words: 214 | word_set.add(word) 215 | for word in reference_words: 216 | word_set.add(word) 217 | 218 | tp = 0 219 | fp = 0 220 | fn = 0 221 | for word in word_set: 222 | if word in candidate_words and word in reference_words: 223 | tp += candidate_words[word] 224 | elif word in candidate_words and word not in reference_words: 225 | fp += candidate_words[word] 226 | elif word not in candidate_words and word in reference_words: 227 | fn += reference_words[word] 228 | 229 | if len(candidate_words) == 0: 230 | return 0, 0, 0 # "0 (warning: length of candidate's words is 0)" 231 | elif len(reference_words) == 0: 232 | return 0, 0, 0 233 | else: 234 | precision = tp / (tp + fp) 235 | recall = tp / (tp + fn) 236 | if tp == 0: 237 | return 0, 0, 0 238 | else: 239 | return 2 * precision * recall / (precision + recall), precision, recall 240 | 241 | 242 | def recall_score(eval_list: list, answer_key: str='answer', prediction_key: str='llava_answer'): 243 | score = [] 244 | for i in eval_list: 245 | gt = i[answer_key] 246 | pred = i[prediction_key] 247 | f1_score, precision, recall = calculate_f1score(pred, gt) 248 | score.append(recall) 249 | total = 0 250 | for i in score: 251 | total = total + i 252 | return total/len(score) -------------------------------------------------------------------------------- /tem_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# !ln -sf /usr/lib/x86_64-linux-gnu/libffi.so.7 /data/mn27889/miniconda3/envs/path-rag/lib/libffi.so.7" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from PIL import Image\n", 19 | "import numpy as np\n", 20 | "import os\n", 21 | "import pickle\n", 22 | "import json\n", 23 | "import numpy as np\n", 24 | "import sys\n", 25 | "sys.path.append(os.path.join(os.getcwd(), 'histocartography'))\n", 26 | "from histocartography.preprocessing import NucleiExtractor, DeepFeatureExtractor, KNNGraphBuilder" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "HISTO_PATCH_SAVE_PATH = \"histo_image_patch\"\n", 36 | "ACRH_DATA_PATH = \"arch\"" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# Cell Graph Generation Definitions\n", 46 | "nuclei_detector = NucleiExtractor()\n", 47 | "feats_extractor = DeepFeatureExtractor(architecture='resnet34', patch_size=72, resize_size=224)\n", 48 | "knn_graph_builder = KNNGraphBuilder(k=5, thresh=50, add_loc_feats=True)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# PathVQA Dataset Processing\n", 58 | "def get_arch_open_images(data_path : str = \"arch\"):\n", 59 | " \n", 60 | " # Extract the images for PubMed dataset\n", 61 | " pubmed_path = os.path.join(data_path, \"pubmed_set\")\n", 62 | " pubmed_path_captions = os.path.join(pubmed_path, \"captions.json\")\n", 63 | " pubmed_path_images = os.path.join(pubmed_path, \"images\")\n", 64 | " \n", 65 | " with open(pubmed_path_captions, 'rb') as file:\n", 66 | " pubmed_captions = json.load(file)\n", 67 | " \n", 68 | " # Getting all open-ended images\n", 69 | " pubmed_img_uuid = [value['uuid'] for index, value in pubmed_captions.items()]\n", 70 | " pubmed_img_uuid_path = [os.path.join(pubmed_path_images, img_uuid + '.jpg') for img_uuid in pubmed_img_uuid]\n", 71 | "\n", 72 | " # Extract the images for Books dataset\n", 73 | " books_path = os.path.join(data_path, \"books_set\")\n", 74 | " books_path_captions = os.path.join(books_path, \"captions.json\")\n", 75 | " books_path_images = os.path.join(books_path, \"images\")\n", 76 | " \n", 77 | " with open(books_path_captions, 'rb') as file:\n", 78 | " books_captions = json.load(file)\n", 79 | " \n", 80 | " # Getting all open-ended images\n", 81 | " books_img_uuid = [value['uuid'] for index, value in books_captions.items()]\n", 82 | " books_img_uuid_path = [os.path.join(books_path_images, img_uuid + '.png') for img_uuid in books_img_uuid]\n", 83 | " \n", 84 | " return pubmed_img_uuid, pubmed_img_uuid_path, books_img_uuid, books_img_uuid_path" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "pubmed_img_uuid, pubmed_img_uuid_path, books_img_uuid, books_img_uuid_path = get_arch_open_images()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "# Save top patches using histocartography\n", 103 | "def save_histocartography_top_patches_arch(img_uuid : list, img_uuid_path: list, books_pubmed_class: str):\n", 104 | " # for image_idx in range(0, len(img_uuid)):\n", 105 | " \n", 106 | " for image_idx in range(0, 10):\n", 107 | " print(f\"{image_idx}/{len(img_uuid)}: Started \")\n", 108 | " query_img = Image.open(img_uuid_path[image_idx]).convert(mode=\"RGB\")\n", 109 | " image = np.array(query_img)\n", 110 | " nuclei_map, nuclei_centers = nuclei_detector.process(image)\n", 111 | "\n", 112 | " # Only consider if more than 5 nuclei are detected since knn needs to form a graph using 5 neighbors.\n", 113 | " # If less than 5 nuclei are present, most of the images are not pathology related\n", 114 | " if nuclei_centers.shape[0] > 5:\n", 115 | " print(f\"{image_idx}: Patches \")\n", 116 | " \n", 117 | " # Get the Features\n", 118 | " features = feats_extractor.process(image, nuclei_map)\n", 119 | " \n", 120 | " # Make Cell Graph\n", 121 | " cell_graph = knn_graph_builder.process(nuclei_map, features)\n", 122 | " \n", 123 | " # Make calculations to extract patches and the overlap images\n", 124 | " width, height = query_img.size\n", 125 | " width_range = np.linspace(0, width, 4, dtype=int)\n", 126 | " height_range = np.linspace(0, height, 4, dtype=int)\n", 127 | "\n", 128 | " overlap_percent = 20\n", 129 | " width_overlap = int((overlap_percent/100) * width)\n", 130 | " height_overlap = int((overlap_percent/100) * height)\n", 131 | " \n", 132 | " # Extract the patches\n", 133 | " image_patches = []\n", 134 | " patch_nuclei_centers = []\n", 135 | " for i in range(len(width_range)-1):\n", 136 | " for j in range(len(height_range)-1):\n", 137 | " # Consider the overlap width from second patch only\n", 138 | " if i != 0:\n", 139 | " start_width = width_range[i] - width_overlap\n", 140 | " else:\n", 141 | " start_width = width_range[i]\n", 142 | "\n", 143 | " # Consider the overlap height from second patch only\n", 144 | " if j != 0:\n", 145 | " start_height = height_range[j] - height_overlap\n", 146 | " else:\n", 147 | " start_height = height_range[j]\n", 148 | " \n", 149 | " # List out the patch ranges\n", 150 | " left = start_width\n", 151 | " upper = start_height\n", 152 | " right = width_range[i+1]\n", 153 | " lower = height_range[j+1]\n", 154 | " \n", 155 | " center_list = []\n", 156 | " for center in nuclei_centers:\n", 157 | " if ((center[0] >= left) and (center[0] <=right) and \n", 158 | " (center[1] >= upper) and (center[1] <=lower)):\n", 159 | " center_list.append(center)\n", 160 | "\n", 161 | " image_patches.append(query_img.crop((left, upper, right, lower)))\n", 162 | " patch_nuclei_centers.append(center_list)\n", 163 | "\n", 164 | " # Calculate the length of nuclei in each patch\n", 165 | " patch_center_length = []\n", 166 | " for center in patch_nuclei_centers:\n", 167 | " patch_center_length.append(len(center))\n", 168 | " \n", 169 | " # Sort the patch indices based on maximum number of nuclei\n", 170 | " sorted_indices_desc = np.flip(np.argsort(patch_center_length))\n", 171 | " \n", 172 | " # Create a directory to store all the patches of the image\n", 173 | " save_directory = os.path.join(os.getcwd(), HISTO_PATCH_SAVE_PATH, ACRH_DATA_PATH, books_pubmed_class, img_uuid[image_idx])\n", 174 | " if not os.path.isdir(save_directory):\n", 175 | " os.mkdir(save_directory)\n", 176 | " \n", 177 | " # Store all the image patches into the newly created directory\n", 178 | " for patch_index in range(0,6):\n", 179 | " save_file_path = os.path.join(save_directory, str(patch_index+1) + \".png\")\n", 180 | " image_patches[sorted_indices_desc[patch_index]].save(save_file_path)\n", 181 | " \n", 182 | " print(f\"{image_idx}/{len(img_uuid)}: Ended \")\n", 183 | " print(\".........\")" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "save_histocartography_top_patches_arch(books_img_uuid, books_img_uuid_path, \"books\")" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [] 201 | } 202 | ], 203 | "metadata": { 204 | "kernelspec": { 205 | "display_name": "path-rag", 206 | "language": "python", 207 | "name": "python3" 208 | }, 209 | "language_info": { 210 | "codemirror_mode": { 211 | "name": "ipython", 212 | "version": 3 213 | }, 214 | "file_extension": ".py", 215 | "mimetype": "text/x-python", 216 | "name": "python", 217 | "nbconvert_exporter": "python", 218 | "pygments_lexer": "ipython3", 219 | "version": "3.9.19" 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 2 224 | } 225 | -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Path-RAG: Knowledge-Guided Key Region Retrieval for Open-ended Pathology Visual Question Answering\n", 8 | "\n", 9 | "

\n", 10 | "
\n", 11 | " \n", 12 | " *Accurate diagnosis and prognosis assisted by pathology images are essential for cancer treatment selection and planning. Despite the recent trend of adopting deep-learning approaches for analyzing complex pathology images, they fall short as they often overlook the domain-expert understanding of tissue structure and cell composition. In this work, we focus on a challenging Open-ended Pathology VQA (PathVQA-Open) task and propose a novel framework named Path-RAG, which leverages HistoCartography to retrieve relevant domain knowledge from pathology images and significantly improves performance on PathVQA-Open. Admitting the complexity of pathology image analysis, Path-RAG adopts a human-centered AI approach by retrieving domain knowledge using HistoCartography to select the relevant patches from pathology images. Our experiments suggest that domain guidance can significantly boost the accuracy of LLaVA-Med from 38\\% to 47\\%, with a notable gain of 28\\% for H\\&E-stained pathology images in the PathVQA-Open dataset. For longer-form question and answer pairs, our model consistently achieves significant improvements of 32.5\\% in ARCH-Open PubMed and 30.6\\% in ARCH-Open Books on H\\&E images.*\n", 13 | "

\n", 14 | "\n", 15 | "---\n", 16 | "\n", 17 | "Awais Naeem*, Tianhao Li*, Huang-Ru Liao*, Jiawei Xu*, Aby Mammen Mathew*, Zehao Zhu*, Zhen Tan**, Ajay Jaiswal*, Raffi Salibian*** , Ziniu Hu*** , Tianlong Chen****, Ying Ding*\n", 18 | "\n", 19 | "*University of Texas at Austin, USA \\\n", 20 | "**Arizona State University, USA \\\n", 21 | "***University of California, Los Angeles, USA \\\n", 22 | "****Massachusetts Institute of Technology, USA\n", 23 | "\n", 24 | "---\n", 25 | "\n", 26 | "## Path-RAG Implementation\n", 27 | "\n", 28 | "### 1. Clone this repository and navigate to path-rag folder\n", 29 | "\n", 30 | "```Shell\n", 31 | "git clone https://github.com/embedded-robotics/path-rag.git\n", 32 | "cd path-rag\n", 33 | "```\n", 34 | "\n", 35 | "### 2. Install Package: Create conda environment\n", 36 | "\n", 37 | "```Shell\n", 38 | "conda create -n path-rag python=3.10 -y\n", 39 | "conda activate path-rag\n", 40 | "pip install --upgrade pip # enable PEP 660 support for LLaVA-Med\n", 41 | "```\n", 42 | "\n", 43 | "### 3. Download the PathVQA dataset from the following link\n", 44 | "\n", 45 | "[PathVQA Dataset](https://github.com/UCSD-AI4H/PathVQA/blob/master/data/README.md)\n", 46 | "\n", 47 | "### 4. Clone the HistoCartography tool, setup the model checkpoints in `histocartography/checkpoints` and install the dependencies\n", 48 | "\n", 49 | "```Shell\n", 50 | "git clone https://github.com/BiomedSciAI/histocartography\n", 51 | "```\n", 52 | "\n", 53 | "### 5. Clone the LLaVA-Med repository and install the dependencies\n", 54 | "\n", 55 | "```Shell\n", 56 | "git clone https://github.com/microsoft/LLaVA-Med\n", 57 | "```\n", 58 | "\n", 59 | "### 6. Download the LLaMA-7B model and weights from HuggingFace\n", 60 | "\n", 61 | "```Shell\n", 62 | "python llama_7B_model_weights.py # LLaMA-7B weights/model stored into $HF_HOME (By Default $HF_HOME = ~/.cache/huggingface)\n", 63 | "```\n", 64 | "\n", 65 | "### 7. Download LLaVA-Med delta weights `llava_med_in_text_60k_ckpt2_delta` and `pvqa-9epoch_delta` from `https://github.com/microsoft/LLaVA-Med#model-download`. Put them inside a folder named `model_delta_weights`\n", 66 | "\n", 67 | "### 8. Apply the LLaVA-Med delta weights to base LLaMA-7B to come up with the final weights for LLaVA-Med\n", 68 | "\n", 69 | "```Shell\n", 70 | "cd LLaVA-Med\n", 71 | "```\n", 72 | "\n", 73 | "#### LLaVA-Med pre-trained on general biomedicine data\n", 74 | "\n", 75 | "```Shell\n", 76 | "!python3 -m llava.model.apply_delta \\\n", 77 | " --base ~/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16 \\\n", 78 | " --target ../final_models/llava_med \\\n", 79 | " --delta ../model_delta_weights/llava_med_in_text_60k_ckpt2_delta\n", 80 | "```\n", 81 | "\n", 82 | "#### LLaVA-Med fine-tuned on PathVQA\n", 83 | "\n", 84 | "```Shell\n", 85 | "!python -m llava.model.apply_delta \\\n", 86 | " --base ~/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16 \\\n", 87 | " --target ../final_models/llava_med_pvqa \\\n", 88 | " --delta ../model_delta_weights/pvqa-9epoch_delta\n", 89 | "```\n", 90 | "\n", 91 | "```Shell\n", 92 | "cd ..\n", 93 | "```\n", 94 | "\n", 95 | "### 9. Generate the top patches for open-ended PathVQA images using HistoCartography\n", 96 | "\n", 97 | "```Shell\n", 98 | "python generate_histo_patches.py\n", 99 | "```\n", 100 | "\n", 101 | "### 10. Generate the files for query to be asked for LLaVA-Med for both the images and patches\n", 102 | "\n", 103 | "```Shell\n", 104 | "python generate_llava_med_query.py\n", 105 | "```\n", 106 | "\n", 107 | "### 11. Now we need to generate the answer for all the query files using raw model (`final_models/llava_med`) and fine-tuned model (`final_models/llava_med_pvqa`)\n", 108 | "\n", 109 | "```Shell\n", 110 | "cd LLaVA-Med\n", 111 | "```\n", 112 | "\n", 113 | "#### Raw Model\n", 114 | "```Shell\n", 115 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med \\\n", 116 | " --question-file ../files/query/image_direct.jsonl \\\n", 117 | " --image-folder ../pvqa/images/test \\\n", 118 | " --answers-file ../files/answer/raw/answer_image_direct.jsonl\n", 119 | "```\n", 120 | "\n", 121 | "```Shell\n", 122 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med \\\n", 123 | " --question-file ../files/query/patch_direct.jsonl \\\n", 124 | " --image-folder ../pvqa/images/test \\\n", 125 | " --answers-file ../files/answer/raw/answer_patch_direct.jsonl\n", 126 | "```\n", 127 | "\n", 128 | "```Shell\n", 129 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med \\\n", 130 | " --question-file ../files/query/image_description.jsonl \\\n", 131 | " --image-folder ../pvqa/images/test \\\n", 132 | " --answers-file ../files/answer/raw/answer_image_description.jsonl\n", 133 | "```\n", 134 | "\n", 135 | "```Shell\n", 136 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med \\\n", 137 | " --question-file ../files/query/patch_description.jsonl \\\n", 138 | " --image-folder ../pvqa/images/test \\\n", 139 | " --answers-file ../files/answer/raw/answer_patch_description.jsonl\n", 140 | "```\n", 141 | "\n", 142 | "#### Fine-Tuned Model\n", 143 | "```Shell\n", 144 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \\\n", 145 | " --question-file ../files/query/image_direct.jsonl \\\n", 146 | " --image-folder ../pvqa/images/test \\\n", 147 | " --answers-file ../files/answer/fine-tuned/answer_image_direct.jsonl\n", 148 | "```\n", 149 | "\n", 150 | "```Shell\n", 151 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \\\n", 152 | " --question-file ../files/query/patch_direct.jsonl \\\n", 153 | " --image-folder ../pvqa/images/test \\\n", 154 | " --answers-file ../files/answer/fine-tuned/answer_patch_direct.jsonl\n", 155 | "```\n", 156 | "\n", 157 | "```Shell\n", 158 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \\\n", 159 | " --question-file ../files/query/image_description.jsonl \\\n", 160 | " --image-folder ../pvqa/images/test \\\n", 161 | " --answers-file ../files/answer/fine-tuned/answer_image_description.jsonl\n", 162 | "```\n", 163 | "\n", 164 | "```Shell\n", 165 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \\\n", 166 | " --question-file ../files/query/patch_description.jsonl \\\n", 167 | " --image-folder ../pvqa/images/test \\\n", 168 | " --answers-file ../files/answer/fine-tuned/answer_patch_description.jsonl\n", 169 | "```\n", 170 | "\n", 171 | "### 12. Evaluate the results for different use-cases using `recall_calculation.py`\n", 172 | "\n", 173 | "(i) Path-RAG w/o GPT: Combine the answer of image + all patches to be the final predicted answer\\\n", 174 | "(ii) Path-RAG (description): Combine the description of image + all patches. Then involve GPT-4 for reasoning to ge the final predicted answer (See Supplementary Section for Prompts)\\\n", 175 | "(iii) Path-RAG (answer): Combine the answer of image + all patches. Then involve GPT-4 for reasoning to ge the final predicted answer (See Supplementary Section for Prompts)\n", 176 | "\n", 177 | "\n", 178 | "## ARCH-Open Dataset\n", 179 | "\n", 180 | "1. Download the `books_set` and `pubmed_set` of ARCH dataset from `https://warwick.ac.uk/fac/cross_fac/tia/data/arch`. Store both of these folders in a folder named `arch`. Both `books_set` and `pubmed_set` contains `captions.json` which lists a **caption** and a **UUID**, whereas **UUID** represents the file name in `images` folder and **caption** represents the description of the image.\n", 181 | "\n", 182 | "2. Using `captions.json` and `images` folder under `arch/books_set`, run the notebooks `ARCH-OPEN/books_data/synthetic_data_textbook.ipynb` by specifying the OpenAI credentials to generate the question-answer pairs for books set\n", 183 | "\n", 184 | "3. Using `captions.json` and `images` folder under `arch/pubmed_set`, run the notebooks `ARCH-OPEN/pubmed_data/synthetic_data_pubmed.ipynb` by specifying the OpenAI credentials to generate the question-answer pairs for pubmed set\n", 185 | "\n", 186 | "4. Run the notebook `ARCH-OPEN/synthetic_data_compilation.ipynb` to compile the `pubmed` and `books` question-answer pairs into json files namely `ARCH-OPEN/pubmed_qa_pairs.json` and `ARCH-OPEN/textbook_qa_pairs.json`. These files are already provided to be used directly\n", 187 | "\n", 188 | "5. The `pubmed_qa_pairs.json` and `textbook_qa_pairs.json` files contain 5 question-pairs for each pair of `caption` and `uuid` (refers to image name in arch data `arch/pubmed_set/images`, `arch/books_set/images`) in the following format (for both `pubmed_set` and `books_set`):\n", 189 | "\n", 190 | "```Shell\n", 191 | " {\n", 192 | " \"figure_id\": \"00\",\n", 193 | " \"letter\": \"A\",\n", 194 | " \"caption\": \" A, Spindle cell variant of embryonal rhabdomyosarcoma is characterized by fascicles of eosinophilic spindle cells (B), some of which can show prominent paranuclear vacuolisation, as seen in leiomyosarcoma.\",\n", 195 | " \"uuid\": \"890e2e79-ab0a-4a2e-9d62-b0b6b3d43884\",\n", 196 | " \"Question_1\": \"What could be the general shape of cells in a spindle cell variant of embryonal rhabdomyosarcoma as seen in the image?\",\n", 197 | " \"Answer_1\": \"The cells often present with a spindle-like elongated shape.\",\n", 198 | " \"Question_2\": \"What type of structures could be visible in the image indicating the presence of spindle cells?\",\n", 199 | " \"Answer_2\": \"Fascicles, or bundles, of cells could be visible in the image, indicating the presence of spindle cells.\",\n", 200 | " \"Question_3\": \"Where in the cell would we likely find paranuclear vacuolisation in the image?\",\n", 201 | " \"Answer_3\": \"Paranuclear vacuolisation is usually seen around the nucleus area of the cell.\",\n", 202 | " \"Question_4\": \"What color might the spindle cells appear in the image?\",\n", 203 | " \"Answer_4\": \"The spindle cells may appear eosinophilic, or pinkish-red, under the microscope due to staining.\",\n", 204 | " \"Question_5\": \"What visual feature might differentiate spindle cells from leiomyosarcoma cells in the image?\",\n", 205 | " \"Answer_5\": \"Spindle cells might show prominent paranuclear vacuolisation, a feature that can differentiate them from leiomyosarcoma cells.\"\n", 206 | " }\n", 207 | "```\n", 208 | "\n", 209 | "## Acknolwdgement\n", 210 | "We would like to acknowledge the following funding supports: NIH OT2OD032581, NIH OTA-21-008, NIH 1OT2OD032742-01, NSF 2333703, NSF 2303038." 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [] 217 | } 218 | ], 219 | "metadata": { 220 | "language_info": { 221 | "name": "python" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 2 226 | } 227 | --------------------------------------------------------------------------------