├── .gitignore
├── ARCH-OPEN
├── books_data
│ ├── llm_qa_pairs_books_0_4305.pkl
│ └── synthetic_data_textbook.ipynb
├── pubmed_data
│ ├── llm_qa_pairs_pubmed_0_3309.pkl
│ └── synthetic_data_pubmed.ipynb
├── pubmed_qa_pairs.json
├── synthetic_data_compilation.ipynb
└── textbook_qa_pairs.json
├── README.md
├── files
├── answer
│ ├── fine-tuned
│ │ └── text.txt
│ └── raw
│ │ └── text.txt
└── query
│ └── text.txt
├── final_models
└── text.txt
├── generate_histo_patches.ipynb
├── generate_histo_patches.py
├── generate_llava_med_query.py
├── images
├── path-rag.png
└── text.txt
├── llama_7B_model_weights.py
├── model_delta_weights
└── text.txt
├── recall_calculation.py
├── tem_data.ipynb
└── test.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | pvqa/*
2 | LLaVA-Med/*
3 | histo_image_patch/*
4 | nohup.out
5 | arch/*
6 | histocartography/*
--------------------------------------------------------------------------------
/ARCH-OPEN/books_data/llm_qa_pairs_books_0_4305.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/embedded-robotics/path-rag/e1bd5ea6aec4825d708b4f7b77b2fcaf52fff824/ARCH-OPEN/books_data/llm_qa_pairs_books_0_4305.pkl
--------------------------------------------------------------------------------
/ARCH-OPEN/books_data/synthetic_data_textbook.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import os\n",
11 | "import json\n",
12 | "import openai\n",
13 | "import backoff\n",
14 | "from PIL import Image\n",
15 | "import pickle\n",
16 | "import time\n",
17 | "import sys"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "sys.path.append(\"data/mn27889/path-rag\")\n",
27 | "file_path = os.path.join(\"arch\", \"books_set\", \"captions.json\")\n",
28 | "\n",
29 | "with open(file_path, 'rb') as file:\n",
30 | " captions_data = json.load(file)"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "OPENAI_API_PATH = os.path.join(os.getcwd(), 'api.key')\n",
40 | "\n",
41 | "with open(OPENAI_API_PATH) as f:\n",
42 | " openai.api_key = f.read().strip()\n",
43 | "\n",
44 | "@backoff.on_exception(backoff.expo, openai.OpenAIError)\n",
45 | "def completions_with_backoff(**kwargs):\n",
46 | " return openai.chat.completions.create(**kwargs)\n",
47 | "\n",
48 | "def gpt(user_prompt, system_prompt=\"You are an expert pathologist\", model=\"gpt-4\", temperature=0.7, max_tokens=1000) -> list:\n",
49 | "\n",
50 | " messages = [{\"role\": \"system\", \"content\": system_prompt},\n",
51 | " {\"role\": \"user\", \"content\": user_prompt}]\n",
52 | " \n",
53 | " res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens)\n",
54 | " \n",
55 | " return res.choices[0].message.content"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "base_prompt = '''You are provided with a text description (figure caption) of a pathology image. Unfortunately, you don't have access to the original image.\n",
65 | "Your job is to generate a total of 5 open-ended question/answer pairs from this figure caption starting with \"What\" or \"Where\". Below are the requirements to generate the question/answer pairs:\n",
66 | "\n",
67 | "- Avoid quoting or referring to specific facts, terms, abbreviations, dates, numbers or names, as these may reveal the conversation is based on the text information, rather than image itself.\n",
68 | "- Focus on the visual aspects of the image that can be inferred without the text information\n",
69 | "- Do not use phrases like \"mentioned\", \"caption\", \"context\", \"without the image\" in the question/answer pairs. Instead, refer to the information as being \"in the image\" or preferably don't mention anything\n",
70 | "- Ensure that question/anwer pairs are diverse and cover a range of visual aspects of the image\n",
71 | "- Answer responsibly, avoiding overconfidence, and do not provide medical advice or diagnostic information\n",
72 | "\n",
73 | "Caption: {caption}\n",
74 | "Question:\n",
75 | "Answer:\n",
76 | "'''"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "# Getting the results and saving it\n",
86 | "index_list = []\n",
87 | "figure_id_list = []\n",
88 | "letter_list = []\n",
89 | "caption_list = []\n",
90 | "uuid_list = []\n",
91 | "llm_response_list = []\n",
92 | "\n",
93 | "start_index = 0\n",
94 | "current_index = start_index\n",
95 | "total_records = len(captions_data)\n",
96 | "\n",
97 | "while True:\n",
98 | " try:\n",
99 | " for index in range(start_index, total_records):\n",
100 | " current_index = index\n",
101 | " figure_id = captions_data[str(current_index)]['figure_id']\n",
102 | " letter = captions_data[str(current_index)]['letter']\n",
103 | " caption = captions_data[str(current_index)]['caption']\n",
104 | " uuid = captions_data[str(current_index)]['uuid']\n",
105 | " \n",
106 | " user_prompt = base_prompt.format(caption = caption)\n",
107 | " response = gpt(user_prompt)\n",
108 | " \n",
109 | " index_list.append(current_index)\n",
110 | " figure_id_list.append(figure_id)\n",
111 | " letter_list.append(letter)\n",
112 | " caption_list.append(caption)\n",
113 | " uuid_list.append(uuid)\n",
114 | " llm_response_list.append(response)\n",
115 | "\n",
116 | " print(\"Index:\", current_index)\n",
117 | " print(\"Figure_ID:\", figure_id)\n",
118 | " print(\"Letter:\", letter)\n",
119 | " print(\"Caption:\", caption)\n",
120 | " print(\"UUID:\", uuid)\n",
121 | " print()\n",
122 | " print(response)\n",
123 | " print()\n",
124 | " \n",
125 | " except Exception as err:\n",
126 | " print(\"Something went wrong: \", err)\n",
127 | " start_index = current_index\n",
128 | " print(\"Waiting for 10 seconds before continuing again with index:\", start_index)\n",
129 | " time.sleep(10)\n",
130 | "\n",
131 | " # Break the loop if current_index has completed\n",
132 | " if current_index == (total_records - 1):\n",
133 | " break\n",
134 | "\n",
135 | "llm_qa_pairs_books = pd.DataFrame({'index': index_list, 'figure_id': figure_id_list, 'letter': letter_list,\n",
136 | " 'caption': caption_list, 'uuid': uuid_list, 'llm_qa_pairs_books': llm_response_list})\n",
137 | "\n",
138 | "file_name = 'llm_qa_pairs_books_' + str(start_index) + '_' + str(total_records) + '.pkl'\n",
139 | "\n",
140 | "with open(file_name, 'wb') as file:\n",
141 | " pickle.dump(llm_qa_pairs_books, file)"
142 | ]
143 | }
144 | ],
145 | "metadata": {
146 | "kernelspec": {
147 | "display_name": "path-rag",
148 | "language": "python",
149 | "name": "python3"
150 | },
151 | "language_info": {
152 | "codemirror_mode": {
153 | "name": "ipython",
154 | "version": 3
155 | },
156 | "file_extension": ".py",
157 | "mimetype": "text/x-python",
158 | "name": "python",
159 | "nbconvert_exporter": "python",
160 | "pygments_lexer": "ipython3",
161 | "version": "3.9.19"
162 | }
163 | },
164 | "nbformat": 4,
165 | "nbformat_minor": 2
166 | }
167 |
--------------------------------------------------------------------------------
/ARCH-OPEN/pubmed_data/llm_qa_pairs_pubmed_0_3309.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/embedded-robotics/path-rag/e1bd5ea6aec4825d708b4f7b77b2fcaf52fff824/ARCH-OPEN/pubmed_data/llm_qa_pairs_pubmed_0_3309.pkl
--------------------------------------------------------------------------------
/ARCH-OPEN/pubmed_data/synthetic_data_pubmed.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import os\n",
11 | "import json\n",
12 | "import openai\n",
13 | "import backoff\n",
14 | "from PIL import Image\n",
15 | "import pickle\n",
16 | "import time\n",
17 | "import sys"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "sys.path.append(\"data/mn27889/path-rag\")\n",
27 | "file_path = os.path.join(\"arch\", \"pubmed_set\", \"captions.json\")\n",
28 | "\n",
29 | "with open(file_path, 'rb') as file:\n",
30 | " captions_data = json.load(file)"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "OPENAI_API_PATH = os.path.join(os.getcwd(), 'api.key')\n",
40 | "\n",
41 | "with open(OPENAI_API_PATH) as f:\n",
42 | " openai.api_key = f.read().strip()\n",
43 | "\n",
44 | "@backoff.on_exception(backoff.expo, openai.OpenAIError)\n",
45 | "def completions_with_backoff(**kwargs):\n",
46 | " return openai.chat.completions.create(**kwargs)\n",
47 | "\n",
48 | "def gpt(user_prompt, system_prompt=\"You are an expert pathologist\", model=\"gpt-4\", temperature=0.7, max_tokens=1000) -> list:\n",
49 | "\n",
50 | " messages = [{\"role\": \"system\", \"content\": system_prompt},\n",
51 | " {\"role\": \"user\", \"content\": user_prompt}]\n",
52 | " \n",
53 | " res = completions_with_backoff(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens)\n",
54 | " \n",
55 | " return res.choices[0].message.content"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "base_prompt = '''You are provided with a text description (figure caption) of a pathology image. Unfortunately, you don't have access to the original image.\n",
65 | "Your job is to generate a total of 5 open-ended question/answer pairs from this figure caption starting with \"What\" or \"Where\". Below are the requirements to generate the question/answer pairs:\n",
66 | "\n",
67 | "- Avoid quoting or referring to specific facts, terms, abbreviations, dates, numbers or names, as these may reveal the conversation is based on the text information, rather than image itself.\n",
68 | "- Focus on the visual aspects of the image that can be inferred without the text information\n",
69 | "- Do not use phrases like \"mentioned\", \"caption\", \"context\", \"without the image\" in the question/answer pairs. Instead, refer to the information as being \"in the image\" or preferably don't mention anything\n",
70 | "- Ensure that question/anwer pairs are diverse and cover a range of visual aspects of the image\n",
71 | "- Answer responsibly, avoiding overconfidence, and do not provide medical advice or diagnostic information\n",
72 | "\n",
73 | "Caption: {caption}\n",
74 | "Question:\n",
75 | "Answer:\n",
76 | "'''"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "# Getting the results and saving it\n",
86 | "index_list = []\n",
87 | "caption_list = []\n",
88 | "uuid_list = []\n",
89 | "llm_response_list = []\n",
90 | "\n",
91 | "start_index = 0\n",
92 | "current_index = start_index\n",
93 | "total_records = len(captions_data)\n",
94 | "\n",
95 | "while True:\n",
96 | " try:\n",
97 | " for index in range(start_index, total_records):\n",
98 | " current_index = index\n",
99 | " caption = captions_data[str(current_index)]['caption']\n",
100 | " uuid = captions_data[str(current_index)]['uuid']\n",
101 | " \n",
102 | " user_prompt = base_prompt.format(caption = caption)\n",
103 | " response = gpt(user_prompt)\n",
104 | " \n",
105 | " index_list.append(current_index)\n",
106 | " caption_list.append(caption)\n",
107 | " uuid_list.append(uuid)\n",
108 | " llm_response_list.append(response)\n",
109 | "\n",
110 | " print(\"Index:\", current_index)\n",
111 | " print(\"Caption:\", caption)\n",
112 | " print(\"UUID:\", uuid)\n",
113 | " print()\n",
114 | " print(response)\n",
115 | " print()\n",
116 | " \n",
117 | " except Exception as err:\n",
118 | " print(\"Something went wrong: \", err)\n",
119 | " start_index = current_index\n",
120 | " print(\"Waiting for 10 seconds before continuing again with index:\", start_index)\n",
121 | " time.sleep(10)\n",
122 | "\n",
123 | " # Break the loop if current_index has completed\n",
124 | " if current_index == (total_records - 1):\n",
125 | " break\n",
126 | "\n",
127 | "\n",
128 | "llm_qa_pairs_pubmed = pd.DataFrame({'index': index_list, 'caption': caption_list, 'uuid': uuid_list, 'llm_qa_pairs_pubmed': llm_response_list})\n",
129 | "\n",
130 | "file_name = 'llm_qa_pairs_pubmed_' + str(start_index) + '_' + str(total_records) + '.pkl'\n",
131 | "\n",
132 | "with open(file_name, 'wb') as file:\n",
133 | " pickle.dump(llm_qa_pairs_pubmed, file)"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": null,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": []
142 | }
143 | ],
144 | "metadata": {
145 | "kernelspec": {
146 | "display_name": "path-rag",
147 | "language": "python",
148 | "name": "python3"
149 | },
150 | "language_info": {
151 | "codemirror_mode": {
152 | "name": "ipython",
153 | "version": 3
154 | },
155 | "file_extension": ".py",
156 | "mimetype": "text/x-python",
157 | "name": "python",
158 | "nbconvert_exporter": "python",
159 | "pygments_lexer": "ipython3",
160 | "version": "3.9.19"
161 | }
162 | },
163 | "nbformat": 4,
164 | "nbformat_minor": 2
165 | }
166 |
--------------------------------------------------------------------------------
/ARCH-OPEN/synthetic_data_compilation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### This code will compile the questions/answers from both the responses received from GPT for textbook and pubmed data"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import json\n",
17 | "import pickle\n",
18 | "import pandas as pd\n",
19 | "import re"
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "### Pubmed Data"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "with open('pubmed_data/llm_qa_pairs_pubmed_0_3309.pkl', 'rb') as file:\n",
36 | " pubmed_data = pickle.load(file)"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 3,
42 | "metadata": {},
43 | "outputs": [
44 | {
45 | "name": "stderr",
46 | "output_type": "stream",
47 | "text": [
48 | "<>:7: SyntaxWarning: invalid escape sequence '\\s'\n",
49 | "<>:8: SyntaxWarning: invalid escape sequence '\\s'\n",
50 | "<>:7: SyntaxWarning: invalid escape sequence '\\s'\n",
51 | "<>:8: SyntaxWarning: invalid escape sequence '\\s'\n",
52 | "/tmp/ipykernel_2634287/1379646763.py:7: SyntaxWarning: invalid escape sequence '\\s'\n",
53 | " questions = re.findall('(Question.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n",
54 | "/tmp/ipykernel_2634287/1379646763.py:8: SyntaxWarning: invalid escape sequence '\\s'\n",
55 | " answers = re.findall('(Answer.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n"
56 | ]
57 | }
58 | ],
59 | "source": [
60 | "pubmed_data_list = []\n",
61 | "\n",
62 | "for index, row in pubmed_data.iterrows():\n",
63 | " caption = row['caption']\n",
64 | " uuid = row['uuid']\n",
65 | " qa_pairs = row['llm_qa_pairs']\n",
66 | " questions = re.findall('(Question.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n",
67 | " answers = re.findall('(Answer.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n",
68 | " assert len(questions) == 5\n",
69 | " assert len(answers) == 5\n",
70 | "\n",
71 | " data = {'caption': caption,\n",
72 | " 'uuid': uuid}\n",
73 | " for i in range(0,5):\n",
74 | " data['Question_' + str(i+1)] = questions[i][1]\n",
75 | " data['Answer_' + str(i+1)] = answers[i][1]\n",
76 | " \n",
77 | " pubmed_data_list.append(data)"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 4,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "with open('pubmed_qa_pairs.json', 'w') as file:\n",
87 | " json.dump(pubmed_data_list, file, indent=2)"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "metadata": {},
93 | "source": [
94 | "### Textbook Data"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "with open('books_data/llm_qa_pairs_books_0_4305.pkl', 'rb') as file:\n",
104 | " textbook_data = pickle.load(file)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 6,
110 | "metadata": {},
111 | "outputs": [
112 | {
113 | "name": "stderr",
114 | "output_type": "stream",
115 | "text": [
116 | "<>:9: SyntaxWarning: invalid escape sequence '\\s'\n",
117 | "<>:10: SyntaxWarning: invalid escape sequence '\\s'\n",
118 | "<>:9: SyntaxWarning: invalid escape sequence '\\s'\n",
119 | "<>:10: SyntaxWarning: invalid escape sequence '\\s'\n",
120 | "/tmp/ipykernel_2634287/2596557835.py:9: SyntaxWarning: invalid escape sequence '\\s'\n",
121 | " questions = re.findall('(Question.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n",
122 | "/tmp/ipykernel_2634287/2596557835.py:10: SyntaxWarning: invalid escape sequence '\\s'\n",
123 | " answers = re.findall('(Answer.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n"
124 | ]
125 | }
126 | ],
127 | "source": [
128 | "textbook_data_list = []\n",
129 | "\n",
130 | "for index, row in textbook_data.iterrows():\n",
131 | " figure_id = row['figure_id']\n",
132 | " letter = row['letter']\n",
133 | " caption = row['caption']\n",
134 | " uuid = row['uuid']\n",
135 | " qa_pairs = row['llm_qa_pairs_books']\n",
136 | " questions = re.findall('(Question.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n",
137 | " answers = re.findall('(Answer.*:[\\s\\n]*)(.+)(\\n*)', qa_pairs)\n",
138 | " assert len(questions) == 5\n",
139 | " assert len(answers) == 5\n",
140 | "\n",
141 | " data = {'figure_id':figure_id,\n",
142 | " 'letter': letter,\n",
143 | " 'caption': caption,\n",
144 | " 'uuid': uuid}\n",
145 | " for i in range(0,5):\n",
146 | " data['Question_' + str(i+1)] = questions[i][1]\n",
147 | " data['Answer_' + str(i+1)] = answers[i][1]\n",
148 | " \n",
149 | " textbook_data_list.append(data)"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": 7,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "with open('textbook_qa_pairs.json', 'w') as file:\n",
159 | " json.dump(textbook_data_list, file, indent=2)"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {},
166 | "outputs": [],
167 | "source": []
168 | }
169 | ],
170 | "metadata": {
171 | "kernelspec": {
172 | "display_name": "path-rag-dpo",
173 | "language": "python",
174 | "name": "python3"
175 | },
176 | "language_info": {
177 | "codemirror_mode": {
178 | "name": "ipython",
179 | "version": 3
180 | },
181 | "file_extension": ".py",
182 | "mimetype": "text/x-python",
183 | "name": "python",
184 | "nbconvert_exporter": "python",
185 | "pygments_lexer": "ipython3",
186 | "version": "3.12.4"
187 | }
188 | },
189 | "nbformat": 4,
190 | "nbformat_minor": 2
191 | }
192 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Path-RAG: Knowledge-Guided Key Region Retrieval for Open-ended Pathology Visual Question Answering
2 |
3 |
4 |
5 |
6 | *Accurate diagnosis and prognosis assisted by pathology images are essential for cancer treatment selection and planning. Despite the recent trend of adopting deep-learning approaches for analyzing complex pathology images, they fall short as they often overlook the domain-expert understanding of tissue structure and cell composition. In this work, we focus on a challenging Open-ended Pathology VQA (PathVQA-Open) task and propose a novel framework named Path-RAG, which leverages HistoCartography to retrieve relevant domain knowledge from pathology images and significantly improves performance on PathVQA-Open. Admitting the complexity of pathology image analysis, Path-RAG adopts a human-centered AI approach by retrieving domain knowledge using HistoCartography to select the relevant patches from pathology images. Our experiments suggest that domain guidance can significantly boost the accuracy of LLaVA-Med from 38\% to 47\%, with a notable gain of 28\% for H\&E-stained pathology images in the PathVQA-Open dataset. For longer-form question and answer pairs, our model consistently achieves significant improvements of 32.5\% in ARCH-Open PubMed and 30.6\% in ARCH-Open Books on H\&E images.*
7 |
8 |
9 | ---
10 |
11 | Awais Naeem*, Tianhao Li*, Huang-Ru Liao*, Jiawei Xu*, Aby Mammen Mathew*, Zehao Zhu*, Zhen Tan**, Ajay Jaiswal*, Raffi Salibian*** , Ziniu Hu*** , Tianlong Chen****, Ying Ding*
12 |
13 | *University of Texas at Austin, USA \
14 | **Arizona State University, USA \
15 | ***University of California, Los Angeles, USA \
16 | ****Massachusetts Institute of Technology, USA
17 |
18 | ---
19 |
20 | ## Path-RAG Implementation
21 |
22 | ### 1. Clone this repository and navigate to path-rag folder
23 |
24 | ```Shell
25 | git clone https://github.com/embedded-robotics/path-rag.git
26 | cd path-rag
27 | ```
28 |
29 | ### 2. Install Package: Create conda environment
30 |
31 | ```Shell
32 | conda create -n path-rag python=3.10 -y
33 | conda activate path-rag
34 | pip install --upgrade pip # enable PEP 660 support for LLaVA-Med
35 | ```
36 |
37 | ### 3. Download the PathVQA dataset from the following link
38 |
39 | [PathVQA Dataset](https://github.com/UCSD-AI4H/PathVQA/blob/master/data/README.md)
40 |
41 | ### 4. Clone the HistoCartography tool, setup the model checkpoints in `histocartography/checkpoints` and install the dependencies
42 |
43 | ```Shell
44 | git clone https://github.com/BiomedSciAI/histocartography
45 | ```
46 |
47 | ### 5. Clone the LLaVA-Med repository and install the dependencies
48 |
49 | ```Shell
50 | git clone https://github.com/microsoft/LLaVA-Med
51 | ```
52 |
53 | ### 6. Download the LLaMA-7B model and weights from HuggingFace
54 |
55 | ```Shell
56 | python llama_7B_model_weights.py # LLaMA-7B weights/model stored into $HF_HOME (By Default $HF_HOME = ~/.cache/huggingface)
57 | ```
58 |
59 | ### 7. Download LLaVA-Med delta weights `llava_med_in_text_60k_ckpt2_delta` and `pvqa-9epoch_delta` from `https://github.com/microsoft/LLaVA-Med#model-download`. Put them inside a folder named `model_delta_weights`
60 |
61 | ### 8. Apply the LLaVA-Med delta weights to base LLaMA-7B to come up with the final weights for LLaVA-Med
62 |
63 | ```Shell
64 | cd LLaVA-Med
65 | ```
66 |
67 | #### LLaVA-Med pre-trained on general biomedicine data
68 |
69 | ```Shell
70 | !python3 -m llava.model.apply_delta \
71 | --base ~/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16 \
72 | --target ../final_models/llava_med \
73 | --delta ../model_delta_weights/llava_med_in_text_60k_ckpt2_delta
74 | ```
75 |
76 | #### LLaVA-Med fine-tuned on PathVQA
77 |
78 | ```Shell
79 | !python -m llava.model.apply_delta \
80 | --base ~/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16 \
81 | --target ../final_models/llava_med_pvqa \
82 | --delta ../model_delta_weights/pvqa-9epoch_delta
83 | ```
84 |
85 | ```Shell
86 | cd ..
87 | ```
88 |
89 | ### 9. Generate the top patches for open-ended PathVQA images using HistoCartography
90 |
91 | ```Shell
92 | python generate_histo_patches.py
93 | ```
94 |
95 | ### 10. Generate the files for query to be asked for LLaVA-Med for both the images and patches
96 |
97 | ```Shell
98 | python generate_llava_med_query.py
99 | ```
100 |
101 | ### 11. Now we need to generate the answer for all the query files using raw model (`final_models/llava_med`) and fine-tuned model (`final_models/llava_med_pvqa`)
102 |
103 | ```Shell
104 | cd LLaVA-Med
105 | ```
106 |
107 | #### Raw Model
108 | ```Shell
109 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med \
110 | --question-file ../files/query/image_direct.jsonl \
111 | --image-folder ../pvqa/images/test \
112 | --answers-file ../files/answer/raw/answer_image_direct.jsonl
113 | ```
114 |
115 | ```Shell
116 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med \
117 | --question-file ../files/query/patch_direct.jsonl \
118 | --image-folder ../pvqa/images/test \
119 | --answers-file ../files/answer/raw/answer_patch_direct.jsonl
120 | ```
121 |
122 | ```Shell
123 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med \
124 | --question-file ../files/query/image_description.jsonl \
125 | --image-folder ../pvqa/images/test \
126 | --answers-file ../files/answer/raw/answer_image_description.jsonl
127 | ```
128 |
129 | ```Shell
130 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med \
131 | --question-file ../files/query/patch_description.jsonl \
132 | --image-folder ../pvqa/images/test \
133 | --answers-file ../files/answer/raw/answer_patch_description.jsonl
134 | ```
135 |
136 | #### Fine-Tuned Model
137 | ```Shell
138 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \
139 | --question-file ../files/query/image_direct.jsonl \
140 | --image-folder ../pvqa/images/test \
141 | --answers-file ../files/answer/fine-tuned/answer_image_direct.jsonl
142 | ```
143 |
144 | ```Shell
145 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \
146 | --question-file ../files/query/patch_direct.jsonl \
147 | --image-folder ../pvqa/images/test \
148 | --answers-file ../files/answer/fine-tuned/answer_patch_direct.jsonl
149 | ```
150 |
151 | ```Shell
152 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \
153 | --question-file ../files/query/image_description.jsonl \
154 | --image-folder ../pvqa/images/test \
155 | --answers-file ../files/answer/fine-tuned/answer_image_description.jsonl
156 | ```
157 |
158 | ```Shell
159 | python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \
160 | --question-file ../files/query/patch_description.jsonl \
161 | --image-folder ../pvqa/images/test \
162 | --answers-file ../files/answer/fine-tuned/answer_patch_description.jsonl
163 | ```
164 |
165 | ### 12. Evaluate the results for different use-cases using `recall_calculation.py`
166 |
167 | (i) Path-RAG w/o GPT: Combine the answer of image + all patches to be the final predicted answer\
168 | (ii) Path-RAG (description): Combine the description of image + all patches. Then involve GPT-4 for reasoning to ge the final predicted answer (See Supplementary Section for Prompts)\
169 | (iii) Path-RAG (answer): Combine the answer of image + all patches. Then involve GPT-4 for reasoning to ge the final predicted answer (See Supplementary Section for Prompts)
170 |
171 |
172 | ## ARCH-Open Dataset
173 |
174 | 1. Download the `books_set` and `pubmed_set` of ARCH dataset from `https://warwick.ac.uk/fac/cross_fac/tia/data/arch`. Store both of these folders in a folder named `arch`. Both `books_set` and `pubmed_set` contains `captions.json` which lists a **caption** and a **UUID**, whereas **UUID** represents the file name in `images` folder and **caption** represents the description of the image.
175 |
176 | 2. Using `captions.json` and `images` folder under `arch/books_set`, run the notebooks `ARCH-OPEN/books_data/synthetic_data_textbook.ipynb` by specifying the OpenAI credentials to generate the question-answer pairs for books set
177 |
178 | 3. Using `captions.json` and `images` folder under `arch/pubmed_set`, run the notebooks `ARCH-OPEN/pubmed_data/synthetic_data_pubmed.ipynb` by specifying the OpenAI credentials to generate the question-answer pairs for pubmed set
179 |
180 | 4. Run the notebook `ARCH-OPEN/synthetic_data_compilation.ipynb` to compile the `pubmed` and `books` question-answer pairs into json files namely `ARCH-OPEN/pubmed_qa_pairs.json` and `ARCH-OPEN/textbook_qa_pairs.json`. These files are already provided to be used directly
181 |
182 | 5. The `pubmed_qa_pairs.json` and `textbook_qa_pairs.json` files contain 5 question-pairs for each pair of `caption` and `uuid` (refers to image name in arch data `arch/pubmed_set/images`, `arch/books_set/images`) in the following format (for both `pubmed_set` and `books_set`):
183 |
184 | ```Shell
185 | {
186 | "figure_id": "00",
187 | "letter": "A",
188 | "caption": " A, Spindle cell variant of embryonal rhabdomyosarcoma is characterized by fascicles of eosinophilic spindle cells (B), some of which can show prominent paranuclear vacuolisation, as seen in leiomyosarcoma.",
189 | "uuid": "890e2e79-ab0a-4a2e-9d62-b0b6b3d43884",
190 | "Question_1": "What could be the general shape of cells in a spindle cell variant of embryonal rhabdomyosarcoma as seen in the image?",
191 | "Answer_1": "The cells often present with a spindle-like elongated shape.",
192 | "Question_2": "What type of structures could be visible in the image indicating the presence of spindle cells?",
193 | "Answer_2": "Fascicles, or bundles, of cells could be visible in the image, indicating the presence of spindle cells.",
194 | "Question_3": "Where in the cell would we likely find paranuclear vacuolisation in the image?",
195 | "Answer_3": "Paranuclear vacuolisation is usually seen around the nucleus area of the cell.",
196 | "Question_4": "What color might the spindle cells appear in the image?",
197 | "Answer_4": "The spindle cells may appear eosinophilic, or pinkish-red, under the microscope due to staining.",
198 | "Question_5": "What visual feature might differentiate spindle cells from leiomyosarcoma cells in the image?",
199 | "Answer_5": "Spindle cells might show prominent paranuclear vacuolisation, a feature that can differentiate them from leiomyosarcoma cells."
200 | }
201 | ```
202 |
203 | ## Acknolwdgement
204 | We would like to acknowledge the following funding supports: NIH OT2OD032581, NIH OTA-21-008, NIH 1OT2OD032742-01, NSF 2333703, NSF 2303038.
--------------------------------------------------------------------------------
/files/answer/fine-tuned/text.txt:
--------------------------------------------------------------------------------
1 | For Answer files by fine-tuned model
2 |
--------------------------------------------------------------------------------
/files/answer/raw/text.txt:
--------------------------------------------------------------------------------
1 | For Answer files by raw model
2 |
--------------------------------------------------------------------------------
/files/query/text.txt:
--------------------------------------------------------------------------------
1 | For Query files
2 |
--------------------------------------------------------------------------------
/final_models/text.txt:
--------------------------------------------------------------------------------
1 | Forfinal modesl
2 |
--------------------------------------------------------------------------------
/generate_histo_patches.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from PIL import Image\n",
10 | "import numpy as np\n",
11 | "import os"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 8,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "data_path = \"arch\"\n",
21 | "\n",
22 | "# Extract the image data for PubMed dataset\n",
23 | "pubmed_path_images = os.path.join(data_path, \"pubmed_set\", \"images\")\n",
24 | "pubmed_img_uuid_list = os.listdir(pubmed_path_images)\n",
25 | "pubmed_img_uuid = [uuid.split('.')[0] for uuid in pubmed_img_uuid_list]\n",
26 | "pubmed_img_uuid_path = [os.path.join(pubmed_path_images, img_uuid) for img_uuid in pubmed_img_uuid_list]\n",
27 | "\n",
28 | "# # Extract the image data for Books dataset\n",
29 | "# books_path_images = os.path.join(data_path, \"books_set\", \"images\")\n",
30 | "# books_img_uuid = os.listdir(books_path_images)\n",
31 | "# books_img_uuid = [uuid.split('.')[0] for uuid in books_img_uuid]\n",
32 | "# books_img_uuid_path = [os.path.join(books_path_images, img_uuid + '.png') for img_uuid in books_img_uuid]"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 6,
38 | "metadata": {},
39 | "outputs": [
40 | {
41 | "name": "stdout",
42 | "output_type": "stream",
43 | "text": [
44 | "71\n"
45 | ]
46 | }
47 | ],
48 | "source": [
49 | "for i in range(len(pubmed_img_uuid)):\n",
50 | " if pubmed_img_uuid[i] == 'a6fade92-c4ce-4eed-b8aa-795e85c4645e':\n",
51 | " print(i)\n",
52 | " break"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": []
61 | }
62 | ],
63 | "metadata": {
64 | "kernelspec": {
65 | "display_name": "path-rag",
66 | "language": "python",
67 | "name": "python3"
68 | },
69 | "language_info": {
70 | "codemirror_mode": {
71 | "name": "ipython",
72 | "version": 3
73 | },
74 | "file_extension": ".py",
75 | "mimetype": "text/x-python",
76 | "name": "python",
77 | "nbconvert_exporter": "python",
78 | "pygments_lexer": "ipython3",
79 | "version": "3.9.19"
80 | }
81 | },
82 | "nbformat": 4,
83 | "nbformat_minor": 2
84 | }
85 |
--------------------------------------------------------------------------------
/generate_histo_patches.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import os
4 | import json
5 | import pickle
6 | import sys
7 | sys.path.append(os.path.join(os.getcwd(), 'histocartography'))
8 |
9 | from histocartography.preprocessing import NucleiExtractor, DeepFeatureExtractor, KNNGraphBuilder
10 |
11 | PVQA_DATA_PATH = "pvqa"
12 | HISTO_PATCH_SAVE_PATH = "histo_image_patch"
13 | ARCH_DATA_PATH = "arch"
14 |
15 | # Cell Graph Generation Definitions
16 | nuclei_detector = NucleiExtractor()
17 | feats_extractor = DeepFeatureExtractor(architecture='resnet34', patch_size=72, resize_size=224)
18 | knn_graph_builder = KNNGraphBuilder(k=5, thresh=50, add_loc_feats=True)
19 |
20 | # PathVQA Dataset Processing
21 | def get_path_vqa_open_images(data_path : str = "pvqa"):
22 |
23 | # Reading the PathVQA dataset
24 | img_train_path = os.path.join(data_path, "images", "test")
25 | qas_train_path = os.path.join(data_path, "qas", "test", "test_qa.pkl")
26 | with open(qas_train_path, 'rb') as file:
27 | pvqa_qas = pickle.load(file)
28 |
29 | # Getting all open-ended images
30 | qas_general = [qas for qas in pvqa_qas if qas['answer'] != 'yes' and qas['answer'] != 'no']
31 | img_general = [qas['image'] for qas in qas_general]
32 | img_general = list(set(img_general))
33 | img_general = sorted(img_general, key=str)
34 | img_general_path = [img_train_path + img_name + '.jpg' for img_name in img_general]
35 |
36 | return img_general, img_general_path
37 |
38 | # PathVQA Dataset Processing
39 | def get_arch_open_images(data_path : str = "arch"):
40 |
41 | # Extract the image data for PubMed dataset
42 | pubmed_path_images = os.path.join(data_path, "pubmed_set", "images")
43 | pubmed_img_uuid_list = os.listdir(pubmed_path_images)
44 | pubmed_img_uuid = [uuid.split('.')[0] for uuid in pubmed_img_uuid_list]
45 | pubmed_img_uuid_path = [os.path.join(pubmed_path_images, img_uuid) for img_uuid in pubmed_img_uuid_list]
46 |
47 | # Extract the image data for Books dataset
48 | books_path_images = os.path.join(data_path, "books_set", "images")
49 | books_img_uuid_list = os.listdir(books_path_images)
50 | books_img_uuid = [uuid.split('.')[0] for uuid in books_img_uuid_list]
51 | books_img_uuid_path = [os.path.join(books_path_images, img_uuid) for img_uuid in books_img_uuid_list]
52 |
53 | return pubmed_img_uuid, pubmed_img_uuid_path, books_img_uuid, books_img_uuid_path
54 |
55 | # Save top patches using histocartography
56 | def save_histocartography_top_patches(img_general : list, img_general_path: list):
57 | for image_idx in range(0, len(img_general)):
58 |
59 | print(f"{image_idx}: Started ")
60 | query_img = Image.open(img_general_path[image_idx]).convert(mode="RGB")
61 | image = np.array(query_img)
62 | nuclei_map, nuclei_centers = nuclei_detector.process(image)
63 |
64 | # Only consider if more than 5 nuclei are detected since knn needs to form a graph using 5 neighbors.
65 | # If less than 5 nuclei are present, most of the images are not pathology related
66 | if nuclei_centers.shape[0] > 5:
67 | print(f"{image_idx}: Patches ")
68 |
69 | # Get the Features
70 | features = feats_extractor.process(image, nuclei_map)
71 |
72 | # Make Cell Graph
73 | cell_graph = knn_graph_builder.process(nuclei_map, features)
74 |
75 | # Make calculations to extract patches and the overlap images
76 | width, height = query_img.size
77 | width_range = np.linspace(0, width, 4, dtype=int)
78 | height_range = np.linspace(0, height, 4, dtype=int)
79 |
80 | overlap_percent = 20
81 | width_overlap = int((overlap_percent/100) * width)
82 | height_overlap = int((overlap_percent/100) * height)
83 |
84 | # Extract the patches
85 | image_patches = []
86 | patch_nuclei_centers = []
87 | for i in range(len(width_range)-1):
88 | for j in range(len(height_range)-1):
89 | # Consider the overlap width from second patch only
90 | if i != 0:
91 | start_width = width_range[i] - width_overlap
92 | else:
93 | start_width = width_range[i]
94 |
95 | # Consider the overlap height from second patch only
96 | if j != 0:
97 | start_height = height_range[j] - height_overlap
98 | else:
99 | start_height = height_range[j]
100 |
101 | # List out the patch ranges
102 | left = start_width
103 | upper = start_height
104 | right = width_range[i+1]
105 | lower = height_range[j+1]
106 |
107 | center_list = []
108 | for center in nuclei_centers:
109 | if ((center[0] >= left) and (center[0] <=right) and
110 | (center[1] >= upper) and (center[1] <=lower)):
111 | center_list.append(center)
112 |
113 | image_patches.append(query_img.crop((left, upper, right, lower)))
114 | patch_nuclei_centers.append(center_list)
115 |
116 | # Calculate the length of nuclei in each patch
117 | patch_center_length = []
118 | for center in patch_nuclei_centers:
119 | patch_center_length.append(len(center))
120 |
121 | # Sort the patch indices based on maximum number of nuclei
122 | sorted_indices_desc = np.flip(np.argsort(patch_center_length))
123 |
124 | # Create a directory to store all the patches of the image
125 | save_directory = os.path.join(HISTO_PATCH_SAVE_PATH, img_general[image_idx])
126 | if not os.path.isdir(save_directory):
127 | os.mkdir(save_directory)
128 |
129 | # Store all the image patches into the newly created directory
130 | for patch_index in range(0,6):
131 | save_file_path = os.path.join(save_directory, str(patch_index+1) + ".png")
132 | image_patches[sorted_indices_desc[patch_index]].save(save_file_path)
133 |
134 | print(f"{image_idx}: Ended ")
135 | print(".........")
136 |
137 | # Save top patches using histocartography
138 | def save_histocartography_top_patches_arch(img_uuid : list, img_uuid_path: list, books_pubmed_class: str):
139 | for image_idx in range(0, len(img_uuid)):
140 |
141 | print(f"{image_idx}/{len(img_uuid)}: Started ")
142 |
143 | query_img = Image.open(img_uuid_path[image_idx]).convert(mode="RGB")
144 | image = np.array(query_img)
145 | nuclei_map, nuclei_centers = nuclei_detector.process(image)
146 |
147 | # Only consider if more than 5 nuclei are detected since knn needs to form a graph using 5 neighbors.
148 | # If less than 5 nuclei are present, most of the images are not pathology related
149 | if nuclei_centers.shape[0] > 5:
150 | print(f"{image_idx}: Creating Patches ")
151 |
152 | # Get the Features
153 | features = feats_extractor.process(image, nuclei_map)
154 |
155 | # Make Cell Graph
156 | cell_graph = knn_graph_builder.process(nuclei_map, features)
157 |
158 | # Make calculations to extract patches and the overlap images
159 | width, height = query_img.size
160 | width_range = np.linspace(0, width, 4, dtype=int)
161 | height_range = np.linspace(0, height, 4, dtype=int)
162 |
163 | overlap_percent = 20
164 | width_overlap = int((overlap_percent/100) * width)
165 | height_overlap = int((overlap_percent/100) * height)
166 |
167 | # Extract the patches
168 | image_patches = []
169 | patch_nuclei_centers = []
170 | for i in range(len(width_range)-1):
171 | for j in range(len(height_range)-1):
172 | # Consider the overlap width from second patch only
173 | if i != 0:
174 | start_width = width_range[i] - width_overlap
175 | else:
176 | start_width = width_range[i]
177 |
178 | # Consider the overlap height from second patch only
179 | if j != 0:
180 | start_height = height_range[j] - height_overlap
181 | else:
182 | start_height = height_range[j]
183 |
184 | # List out the patch ranges
185 | left = start_width
186 | upper = start_height
187 | right = width_range[i+1]
188 | lower = height_range[j+1]
189 |
190 | center_list = []
191 | for center in nuclei_centers:
192 | if ((center[0] >= left) and (center[0] <=right) and
193 | (center[1] >= upper) and (center[1] <=lower)):
194 | center_list.append(center)
195 |
196 | image_patches.append(query_img.crop((left, upper, right, lower)))
197 | patch_nuclei_centers.append(center_list)
198 |
199 | # Calculate the length of nuclei in each patch
200 | patch_center_length = []
201 | for center in patch_nuclei_centers:
202 | patch_center_length.append(len(center))
203 |
204 | # Sort the patch indices based on maximum number of nuclei
205 | sorted_indices_desc = np.flip(np.argsort(patch_center_length))
206 |
207 | # Create a directory to store all the patches of the image
208 | save_directory = os.path.join(os.getcwd(), HISTO_PATCH_SAVE_PATH, ARCH_DATA_PATH, books_pubmed_class, img_uuid[image_idx])
209 | if not os.path.isdir(save_directory):
210 | os.mkdir(save_directory)
211 |
212 | # Store all the image patches into the newly created directory
213 | for patch_index in range(0,6):
214 | save_file_path = os.path.join(save_directory, str(patch_index+1) + ".png")
215 | image_patches[sorted_indices_desc[patch_index]].save(save_file_path)
216 | else:
217 | # Create an empty directory for non-pathology images
218 | print(f"{image_idx}: Non-Pathology Images")
219 | save_directory = os.path.join(os.getcwd(), HISTO_PATCH_SAVE_PATH, ARCH_DATA_PATH, books_pubmed_class, img_uuid[image_idx])
220 | if not os.path.isdir(save_directory):
221 | os.mkdir(save_directory)
222 |
223 | print(f"{image_idx}/{len(img_uuid)}: Ended ")
224 | print(".........")
225 |
226 |
227 | if __name__ == "__main__":
228 |
229 | # # Get all the open-ended images of PathVQA
230 | # img_general, img_general_path = get_path_vqa_open_images(PVQA_DATA_PATH)
231 |
232 | # # Generate the top patches using histocartography and save them
233 | # save_histocartography_top_patches(img_general, img_general_path)
234 |
235 | pubmed_img_uuid, pubmed_img_uuid_path, books_img_uuid, books_img_uuid_path = get_arch_open_images(ARCH_DATA_PATH)
236 |
237 | save_histocartography_top_patches_arch(pubmed_img_uuid, pubmed_img_uuid_path, "pubmed")
238 |
239 | save_histocartography_top_patches_arch(books_img_uuid, books_img_uuid_path, "books")
240 |
--------------------------------------------------------------------------------
/generate_llava_med_query.py:
--------------------------------------------------------------------------------
1 | # This program will generate the query .json files which will need to be input into LLaVA-Med to generate descriptions or answers
2 | # for all the images in PathVQA and all the relevant patches
3 |
4 | import os
5 | import pandas as pd
6 | import pickle
7 | from tqdm import tqdm
8 | import json
9 |
10 | PVQA_DATA_PATH = "pvqa"
11 | HISTO_PATCH_SAVE_PATH = "histo_image_patch"
12 | LLAVA_MED_QUERY_PATH = os.path.join("files", "query")
13 |
14 | def get_path_vqa_open_data(data_path : str = "pvqa"):
15 | # Reading the PathVQA dataset
16 | qas_train_path = os.path.join(data_path, "qas", "test", "test_qa.pkl")
17 | with open(qas_train_path, 'rb') as file:
18 | pvqa_qas = pickle.load(file)
19 |
20 | # Getting all open-ended images
21 | qas_general = [qas for qas in pvqa_qas if qas['answer'] != 'yes' and qas['answer'] != 'no']
22 | img_general = [qas['image'] for qas in qas_general]
23 |
24 | return qas_general, img_general
25 |
26 | def generate_image_files_direct(qas_general: list, img_general: list):
27 | # Generating all the info for images with original questions
28 | question = []
29 | idx = 0
30 | for i in range(0, len(qas_general)):
31 | question.append({"question_id": idx, "image": img_general[i]+'.jpg', "text": qas_general[i]+"\n"})
32 | idx = idx+1
33 |
34 | # Writing each dictionary as a JSON object on a new line
35 | img_direct_path = os.path.join(LLAVA_MED_QUERY_PATH, 'image_direct.jsonl')
36 | with open(img_direct_path, 'w') as file:
37 | for item in question:
38 | json_line = json.dumps(item) # Convert dictionary to JSON string
39 | file.write(json_line + '\n')
40 |
41 | def generate_patch_files_direct(qas_general: list, img_general: list):
42 | # Generating all the info for patches with original questions
43 | question = []
44 | idx = 0
45 | for i in range(0, len(qas_general)):
46 | question.append({"question_id": idx, "image": img_general[i]+"/1.png", "text": qas_general[i]+"\n"})
47 | idx = idx+1
48 | question.append({"question_id": idx, "image": img_general[i]+"/2.png", "text": qas_general[i]+"\n"})
49 | idx = idx+1
50 | question.append({"question_id": idx, "image": img_general[i]+"/3.png", "text": qas_general[i]+"\n"})
51 | idx = idx+1
52 |
53 | # Writing each dictionary as a JSON object on a new line
54 | patch_direct_path = os.path.join(LLAVA_MED_QUERY_PATH, 'patch_direct.jsonl')
55 | with open(patch_direct_path, 'w') as file:
56 | for item in question:
57 | json_line = json.dumps(item) # Convert dictionary to JSON string
58 | file.write(json_line + '\n')
59 |
60 | def generate_image_files_description(img_general: list):
61 | # Getting the description info for all the images
62 | question = []
63 | idx = 0
64 | for i in img_general:
65 | question.append({"question_id": idx, "image": i+'.jpg', "text": "Describe the following image in detail.\n"})
66 | idx = idx+1
67 |
68 | # Writing each dictionary as a JSON object on a new line
69 | img_desc_path = os.path.join(LLAVA_MED_QUERY_PATH, 'image_description.jsonl')
70 | with open(img_desc_path, 'w') as file:
71 | for item in question:
72 | json_line = json.dumps(item) # Convert dictionary to JSON string
73 | file.write(json_line + '\n')
74 |
75 | def generate_patch_files_description():
76 | # Getting all the directories of generated patches
77 | patch_dir_list = os.listdir(HISTO_PATCH_SAVE_PATH)
78 | patch_dir_list = patch_dir_list.sort()
79 |
80 | # Getting the description info for top 3 patches
81 | question = []
82 | idx = 0
83 | for i in patch_dir_list:
84 | question.append({"question_id": idx, "image": i+"/1.png", "text": "Describe the following image in detail.\n"})
85 | idx = idx+1
86 | question.append({"question_id": idx, "image": i+"/2.png", "text": "Describe the following image in detail.\n"})
87 | idx = idx+1
88 | question.append({"question_id": idx, "image": i+"/3.png", "text": "Describe the following image in detail.\n"})
89 | idx = idx+1
90 |
91 | # Writing each dictionary as a JSON object on a new line
92 | patch_desc_path = os.path.join(LLAVA_MED_QUERY_PATH, 'patch_description.jsonl')
93 | with open(patch_desc_path, 'w') as file:
94 | for item in question:
95 | json_line = json.dumps(item) # Convert dictionary to JSON string
96 | file.write(json_line + '\n')
97 |
98 | if __name__ == "__main__":
99 |
100 | # Get all the data from PathVQA
101 | qas_general, img_general = get_path_vqa_open_data(PVQA_DATA_PATH)
102 |
103 | # Generate Image File with Answers
104 | generate_image_files_direct(qas_general, img_general)
105 |
106 | # Generate Patch File with Answers
107 | generate_patch_files_direct(qas_general, img_general)
108 |
109 | # Generate Image File with Descriptions
110 | generate_image_files_description(img_general)
111 |
112 | # Generate Patch File with Answers
113 | generate_patch_files_description()
114 |
--------------------------------------------------------------------------------
/images/path-rag.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/embedded-robotics/path-rag/e1bd5ea6aec4825d708b4f7b77b2fcaf52fff824/images/path-rag.png
--------------------------------------------------------------------------------
/images/text.txt:
--------------------------------------------------------------------------------
1 | Add repo images
2 |
--------------------------------------------------------------------------------
/llama_7B_model_weights.py:
--------------------------------------------------------------------------------
1 | # This program will download LLaMA-7B weights from HuggingFace and store the resulting model into the path specified by $HF_HOME
2 | # By Default $HF_HOME = ~/.cache/huggingface
3 |
4 | from transformers import AutoTokenizer, AutoModelForCausalLM
5 |
6 | # Load tokenizer and model
7 | tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
8 | model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b")
9 |
10 | # Try out an example to make sure the base model is downloaded successfully and working
11 | input_text = "The future of AI in healthcare is"
12 | input_ids = tokenizer.encode(input_text, return_tensors="pt")
13 | # Generate text
14 | output = model.generate(input_ids, max_length=50, num_return_sequences=1)
15 | # Decode and print the generated text
16 | print(tokenizer.decode(output[0], skip_special_tokens=True))
--------------------------------------------------------------------------------
/model_delta_weights/text.txt:
--------------------------------------------------------------------------------
1 | For model delta weights
2 |
--------------------------------------------------------------------------------
/recall_calculation.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import re
3 |
4 | contractions = {
5 | "aint": "ain't",
6 | "arent": "aren't",
7 | "cant": "can't",
8 | "couldve": "could've",
9 | "couldnt": "couldn't",
10 | "couldn'tve": "couldn't've",
11 | "couldnt've": "couldn't've",
12 | "didnt": "didn't",
13 | "doesnt": "doesn't",
14 | "dont": "don't",
15 | "hadnt": "hadn't",
16 | "hadnt've": "hadn't've",
17 | "hadn'tve": "hadn't've",
18 | "hasnt": "hasn't",
19 | "havent": "haven't",
20 | "hed": "he'd",
21 | "hed've": "he'd've",
22 | "he'dve": "he'd've",
23 | "hes": "he's",
24 | "howd": "how'd",
25 | "howll": "how'll",
26 | "hows": "how's",
27 | "Id've": "I'd've",
28 | "I'dve": "I'd've",
29 | "Im": "I'm",
30 | "Ive": "I've",
31 | "isnt": "isn't",
32 | "itd": "it'd",
33 | "itd've": "it'd've",
34 | "it'dve": "it'd've",
35 | "itll": "it'll",
36 | "let's": "let's",
37 | "maam": "ma'am",
38 | "mightnt": "mightn't",
39 | "mightnt've": "mightn't've",
40 | "mightn'tve": "mightn't've",
41 | "mightve": "might've",
42 | "mustnt": "mustn't",
43 | "mustve": "must've",
44 | "neednt": "needn't",
45 | "notve": "not've",
46 | "oclock": "o'clock",
47 | "oughtnt": "oughtn't",
48 | "ow's'at": "'ow's'at",
49 | "'ows'at": "'ow's'at",
50 | "'ow'sat": "'ow's'at",
51 | "shant": "shan't",
52 | "shed've": "she'd've",
53 | "she'dve": "she'd've",
54 | "she's": "she's",
55 | "shouldve": "should've",
56 | "shouldnt": "shouldn't",
57 | "shouldnt've": "shouldn't've",
58 | "shouldn'tve": "shouldn't've",
59 | "somebody'd": "somebodyd",
60 | "somebodyd've": "somebody'd've",
61 | "somebody'dve": "somebody'd've",
62 | "somebodyll": "somebody'll",
63 | "somebodys": "somebody's",
64 | "someoned": "someone'd",
65 | "someoned've": "someone'd've",
66 | "someone'dve": "someone'd've",
67 | "someonell": "someone'll",
68 | "someones": "someone's",
69 | "somethingd": "something'd",
70 | "somethingd've": "something'd've",
71 | "something'dve": "something'd've",
72 | "somethingll": "something'll",
73 | "thats": "that's",
74 | "thered": "there'd",
75 | "thered've": "there'd've",
76 | "there'dve": "there'd've",
77 | "therere": "there're",
78 | "theres": "there's",
79 | "theyd": "they'd",
80 | "theyd've": "they'd've",
81 | "they'dve": "they'd've",
82 | "theyll": "they'll",
83 | "theyre": "they're",
84 | "theyve": "they've",
85 | "twas": "'twas",
86 | "wasnt": "wasn't",
87 | "wed've": "we'd've",
88 | "we'dve": "we'd've",
89 | "weve": "we've",
90 | "werent": "weren't",
91 | "whatll": "what'll",
92 | "whatre": "what're",
93 | "whats": "what's",
94 | "whatve": "what've",
95 | "whens": "when's",
96 | "whered": "where'd",
97 | "wheres": "where's",
98 | "whereve": "where've",
99 | "whod": "who'd",
100 | "whod've": "who'd've",
101 | "who'dve": "who'd've",
102 | "wholl": "who'll",
103 | "whos": "who's",
104 | "whove": "who've",
105 | "whyll": "why'll",
106 | "whyre": "why're",
107 | "whys": "why's",
108 | "wont": "won't",
109 | "wouldve": "would've",
110 | "wouldnt": "wouldn't",
111 | "wouldnt've": "wouldn't've",
112 | "wouldn'tve": "wouldn't've",
113 | "yall": "y'all",
114 | "yall'll": "y'all'll",
115 | "y'allll": "y'all'll",
116 | "yall'd've": "y'all'd've",
117 | "y'alld've": "y'all'd've",
118 | "y'all'dve": "y'all'd've",
119 | "youd": "you'd",
120 | "youd've": "you'd've",
121 | "you'dve": "you'd've",
122 | "youll": "you'll",
123 | "youre": "you're",
124 | "youve": "you've",
125 | }
126 |
127 | manual_map = {
128 | "none": "0",
129 | "zero": "0",
130 | "one": "1",
131 | "two": "2",
132 | "three": "3",
133 | "four": "4",
134 | "five": "5",
135 | "six": "6",
136 | "seven": "7",
137 | "eight": "8",
138 | "nine": "9",
139 | "ten": "10",
140 | }
141 | articles = ["a", "an", "the"]
142 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
143 | comma_strip = re.compile("(\d)(\,)(\d)")
144 | punct = [
145 | ";",
146 | r"/",
147 | "[",
148 | "]",
149 | '"',
150 | "{",
151 | "}",
152 | "(",
153 | ")",
154 | "=",
155 | "+",
156 | "\\",
157 | "_",
158 | "-",
159 | ">",
160 | "<",
161 | "@",
162 | "`",
163 | ",",
164 | "?",
165 | "!",
166 | ]
167 |
168 | def normalize_word(token):
169 | _token = token
170 | for p in punct:
171 | if (p + " " in token or " " + p in token) or (
172 | re.search(comma_strip, token) != None
173 | ):
174 | _token = _token.replace(p, "")
175 | else:
176 | _token = _token.replace(p, " ")
177 | token = period_strip.sub("", _token, re.UNICODE)
178 |
179 | _token = []
180 | temp = token.lower().split()
181 | for word in temp:
182 | word = manual_map.setdefault(word, word)
183 | if word not in articles:
184 | _token.append(word)
185 | for i, word in enumerate(_token):
186 | if word in contractions:
187 | _token[i] = contractions[word]
188 | token = " ".join(_token)
189 | token = token.replace(",", "")
190 | return token
191 |
192 | def split_sentence(sentence, n):
193 | words = defaultdict(int)
194 | # tmp_sentence = re.sub("[^a-zA-Z ]", "", sentence)
195 | tmp_sentence = sentence
196 | tmp_sentence = tmp_sentence.lower()
197 | tmp_sentence = tmp_sentence.strip().split()
198 | length = len(tmp_sentence)
199 | for i in range(length - n + 1):
200 | tmp_words = " ".join(tmp_sentence[i: i + n])
201 | if tmp_words:
202 | words[tmp_words] += 1
203 | return words
204 |
205 | def calculate_f1score(candidate, reference):
206 |
207 | candidate = normalize_word(candidate)
208 | reference = normalize_word(reference)
209 |
210 | candidate_words = split_sentence(candidate, 1)
211 | reference_words = split_sentence(reference, 1)
212 | word_set = set()
213 | for word in candidate_words:
214 | word_set.add(word)
215 | for word in reference_words:
216 | word_set.add(word)
217 |
218 | tp = 0
219 | fp = 0
220 | fn = 0
221 | for word in word_set:
222 | if word in candidate_words and word in reference_words:
223 | tp += candidate_words[word]
224 | elif word in candidate_words and word not in reference_words:
225 | fp += candidate_words[word]
226 | elif word not in candidate_words and word in reference_words:
227 | fn += reference_words[word]
228 |
229 | if len(candidate_words) == 0:
230 | return 0, 0, 0 # "0 (warning: length of candidate's words is 0)"
231 | elif len(reference_words) == 0:
232 | return 0, 0, 0
233 | else:
234 | precision = tp / (tp + fp)
235 | recall = tp / (tp + fn)
236 | if tp == 0:
237 | return 0, 0, 0
238 | else:
239 | return 2 * precision * recall / (precision + recall), precision, recall
240 |
241 |
242 | def recall_score(eval_list: list, answer_key: str='answer', prediction_key: str='llava_answer'):
243 | score = []
244 | for i in eval_list:
245 | gt = i[answer_key]
246 | pred = i[prediction_key]
247 | f1_score, precision, recall = calculate_f1score(pred, gt)
248 | score.append(recall)
249 | total = 0
250 | for i in score:
251 | total = total + i
252 | return total/len(score)
--------------------------------------------------------------------------------
/tem_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# !ln -sf /usr/lib/x86_64-linux-gnu/libffi.so.7 /data/mn27889/miniconda3/envs/path-rag/lib/libffi.so.7"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from PIL import Image\n",
19 | "import numpy as np\n",
20 | "import os\n",
21 | "import pickle\n",
22 | "import json\n",
23 | "import numpy as np\n",
24 | "import sys\n",
25 | "sys.path.append(os.path.join(os.getcwd(), 'histocartography'))\n",
26 | "from histocartography.preprocessing import NucleiExtractor, DeepFeatureExtractor, KNNGraphBuilder"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "HISTO_PATCH_SAVE_PATH = \"histo_image_patch\"\n",
36 | "ACRH_DATA_PATH = \"arch\""
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "# Cell Graph Generation Definitions\n",
46 | "nuclei_detector = NucleiExtractor()\n",
47 | "feats_extractor = DeepFeatureExtractor(architecture='resnet34', patch_size=72, resize_size=224)\n",
48 | "knn_graph_builder = KNNGraphBuilder(k=5, thresh=50, add_loc_feats=True)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "# PathVQA Dataset Processing\n",
58 | "def get_arch_open_images(data_path : str = \"arch\"):\n",
59 | " \n",
60 | " # Extract the images for PubMed dataset\n",
61 | " pubmed_path = os.path.join(data_path, \"pubmed_set\")\n",
62 | " pubmed_path_captions = os.path.join(pubmed_path, \"captions.json\")\n",
63 | " pubmed_path_images = os.path.join(pubmed_path, \"images\")\n",
64 | " \n",
65 | " with open(pubmed_path_captions, 'rb') as file:\n",
66 | " pubmed_captions = json.load(file)\n",
67 | " \n",
68 | " # Getting all open-ended images\n",
69 | " pubmed_img_uuid = [value['uuid'] for index, value in pubmed_captions.items()]\n",
70 | " pubmed_img_uuid_path = [os.path.join(pubmed_path_images, img_uuid + '.jpg') for img_uuid in pubmed_img_uuid]\n",
71 | "\n",
72 | " # Extract the images for Books dataset\n",
73 | " books_path = os.path.join(data_path, \"books_set\")\n",
74 | " books_path_captions = os.path.join(books_path, \"captions.json\")\n",
75 | " books_path_images = os.path.join(books_path, \"images\")\n",
76 | " \n",
77 | " with open(books_path_captions, 'rb') as file:\n",
78 | " books_captions = json.load(file)\n",
79 | " \n",
80 | " # Getting all open-ended images\n",
81 | " books_img_uuid = [value['uuid'] for index, value in books_captions.items()]\n",
82 | " books_img_uuid_path = [os.path.join(books_path_images, img_uuid + '.png') for img_uuid in books_img_uuid]\n",
83 | " \n",
84 | " return pubmed_img_uuid, pubmed_img_uuid_path, books_img_uuid, books_img_uuid_path"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "pubmed_img_uuid, pubmed_img_uuid_path, books_img_uuid, books_img_uuid_path = get_arch_open_images()"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "# Save top patches using histocartography\n",
103 | "def save_histocartography_top_patches_arch(img_uuid : list, img_uuid_path: list, books_pubmed_class: str):\n",
104 | " # for image_idx in range(0, len(img_uuid)):\n",
105 | " \n",
106 | " for image_idx in range(0, 10):\n",
107 | " print(f\"{image_idx}/{len(img_uuid)}: Started \")\n",
108 | " query_img = Image.open(img_uuid_path[image_idx]).convert(mode=\"RGB\")\n",
109 | " image = np.array(query_img)\n",
110 | " nuclei_map, nuclei_centers = nuclei_detector.process(image)\n",
111 | "\n",
112 | " # Only consider if more than 5 nuclei are detected since knn needs to form a graph using 5 neighbors.\n",
113 | " # If less than 5 nuclei are present, most of the images are not pathology related\n",
114 | " if nuclei_centers.shape[0] > 5:\n",
115 | " print(f\"{image_idx}: Patches \")\n",
116 | " \n",
117 | " # Get the Features\n",
118 | " features = feats_extractor.process(image, nuclei_map)\n",
119 | " \n",
120 | " # Make Cell Graph\n",
121 | " cell_graph = knn_graph_builder.process(nuclei_map, features)\n",
122 | " \n",
123 | " # Make calculations to extract patches and the overlap images\n",
124 | " width, height = query_img.size\n",
125 | " width_range = np.linspace(0, width, 4, dtype=int)\n",
126 | " height_range = np.linspace(0, height, 4, dtype=int)\n",
127 | "\n",
128 | " overlap_percent = 20\n",
129 | " width_overlap = int((overlap_percent/100) * width)\n",
130 | " height_overlap = int((overlap_percent/100) * height)\n",
131 | " \n",
132 | " # Extract the patches\n",
133 | " image_patches = []\n",
134 | " patch_nuclei_centers = []\n",
135 | " for i in range(len(width_range)-1):\n",
136 | " for j in range(len(height_range)-1):\n",
137 | " # Consider the overlap width from second patch only\n",
138 | " if i != 0:\n",
139 | " start_width = width_range[i] - width_overlap\n",
140 | " else:\n",
141 | " start_width = width_range[i]\n",
142 | "\n",
143 | " # Consider the overlap height from second patch only\n",
144 | " if j != 0:\n",
145 | " start_height = height_range[j] - height_overlap\n",
146 | " else:\n",
147 | " start_height = height_range[j]\n",
148 | " \n",
149 | " # List out the patch ranges\n",
150 | " left = start_width\n",
151 | " upper = start_height\n",
152 | " right = width_range[i+1]\n",
153 | " lower = height_range[j+1]\n",
154 | " \n",
155 | " center_list = []\n",
156 | " for center in nuclei_centers:\n",
157 | " if ((center[0] >= left) and (center[0] <=right) and \n",
158 | " (center[1] >= upper) and (center[1] <=lower)):\n",
159 | " center_list.append(center)\n",
160 | "\n",
161 | " image_patches.append(query_img.crop((left, upper, right, lower)))\n",
162 | " patch_nuclei_centers.append(center_list)\n",
163 | "\n",
164 | " # Calculate the length of nuclei in each patch\n",
165 | " patch_center_length = []\n",
166 | " for center in patch_nuclei_centers:\n",
167 | " patch_center_length.append(len(center))\n",
168 | " \n",
169 | " # Sort the patch indices based on maximum number of nuclei\n",
170 | " sorted_indices_desc = np.flip(np.argsort(patch_center_length))\n",
171 | " \n",
172 | " # Create a directory to store all the patches of the image\n",
173 | " save_directory = os.path.join(os.getcwd(), HISTO_PATCH_SAVE_PATH, ACRH_DATA_PATH, books_pubmed_class, img_uuid[image_idx])\n",
174 | " if not os.path.isdir(save_directory):\n",
175 | " os.mkdir(save_directory)\n",
176 | " \n",
177 | " # Store all the image patches into the newly created directory\n",
178 | " for patch_index in range(0,6):\n",
179 | " save_file_path = os.path.join(save_directory, str(patch_index+1) + \".png\")\n",
180 | " image_patches[sorted_indices_desc[patch_index]].save(save_file_path)\n",
181 | " \n",
182 | " print(f\"{image_idx}/{len(img_uuid)}: Ended \")\n",
183 | " print(\".........\")"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": null,
189 | "metadata": {},
190 | "outputs": [],
191 | "source": [
192 | "save_histocartography_top_patches_arch(books_img_uuid, books_img_uuid_path, \"books\")"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": []
201 | }
202 | ],
203 | "metadata": {
204 | "kernelspec": {
205 | "display_name": "path-rag",
206 | "language": "python",
207 | "name": "python3"
208 | },
209 | "language_info": {
210 | "codemirror_mode": {
211 | "name": "ipython",
212 | "version": 3
213 | },
214 | "file_extension": ".py",
215 | "mimetype": "text/x-python",
216 | "name": "python",
217 | "nbconvert_exporter": "python",
218 | "pygments_lexer": "ipython3",
219 | "version": "3.9.19"
220 | }
221 | },
222 | "nbformat": 4,
223 | "nbformat_minor": 2
224 | }
225 |
--------------------------------------------------------------------------------
/test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Path-RAG: Knowledge-Guided Key Region Retrieval for Open-ended Pathology Visual Question Answering\n",
8 | "\n",
9 | "\n",
10 | "
\n",
11 | " \n",
12 | " *Accurate diagnosis and prognosis assisted by pathology images are essential for cancer treatment selection and planning. Despite the recent trend of adopting deep-learning approaches for analyzing complex pathology images, they fall short as they often overlook the domain-expert understanding of tissue structure and cell composition. In this work, we focus on a challenging Open-ended Pathology VQA (PathVQA-Open) task and propose a novel framework named Path-RAG, which leverages HistoCartography to retrieve relevant domain knowledge from pathology images and significantly improves performance on PathVQA-Open. Admitting the complexity of pathology image analysis, Path-RAG adopts a human-centered AI approach by retrieving domain knowledge using HistoCartography to select the relevant patches from pathology images. Our experiments suggest that domain guidance can significantly boost the accuracy of LLaVA-Med from 38\\% to 47\\%, with a notable gain of 28\\% for H\\&E-stained pathology images in the PathVQA-Open dataset. For longer-form question and answer pairs, our model consistently achieves significant improvements of 32.5\\% in ARCH-Open PubMed and 30.6\\% in ARCH-Open Books on H\\&E images.*\n",
13 | "
\n",
14 | "\n",
15 | "---\n",
16 | "\n",
17 | "Awais Naeem*, Tianhao Li*, Huang-Ru Liao*, Jiawei Xu*, Aby Mammen Mathew*, Zehao Zhu*, Zhen Tan**, Ajay Jaiswal*, Raffi Salibian*** , Ziniu Hu*** , Tianlong Chen****, Ying Ding*\n",
18 | "\n",
19 | "*University of Texas at Austin, USA \\\n",
20 | "**Arizona State University, USA \\\n",
21 | "***University of California, Los Angeles, USA \\\n",
22 | "****Massachusetts Institute of Technology, USA\n",
23 | "\n",
24 | "---\n",
25 | "\n",
26 | "## Path-RAG Implementation\n",
27 | "\n",
28 | "### 1. Clone this repository and navigate to path-rag folder\n",
29 | "\n",
30 | "```Shell\n",
31 | "git clone https://github.com/embedded-robotics/path-rag.git\n",
32 | "cd path-rag\n",
33 | "```\n",
34 | "\n",
35 | "### 2. Install Package: Create conda environment\n",
36 | "\n",
37 | "```Shell\n",
38 | "conda create -n path-rag python=3.10 -y\n",
39 | "conda activate path-rag\n",
40 | "pip install --upgrade pip # enable PEP 660 support for LLaVA-Med\n",
41 | "```\n",
42 | "\n",
43 | "### 3. Download the PathVQA dataset from the following link\n",
44 | "\n",
45 | "[PathVQA Dataset](https://github.com/UCSD-AI4H/PathVQA/blob/master/data/README.md)\n",
46 | "\n",
47 | "### 4. Clone the HistoCartography tool, setup the model checkpoints in `histocartography/checkpoints` and install the dependencies\n",
48 | "\n",
49 | "```Shell\n",
50 | "git clone https://github.com/BiomedSciAI/histocartography\n",
51 | "```\n",
52 | "\n",
53 | "### 5. Clone the LLaVA-Med repository and install the dependencies\n",
54 | "\n",
55 | "```Shell\n",
56 | "git clone https://github.com/microsoft/LLaVA-Med\n",
57 | "```\n",
58 | "\n",
59 | "### 6. Download the LLaMA-7B model and weights from HuggingFace\n",
60 | "\n",
61 | "```Shell\n",
62 | "python llama_7B_model_weights.py # LLaMA-7B weights/model stored into $HF_HOME (By Default $HF_HOME = ~/.cache/huggingface)\n",
63 | "```\n",
64 | "\n",
65 | "### 7. Download LLaVA-Med delta weights `llava_med_in_text_60k_ckpt2_delta` and `pvqa-9epoch_delta` from `https://github.com/microsoft/LLaVA-Med#model-download`. Put them inside a folder named `model_delta_weights`\n",
66 | "\n",
67 | "### 8. Apply the LLaVA-Med delta weights to base LLaMA-7B to come up with the final weights for LLaVA-Med\n",
68 | "\n",
69 | "```Shell\n",
70 | "cd LLaVA-Med\n",
71 | "```\n",
72 | "\n",
73 | "#### LLaVA-Med pre-trained on general biomedicine data\n",
74 | "\n",
75 | "```Shell\n",
76 | "!python3 -m llava.model.apply_delta \\\n",
77 | " --base ~/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16 \\\n",
78 | " --target ../final_models/llava_med \\\n",
79 | " --delta ../model_delta_weights/llava_med_in_text_60k_ckpt2_delta\n",
80 | "```\n",
81 | "\n",
82 | "#### LLaVA-Med fine-tuned on PathVQA\n",
83 | "\n",
84 | "```Shell\n",
85 | "!python -m llava.model.apply_delta \\\n",
86 | " --base ~/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16 \\\n",
87 | " --target ../final_models/llava_med_pvqa \\\n",
88 | " --delta ../model_delta_weights/pvqa-9epoch_delta\n",
89 | "```\n",
90 | "\n",
91 | "```Shell\n",
92 | "cd ..\n",
93 | "```\n",
94 | "\n",
95 | "### 9. Generate the top patches for open-ended PathVQA images using HistoCartography\n",
96 | "\n",
97 | "```Shell\n",
98 | "python generate_histo_patches.py\n",
99 | "```\n",
100 | "\n",
101 | "### 10. Generate the files for query to be asked for LLaVA-Med for both the images and patches\n",
102 | "\n",
103 | "```Shell\n",
104 | "python generate_llava_med_query.py\n",
105 | "```\n",
106 | "\n",
107 | "### 11. Now we need to generate the answer for all the query files using raw model (`final_models/llava_med`) and fine-tuned model (`final_models/llava_med_pvqa`)\n",
108 | "\n",
109 | "```Shell\n",
110 | "cd LLaVA-Med\n",
111 | "```\n",
112 | "\n",
113 | "#### Raw Model\n",
114 | "```Shell\n",
115 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med \\\n",
116 | " --question-file ../files/query/image_direct.jsonl \\\n",
117 | " --image-folder ../pvqa/images/test \\\n",
118 | " --answers-file ../files/answer/raw/answer_image_direct.jsonl\n",
119 | "```\n",
120 | "\n",
121 | "```Shell\n",
122 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med \\\n",
123 | " --question-file ../files/query/patch_direct.jsonl \\\n",
124 | " --image-folder ../pvqa/images/test \\\n",
125 | " --answers-file ../files/answer/raw/answer_patch_direct.jsonl\n",
126 | "```\n",
127 | "\n",
128 | "```Shell\n",
129 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med \\\n",
130 | " --question-file ../files/query/image_description.jsonl \\\n",
131 | " --image-folder ../pvqa/images/test \\\n",
132 | " --answers-file ../files/answer/raw/answer_image_description.jsonl\n",
133 | "```\n",
134 | "\n",
135 | "```Shell\n",
136 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med \\\n",
137 | " --question-file ../files/query/patch_description.jsonl \\\n",
138 | " --image-folder ../pvqa/images/test \\\n",
139 | " --answers-file ../files/answer/raw/answer_patch_description.jsonl\n",
140 | "```\n",
141 | "\n",
142 | "#### Fine-Tuned Model\n",
143 | "```Shell\n",
144 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \\\n",
145 | " --question-file ../files/query/image_direct.jsonl \\\n",
146 | " --image-folder ../pvqa/images/test \\\n",
147 | " --answers-file ../files/answer/fine-tuned/answer_image_direct.jsonl\n",
148 | "```\n",
149 | "\n",
150 | "```Shell\n",
151 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \\\n",
152 | " --question-file ../files/query/patch_direct.jsonl \\\n",
153 | " --image-folder ../pvqa/images/test \\\n",
154 | " --answers-file ../files/answer/fine-tuned/answer_patch_direct.jsonl\n",
155 | "```\n",
156 | "\n",
157 | "```Shell\n",
158 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \\\n",
159 | " --question-file ../files/query/image_description.jsonl \\\n",
160 | " --image-folder ../pvqa/images/test \\\n",
161 | " --answers-file ../files/answer/fine-tuned/answer_image_description.jsonl\n",
162 | "```\n",
163 | "\n",
164 | "```Shell\n",
165 | "python llava/eval/model_vqa.py --model-name ../final_models/llava_med_pvqa \\\n",
166 | " --question-file ../files/query/patch_description.jsonl \\\n",
167 | " --image-folder ../pvqa/images/test \\\n",
168 | " --answers-file ../files/answer/fine-tuned/answer_patch_description.jsonl\n",
169 | "```\n",
170 | "\n",
171 | "### 12. Evaluate the results for different use-cases using `recall_calculation.py`\n",
172 | "\n",
173 | "(i) Path-RAG w/o GPT: Combine the answer of image + all patches to be the final predicted answer\\\n",
174 | "(ii) Path-RAG (description): Combine the description of image + all patches. Then involve GPT-4 for reasoning to ge the final predicted answer (See Supplementary Section for Prompts)\\\n",
175 | "(iii) Path-RAG (answer): Combine the answer of image + all patches. Then involve GPT-4 for reasoning to ge the final predicted answer (See Supplementary Section for Prompts)\n",
176 | "\n",
177 | "\n",
178 | "## ARCH-Open Dataset\n",
179 | "\n",
180 | "1. Download the `books_set` and `pubmed_set` of ARCH dataset from `https://warwick.ac.uk/fac/cross_fac/tia/data/arch`. Store both of these folders in a folder named `arch`. Both `books_set` and `pubmed_set` contains `captions.json` which lists a **caption** and a **UUID**, whereas **UUID** represents the file name in `images` folder and **caption** represents the description of the image.\n",
181 | "\n",
182 | "2. Using `captions.json` and `images` folder under `arch/books_set`, run the notebooks `ARCH-OPEN/books_data/synthetic_data_textbook.ipynb` by specifying the OpenAI credentials to generate the question-answer pairs for books set\n",
183 | "\n",
184 | "3. Using `captions.json` and `images` folder under `arch/pubmed_set`, run the notebooks `ARCH-OPEN/pubmed_data/synthetic_data_pubmed.ipynb` by specifying the OpenAI credentials to generate the question-answer pairs for pubmed set\n",
185 | "\n",
186 | "4. Run the notebook `ARCH-OPEN/synthetic_data_compilation.ipynb` to compile the `pubmed` and `books` question-answer pairs into json files namely `ARCH-OPEN/pubmed_qa_pairs.json` and `ARCH-OPEN/textbook_qa_pairs.json`. These files are already provided to be used directly\n",
187 | "\n",
188 | "5. The `pubmed_qa_pairs.json` and `textbook_qa_pairs.json` files contain 5 question-pairs for each pair of `caption` and `uuid` (refers to image name in arch data `arch/pubmed_set/images`, `arch/books_set/images`) in the following format (for both `pubmed_set` and `books_set`):\n",
189 | "\n",
190 | "```Shell\n",
191 | " {\n",
192 | " \"figure_id\": \"00\",\n",
193 | " \"letter\": \"A\",\n",
194 | " \"caption\": \" A, Spindle cell variant of embryonal rhabdomyosarcoma is characterized by fascicles of eosinophilic spindle cells (B), some of which can show prominent paranuclear vacuolisation, as seen in leiomyosarcoma.\",\n",
195 | " \"uuid\": \"890e2e79-ab0a-4a2e-9d62-b0b6b3d43884\",\n",
196 | " \"Question_1\": \"What could be the general shape of cells in a spindle cell variant of embryonal rhabdomyosarcoma as seen in the image?\",\n",
197 | " \"Answer_1\": \"The cells often present with a spindle-like elongated shape.\",\n",
198 | " \"Question_2\": \"What type of structures could be visible in the image indicating the presence of spindle cells?\",\n",
199 | " \"Answer_2\": \"Fascicles, or bundles, of cells could be visible in the image, indicating the presence of spindle cells.\",\n",
200 | " \"Question_3\": \"Where in the cell would we likely find paranuclear vacuolisation in the image?\",\n",
201 | " \"Answer_3\": \"Paranuclear vacuolisation is usually seen around the nucleus area of the cell.\",\n",
202 | " \"Question_4\": \"What color might the spindle cells appear in the image?\",\n",
203 | " \"Answer_4\": \"The spindle cells may appear eosinophilic, or pinkish-red, under the microscope due to staining.\",\n",
204 | " \"Question_5\": \"What visual feature might differentiate spindle cells from leiomyosarcoma cells in the image?\",\n",
205 | " \"Answer_5\": \"Spindle cells might show prominent paranuclear vacuolisation, a feature that can differentiate them from leiomyosarcoma cells.\"\n",
206 | " }\n",
207 | "```\n",
208 | "\n",
209 | "## Acknolwdgement\n",
210 | "We would like to acknowledge the following funding supports: NIH OT2OD032581, NIH OTA-21-008, NIH 1OT2OD032742-01, NSF 2333703, NSF 2303038."
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "metadata": {},
216 | "source": []
217 | }
218 | ],
219 | "metadata": {
220 | "language_info": {
221 | "name": "python"
222 | }
223 | },
224 | "nbformat": 4,
225 | "nbformat_minor": 2
226 | }
227 |
--------------------------------------------------------------------------------