├── LLMs └── .gitignore ├── README.md ├── app.py ├── data └── paper_PDF │ └── .gitignore ├── embedding_models └── .gitignore ├── multi-modal_qdrant_retrieval_demo.ipynb ├── ollama_multimodal_llm.py ├── qdrant_db └── .gitignore ├── retrieval_serving.py ├── retriever_config.yaml ├── serve_grobid_light.sh └── src ├── __pycache__ ├── __init__.cpython-311.pyc ├── custom_embeddings.cpython-311.pyc ├── custom_vectore_store.cpython-311.pyc ├── mm_retriever.cpython-311.pyc └── prompt.cpython-311.pyc ├── build_qdrant_collections.py ├── custom_embeddings.py ├── custom_vectore_store.py ├── llava_llamacpp.py ├── mm_query_engine.py ├── mm_retriever.py ├── parse_pdf.py └── prompt.py /LLMs/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Virgil-L/Local_MultiModal_RAG_with_llamaindex/53fe2eed1be93ba795d9955b6f5a9493b79cdba4/LLMs/.gitignore -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiModal RAG with LlamaIndex + Qdrant + Local Vision-LLM & Embedding models 2 | 3 | ## Overview 4 | 5 | This project is implemented within the framework of LlamaIndex, using multiple custom components to achieve a fully localized multimodal document-QA system without relying on any APIs or remote resources. 6 | 7 | Below are the main tools/models used: 8 | 9 | - **PDF parser**: [SciPDF Parser](https://github.com/titipata/scipdf_parser) 10 | 11 | - **RAG Framework**: [LlamaIndex](https://github.com/run-llama/llama_index) 12 | 13 | - **Vector DataBase**: [Qdrant](https://qdrant.tech/) 14 | 15 | - **Vison-LLM**: A [gguf quantized 7B version](https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf) of [LLaVA](https://llava-vl.github.io/) 16 | 17 | - **LLM Inference Framework**: [llama.cpp](https://github.com/ggerganov/llama.cpp) & [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) 18 | 19 | - **Embedding Models**: 20 | - **BGE** models for text [embedding](https://huggingface.co/BAAI/bge-small-en-v1.5) and [reranking](https://huggingface.co/BAAI/bge-reranker-base) 21 | 22 | - Efficient **SPLADE** models ([doc](https://huggingface.co/naver/efficient-splade-VI-BT-large-doc), [query](https://huggingface.co/naver/efficient-splade-VI-BT-large-query)) for sparse retrieval 23 | 24 | 25 | - [**CLIP**](https://huggingface.co/openai/clip-vit-base-patch32) for query-to-image retrieval 26 | 27 | 28 | ## Environment 29 | 30 | WSL2 on Windows 10, Ubuntu 22.04 31 | 32 | CUDA Toolkit Version 12.3 33 | 34 | 35 | ## Library Installation 36 | 37 | **SciPDF Parser** 38 | 39 | Follow the steps from https://github.com/titipata/scipdf_parser/blob/master/README.md. 40 | 41 | Install this branch to avoid exceeding the request time limit when dealing with large PDF files. 42 | ``` 43 | pip install https://github.com/Virgil-L/scipdf_parser.git 44 | ``` 45 | To save memory and computing cost, this project uses a lightweight image of [GROBID](https://github.com/kermitt2/grobid), run `serve_grobid_light.sh` to get it. 46 | 47 | **LlamaIndex** 48 | 49 | ``` 50 | pip install -q llama-index llama-index-embeddings-huggingface 51 | ``` 52 | 53 | **Qdrant** 54 | ``` 55 | pip install qdrant-client 56 | ``` 57 | 58 | **llama-cpp-python** 59 | 60 | ``` 61 | # Linux and Mac 62 | CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install llama-cpp-python 63 | ``` 64 | 65 | ## Examples 66 | 67 | Download the sample pdf data 68 | ``` 69 | wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/paper_PDF/llama2.pdf" 70 | wget --user-agent "Mozilla" "https://arxiv.org/pdf/2310.03744.pdf" -O "data/paper_PDF/llava-vl.pdf" 71 | 72 | wget --user-agent "Mozilla" "https://arxiv.org/pdf/1706.03762.pdf" -O "data/paper_PDF/attention.pdf" 73 | wget --user-agent "Mozilla" "https://arxiv.org/pdf/1810.04805.pdf" -O "data/paper_PDF/bert.pdf" 74 | ``` 75 | 76 | Parse pdf files to extract text and image elemnts 77 | ``` 78 | python parse_pdf.py \ 79 | --pdf_folder "./data/paper_PDF" \ 80 | --image_resolution 300 \ 81 | --max_timeout 120 82 | ``` 83 | 84 | Create Qdrant collections, build and ingest text and image nodes 85 | ``` 86 | python build_qdrant_collections.py \ 87 | --pdf_folder "./data/paper_PDF" \ 88 | --storage_path "./qdrant_db" \ 89 | --text_embedding_model "./embedding_models/bge-small-en-v1.5" \ 90 | --sparse_text_embedding_model "./embedding_models/efficient-splade-VI-BT-large-doc" \ 91 | --image_embedding_model "./embedding_models/clip-vit-base-patch32" \ 92 | --chunk_size 384 \ 93 | --chunk_overlap 32 94 | ``` -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import streamlit as st 3 | from PIL import Image 4 | import argparse 5 | from streamlit_feedback import streamlit_feedback 6 | import time 7 | from io import BytesIO 8 | import base64 9 | import requests 10 | from tqdm import tqdm 11 | import json 12 | import ollama 13 | from src.prompt import generate_sys_prompt 14 | 15 | MLLM_NAME = 'minicpm-v:8b-2.6-q4_K_M' 16 | DEFAULT_PROLOGUE = "Hello! I'm a multimodal retriever. I can help you with a variety of tasks. What would you like to know?" 17 | 18 | st.set_page_config( 19 | page_title="Chat with Multimodal Retriever", 20 | page_icon="🔍", 21 | initial_sidebar_state="expanded", 22 | menu_items={ 23 | "Get help": "https://github.com/Virgil-L", 24 | "Report a bug": "https://github.com/Virgil-L", 25 | "About": "Built by @Virgil-L with Streamlit & LlamaIndex", 26 | } 27 | ) 28 | 29 | # Initialize session state 30 | if 'llm_prompt_tokens' not in st.session_state: 31 | st.session_state['llm_prompt_tokens'] = 0 32 | 33 | if 'llm_completion_tokens' not in st.session_state: 34 | st.session_state['llm_completion_tokens'] = 0 35 | 36 | if 'messages' not in st.session_state: 37 | st.session_state['messages'] = [] 38 | 39 | if 'btn_llama_index' not in st.session_state: 40 | st.session_state['btn_llama_index'] = False 41 | 42 | if 'btn_retriever' not in st.session_state: 43 | st.session_state['btn_retriever'] = False 44 | 45 | if 'btn_diff' not in st.session_state: 46 | st.session_state['btn_diff'] = False 47 | 48 | if 'btn_rag' not in st.session_state: 49 | st.session_state['btn_rag'] = False 50 | # if 'openai_api_key' in st.session_state: 51 | # openai.api_key = st.session_state['openai_api_key'] 52 | 53 | def mm_retrieve(query, text_topk = 3, image_topk = 1, port=5000): 54 | req_data = {"query": query, "text_topk": text_topk, "image_topk": image_topk} 55 | 56 | response = requests.post(f"http://localhost:{port}/api", headers={"Content-Type": "application/json"}, data=json.dumps(req_data)) 57 | rep_data = response.json() 58 | text_sources, image_sources = [], [] 59 | if 'text_result' in rep_data: 60 | text_sources = [{'text': item['node']['text'], 61 | 'elementType': item['node']['metadata']['metadata']['elementType'], 62 | 'source_file': item['node']['metadata']['metadata']['source_file_path']} for item in rep_data['text_result']] 63 | if 'image_result' in rep_data: 64 | image_sources = [{'image': base64.b64decode(item['node']['image']), 65 | 'caption': item['node']['text'], 66 | 'elementType': item['node']['metadata']['elementType'], 67 | 'source_file': item['node']['metadata']['source_file_path']} for item in rep_data['image_result']] 68 | return { 69 | 'text_sources': text_sources, 70 | 'image_sources': image_sources 71 | } 72 | 73 | 74 | def display_sources(sources): 75 | with st.expander("See sources"): 76 | if sources['text_sources']: 77 | st.markdown("#### Text sources:") 78 | for i, item in enumerate(sources['text_sources']): 79 | # 对每一篇文章,创建一个container,展示文章的文本内容 80 | txt_sources_container = st.container() 81 | with txt_sources_container: 82 | st.markdown(f"**Ref [{i+1}]**:\n\nfrom: {item['source_file']}\n\n > {item['text']}\n\n") 83 | if sources['image_sources']: 84 | st.markdown("#### Image sources:") 85 | for i, item in enumerate(sources['image_sources']): 86 | img_sources_container = st.container() 87 | with img_sources_container: 88 | st.markdown(f"**Ref [{i+1}]**\n\nfrom: {(item['source_file'])} \n\n") 89 | st.image(item['image'], caption=item['caption'], use_column_width=True) 90 | 91 | 92 | def display_chat_history(messages, dialogue_container): 93 | """Display previous chat messages.""" 94 | with dialogue_container: 95 | for message in messages: 96 | with st.chat_message(message["role"]): 97 | if st.session_state.with_sources: 98 | if "sources" in message: 99 | #TODO: 展示图片,文本块等参考信息 100 | display_sources(message["sources"]) 101 | st.write(message["content"]) 102 | 103 | 104 | 105 | 106 | def clear_chat_history(): 107 | """"Clear chat history and reset questions' buttons.""" 108 | 109 | # st.session_state.messages = [ 110 | # {"role": "assistant", "content": DEFAULT_PROLOGUE} 111 | # ] 112 | st.session_state.messages = [] 113 | st.session_state["btn_llama_index"] = False 114 | st.session_state["btn_retriever"] = False 115 | st.session_state["btn_diff"] = False 116 | st.session_state["btn_rag"] = False 117 | 118 | 119 | 120 | def upload_image(): 121 | uploaded_file = st.file_uploader("Choose an image from your computer", type=["jpg", "jpeg", "png", "webp"]) 122 | if uploaded_file is not None: 123 | image = Image.open(uploaded_file) 124 | st.image(image, caption="Uploaded Image", use_column_width=True) 125 | return image 126 | 127 | 128 | 129 | def get_user_input(): 130 | query = st.chat_input(placeholder="Enter your prompt here", max_chars=2048) 131 | if query: 132 | st.session_state['llm_prompt_tokens'] += 1 133 | st.session_state['messages'].append({"role": "user", "content": query}) 134 | return query 135 | 136 | 137 | 138 | def request_assistant_response(messages, sources=None): 139 | ##TODO: prompt模板、多模态上下文输入、多轮对话 140 | system = generate_sys_prompt(sources) 141 | 142 | if messages[0]['role'] == 'assistant': 143 | messages = messages[1:] 144 | if sources and sources['image_sources']: 145 | messages.insert(0, {'role': 'system', 'content': system, 'images': [item['image'] for item in sources['image_sources']]}) 146 | else: 147 | messages.insert(0, {'role': 'system', 'content': system}) 148 | resp_stream = ollama.chat( 149 | model=MLLM_NAME, 150 | messages=messages, 151 | stream=True, 152 | keep_alive=600, 153 | options={ 154 | 'num_predict': -1, 155 | 'temperature': st.session_state.get('temperature', 0.5), 156 | 'top_p': st.session_state.get('top_p', 0.9), 157 | 'stop': ['', '<|im_end|>'], 158 | 'frequency_penalty':st.session_state.get('frequency_penalty', 2.0), 159 | 'num_ctx':8192, 160 | }, 161 | ) 162 | for chunk in resp_stream: 163 | yield chunk['message']['content'] 164 | 165 | 166 | def generate_assistant_response(messages, sources=None): 167 | 168 | with st.spinner("I am on it..."): 169 | resp_stream = request_assistant_response(messages, sources) 170 | with st.chat_message("assistant"): 171 | full_response = st.write_stream(resp_stream) 172 | message = {'role': 'assistant', 'content': full_response} 173 | if st.session_state.with_sources: 174 | message['sources'] = sources 175 | display_sources(sources) 176 | st.session_state.messages.append(message) 177 | 178 | 179 | 180 | def format_sources(response): 181 | ##TODO 182 | # """Format filename, authors and scores of the response source nodes.""" 183 | # base = "https://github.com/jerryjliu/llama_index/tree/main/" 184 | # return "\n".join([f"- {base}{source['filename']} (author: '{source['author']}'; score: {source['score']})\n" for source in get_metadata(response)]) 185 | raise NotImplementedError 186 | 187 | def get_metadata(response): 188 | """Parse response source nodes and return a list of dictionaries with filenames, authors and scores.""" 189 | ##TODO 190 | # sources = [] 191 | # for item in response.source_nodes: 192 | # if hasattr(item, "metadata"): 193 | # filename = item.metadata.get('filename').replace('\\', '/') 194 | # author = item.metadata.get('author') 195 | # score = float("{:.3f}".format(item.score)) 196 | # sources.append({'filename': filename, 'author': author, 'score': score}) 197 | 198 | # return sources 199 | raise NotImplementedError 200 | 201 | 202 | 203 | def side_bar(): 204 | """Configure the sidebar and user's preferences.""" 205 | st.sidebar.title("Configurations") 206 | 207 | with st.sidebar.expander("🔑 OPENAI-API-KEY", expanded=True): 208 | st.text_input(label='OPENAI-API-KEY', 209 | type='password', 210 | key='openai_api_key', 211 | label_visibility='hidden').strip() 212 | "[Get an OpenAI API key](https://platform.openai.com/account/api-keys)" 213 | 214 | with st.sidebar.expander("💲 GPT3.5 INFERENCE COST", expanded=True): 215 | i_tokens = st.session_state['llm_prompt_tokens'] 216 | o_tokens = st.session_state['llm_completion_tokens'] 217 | st.markdown(f'LLM Prompt: {i_tokens} tokens') 218 | st.markdown(f'LLM Completion: {o_tokens} tokens') 219 | 220 | i_cost = (i_tokens / 1000) * 0.0015 221 | o_cost = (o_tokens / 1000) * 0.002 222 | st.markdown('**Cost Estimation: ${0}**'.format(round(i_cost + o_cost, 5))) 223 | "[OpenAI Pricing](https://openai.com/pricing)" 224 | 225 | with st.sidebar.expander("🔧 SETTINGS", expanded=True): 226 | st.toggle('Cache Results', value=True, key="with_cache") 227 | st.toggle('Display Sources', value=True, key="with_sources") 228 | st.toggle('Streaming', value=False, key="with_streaming") 229 | st.toggle('Debug Info', value=False, key="debug_mode") 230 | 231 | st.sidebar.button('Clear Messages', type="primary", on_click=clear_chat_history) 232 | 233 | if st.session_state.debug_mode: 234 | with st.sidebar.expander(" 🕸️ Current Session State", expanded=True): 235 | st.write(st.session_state) 236 | 237 | 238 | ## Show external links 239 | 240 | #st.sidebar.divider() 241 | # with st.sidebar: 242 | # col_ll, col_gh = st.columns([1, 1]) 243 | # with col_ll: 244 | # "[![LlamaIndex Docs](https://img.shields.io/badge/LlamaIndex%20Docs-gray)](https://gpt-index.readthedocs.io/en/latest/index.html)" 245 | # with col_gh: 246 | # "[![Github](https://img.shields.io/badge/Github%20Repo-gray?logo=Github)](https://github.com/dcarpintero/llamaindexchat)" 247 | 248 | 249 | 250 | 251 | 252 | def layout(): 253 | st.header("Chat with 🦙 LlamaIndex Docs 🗂️") 254 | 255 | 256 | # Sample Questions for User input 257 | user_input_button = None 258 | 259 | ##TODO: Modify the sample questions 260 | btn_llama_index = st.session_state.get("btn_llama_index", False) 261 | btn_retriever = st.session_state.get("btn_retriever", False) 262 | btn_diff = st.session_state.get("btn_diff", False) 263 | btn_rag = st.session_state.get("btn_rag", False) 264 | 265 | col1, col2, col3, col4 = st.columns([1,1,1,1]) 266 | 267 | with col1: 268 | if st.button("explain the basic usage pattern of LlamaIndex", type="primary", disabled=btn_llama_index): 269 | user_input_button = "explain the basic usage pattern in LlamaIndex" 270 | st.session_state.btn_llama_index = True 271 | with col2: 272 | if st.button("how can I ingest data from the GoogleDocsReader?", type="primary", disabled=btn_retriever): 273 | user_input_button = "how can I ingest data from the GoogleDocsReader?" 274 | st.session_state.btn_retriever = True 275 | with col3: 276 | if st.button("what's the difference between document & node?", type="primary", disabled=btn_diff): 277 | user_input_button = "what's the difference between document and node?" 278 | st.session_state.btn_diff = True 279 | with col4: 280 | if st.button("how can I make a RAG application performant?", type="primary", disabled=btn_rag): 281 | user_input_button = "how can I make a RAG application performant?" 282 | st.session_state.btn_rag = True 283 | 284 | 285 | # System Message 286 | if "messages" not in st.session_state or not st.session_state.messages: 287 | st.session_state.messages = [ 288 | {"role": "assistant", "content": "Try one of the sample questions or ask your own!"} 289 | ] 290 | 291 | dialogue_container = st.container() 292 | 293 | # User input 294 | user_input = st.chat_input("Enter your question here") 295 | if user_input or user_input_button: 296 | st.session_state.messages.append({"role": "user", "content": user_input or user_input_button}) 297 | 298 | # Display previous chat 299 | display_chat_history(st.session_state.messages, dialogue_container) 300 | 301 | # Generate response 302 | if st.session_state.messages[-1]["role"] != "assistant": 303 | try: 304 | sources = None 305 | if st.session_state.with_sources: 306 | sources = mm_retrieve(user_input or user_input_button) 307 | generate_assistant_response(st.session_state.messages, sources) 308 | 309 | except Exception as ex: 310 | st.error(str(ex)) 311 | 312 | 313 | @st.cache_resource 314 | def initialization(): 315 | with st.spinner('Loading Environment...'): 316 | try: 317 | ##TODO: check retriever and ollama service, if not running, start them 318 | time.sleep(3) 319 | pass 320 | 321 | except Exception as ex: 322 | st.error(str(ex)) 323 | st.stop() 324 | 325 | def main(): 326 | # Initializations 327 | initialization() 328 | side_bar() 329 | layout() 330 | 331 | 332 | 333 | if __name__ == "__main__": 334 | main() -------------------------------------------------------------------------------- /data/paper_PDF/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Virgil-L/Local_MultiModal_RAG_with_llamaindex/53fe2eed1be93ba795d9955b6f5a9493b79cdba4/data/paper_PDF/.gitignore -------------------------------------------------------------------------------- /embedding_models/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Virgil-L/Local_MultiModal_RAG_with_llamaindex/53fe2eed1be93ba795d9955b6f5a9493b79cdba4/embedding_models/.gitignore -------------------------------------------------------------------------------- /multi-modal_qdrant_retrieval_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Build Qdrant Vector Stores" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from llama_index.legacy.vector_stores import QdrantVectorStore\n", 17 | "from custom_vectore_store import MultiModalQdrantVectorStore\n", 18 | "from custom_embeddings import custom_sparse_doc_vectors, custom_sparse_query_vectors\n", 19 | "\n", 20 | "from functools import partial\n", 21 | "\n", 22 | "from qdrant_client import QdrantClient\n", 23 | "from qdrant_client.http import models as qd_models\n", 24 | "\n", 25 | "try:\n", 26 | " client = QdrantClient(path=\"qdrant_db\")\n", 27 | " print(\"Connected to Qdrant\")\n", 28 | "except:\n", 29 | " pass\n", 30 | " print(\"Failed to connect to Qdrant\")\n", 31 | "\n", 32 | "\n", 33 | "\n", 34 | "import torch\n", 35 | "from transformers import AutoTokenizer, AutoModelForMaskedLM\n", 36 | "\n", 37 | "SPLADE_QUERY_PATH = \"./embedding_models/efficient-splade-VI-BT-large-query\"\n", 38 | "splade_q_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_PATH)\n", 39 | "splade_q_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_PATH)\n", 40 | "\n", 41 | "SPLADE_DOC_PATH = \"./embedding_models/efficient-splade-VI-BT-large-doc\"\n", 42 | "splade_d_tokenizer = AutoTokenizer.from_pretrained(SPLADE_DOC_PATH)\n", 43 | "splade_d_model = AutoModelForMaskedLM.from_pretrained(SPLADE_DOC_PATH)\n", 44 | "\n", 45 | "custom_sparse_doc_fn = partial(custom_sparse_doc_vectors, splade_d_tokenizer, splade_d_model, 512)\n", 46 | "custom_sparse_query_fn = partial(custom_sparse_query_vectors, splade_q_tokenizer, splade_q_model, 512)\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "text_store = QdrantVectorStore(\n", 56 | " client=client,\n", 57 | " collection_name=\"text_collection\",\n", 58 | " enable_hybrid=True,\n", 59 | " sparse_query_fn=custom_sparse_query_fn,\n", 60 | " sparse_doc_fn=custom_sparse_doc_fn,\n", 61 | " stores_text=True,\n", 62 | ")\n", 63 | "\n", 64 | "image_store = MultiModalQdrantVectorStore(\n", 65 | " client=client,\n", 66 | " collection_name=\"image_collection\",\n", 67 | " enable_hybrid=True,\n", 68 | " sparse_query_fn=custom_sparse_query_fn,\n", 69 | " sparse_doc_fn=custom_sparse_doc_fn,\n", 70 | " stores_text=False,\n", 71 | ")" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "from llama_index.legacy.embeddings import HuggingFaceEmbedding\n", 81 | "from custom_embeddings import CustomizedCLIPEmbedding\n", 82 | "\n", 83 | "BGE_PATH = \"./embedding_models/bge-small-en-v1.5\"\n", 84 | "CLIP_PATH = \"./embedding_models/clip-vit-base-patch32\"\n", 85 | "bge_embedding = HuggingFaceEmbedding(model_name=BGE_PATH, device=\"cpu\", pooling=\"mean\")\n", 86 | "clip_embedding = CustomizedCLIPEmbedding(model_name=CLIP_PATH, device=\"cpu\")\n" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## Customized Multi-modal Retriever with Reranker" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "from llama_index.core.postprocessor import SentenceTransformerRerank\n", 103 | "\n", 104 | "bge_reranker = SentenceTransformerRerank(\n", 105 | " model=\"./embedding_models/bge-reranker-base\",\n", 106 | " top_n=3,\n", 107 | " device=\"cpu\",\n", 108 | " keep_retrieval_score=False,\n", 109 | " )\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "from mm_retriever import MultiModalQdrantRetriever\n", 119 | "\n", 120 | "mm_retriever = MultiModalQdrantRetriever(\n", 121 | " text_vector_store = text_store,\n", 122 | " image_vector_store = image_store, \n", 123 | " text_embed_model = bge_embedding, \n", 124 | " mm_embed_model = clip_embedding,\n", 125 | " reranker = bge_reranker,\n", 126 | " text_similarity_top_k = 5,\n", 127 | " text_sparse_top_k = 5,\n", 128 | " text_rerank_top_n = 3,\n", 129 | " image_similarity_top_k = 5,\n", 130 | " image_sparse_top_k = 5,\n", 131 | " image_rerank_top_n = 1,\n", 132 | " sparse_query_fn = custom_sparse_query_fn,\n", 133 | ")" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "from llama_index.legacy.schema import QueryBundle\n", 143 | "query_bundle=QueryBundle(query_str=\"How does Llama 2 perform compared to other open-source models?\")\n", 144 | "\n", 145 | "# text_query_result = mm_retriever.retrieve_text_nodes(query_bundle=query_bundle, query_mode=\"hybrid\")\n", 146 | "# reranked_text_nodes = mm_retriever.rerank_text_nodes(query_bundle, text_query_result)\n", 147 | "# image_query_result = mm_retriever.retrieve_image_nodes(query_bundle=query_bundle, query_mode=\"hybrid\")\n", 148 | "# reranked_image_nodes = mm_retriever.rerank_image_nodes(query_bundle, image_query_result)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "## Load Quantized LLaVA-1.6 with llama-cpp framework" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "from llama_cpp.llama_chat_format import Llava15ChatHandler\n", 165 | "\n", 166 | "llava_chat_handler = Llava15ChatHandler(\n", 167 | " clip_model_path = \"LLMs/llava-1.6-mistral-7b-gguf/mmproj-model-f16.gguf\",\n", 168 | " verbose = False\n", 169 | ")\n", 170 | "\n", 171 | "\n", 172 | "## Load LLaVA with the original llama-cpp python bindings \n", 173 | "\n", 174 | "# from llama_cpp import Llama\n", 175 | "\n", 176 | "# llava_1_6 = Llama(\n", 177 | "# model_path=\"LLMs/llava-1.6-mistral-7b-gguf/llava-v1.6-mistral-7b.Q4_K_M.gguf\",\n", 178 | "# chat_format=\"llava-1-5\",\n", 179 | "# chat_handler=llava_chat_handler, # Optional chat handler to use when calling create_chat_completion.\n", 180 | "# n_ctx=2048, # (context window size) Text context, 0 = from model\n", 181 | "# logits_all=True, # Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.\n", 182 | "# offload_kqv=True, # Offload K, Q, V to GPU.\n", 183 | "# n_gpu_layers=40, # Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.\n", 184 | "# last_n_tokens_size=64, # maximum number of tokens to keep in the last_n_tokens deque.\n", 185 | "# verbose=True,\n", 186 | "\n", 187 | "# ## LoRA Params\n", 188 | "# # lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.\n", 189 | "# # lora_scale: float = 1.0,\n", 190 | "# # lora_path: Path to a LoRA file to apply to the model.\n", 191 | "\n", 192 | "# ## Tokenizer Override\n", 193 | "# # tokenizer: Optional[BaseLlamaTokenizer] = None,\n", 194 | "# )" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "## Load LLaVA with customized llama-index integration\n", 204 | "from llava_llamacpp import Llava_LlamaCPP\n", 205 | "\n", 206 | "model_kwargs = {\n", 207 | " \"chat_format\":\"llava-1-5\",\n", 208 | " \"chat_handler\":llava_chat_handler, \n", 209 | " \"logits_all\":True,\n", 210 | " \"offload_kqv\":True,\n", 211 | " \"n_gpu_layers\":40,\n", 212 | " \"last_n_tokens_size\":64,\n", 213 | " \n", 214 | " ## LoRA Params\n", 215 | " # lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.\n", 216 | " # lora_scale: float = 1.0,\n", 217 | " # lora_path: Path to a LoRA file to apply to the model.\n", 218 | "\n", 219 | " ## Tokenizer Override\n", 220 | " # tokenizer: Optional[BaseLlamaTokenizer] = None,\n", 221 | "}\n", 222 | "\n", 223 | "llava_1_6 = Llava_LlamaCPP(\n", 224 | " model_path=\"LLMs/llava-1.6-mistral-7b-gguf/llava-v1.6-mistral-7b.Q3_K_M.gguf\",\n", 225 | " temperature=0.5,\n", 226 | " max_new_tokens=1024,\n", 227 | " context_window=4096,\n", 228 | " verbose=True,\n", 229 | " model_kwargs = model_kwargs,\n", 230 | ")" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "## Build Query Engine" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "from mm_query_engine import CustomMultiModalQueryEngine\n", 247 | "\n", 248 | "query_engine = CustomMultiModalQueryEngine(\n", 249 | " retriever = mm_retriever,\n", 250 | " multi_modal_llm = llava_1_6,\n", 251 | ")" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "# retrieval_results = query_engine.retrieve(query_bundle=query_bundle, text_query_mode=\"hybrid\", image_query_mode=\"default\")\n", 261 | "# response = query_engine.synthesize(query_bundle, retrieval_results)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "response = query_engine.query(query_bundle)" 271 | ] 272 | } 273 | ], 274 | "metadata": { 275 | "kernelspec": { 276 | "display_name": "cu118py311", 277 | "language": "python", 278 | "name": "python3" 279 | }, 280 | "language_info": { 281 | "codemirror_mode": { 282 | "name": "ipython", 283 | "version": 3 284 | }, 285 | "file_extension": ".py", 286 | "mimetype": "text/x-python", 287 | "name": "python", 288 | "nbconvert_exporter": "python", 289 | "pygments_lexer": "ipython3", 290 | "version": "3.11.3" 291 | } 292 | }, 293 | "nbformat": 4, 294 | "nbformat_minor": 2 295 | } 296 | -------------------------------------------------------------------------------- /ollama_multimodal_llm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Sequence, Tuple, List 2 | 3 | from llama_index.core.base.llms.types import ( 4 | ChatMessage, 5 | ChatResponse, 6 | ChatResponseAsyncGen, 7 | ChatResponseGen, 8 | CompletionResponse, 9 | CompletionResponseAsyncGen, 10 | CompletionResponseGen, 11 | MessageRole, 12 | ) 13 | from llama_index.core.bridge.pydantic import Field 14 | from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS 15 | from llama_index.core.multi_modal_llms import ( 16 | MultiModalLLM, 17 | MultiModalLLMMetadata, 18 | ) 19 | from llama_index.core.multi_modal_llms.generic_utils import image_documents_to_base64 20 | from llama_index.core.schema import ImageDocument 21 | 22 | 23 | def get_additional_kwargs( 24 | response: Dict[str, Any], exclude: Tuple[str, ...] 25 | ) -> Dict[str, Any]: 26 | return {k: v for k, v in response.items() if k not in exclude} 27 | 28 | 29 | def _messages_to_dicts(messages: Sequence[ChatMessage]) -> Sequence[Dict[str, Any]]: 30 | """Convert messages to dicts. 31 | 32 | For use in ollama API 33 | 34 | """ 35 | results = [] 36 | for message in messages: 37 | # TODO: just pass through the image arg for now. 38 | # TODO: have a consistent interface between multimodal models 39 | images = message.additional_kwargs.get("images") 40 | results.append( 41 | { 42 | "role": message.role.value, 43 | "content": message.content, 44 | "images": images, 45 | } 46 | ) 47 | return results 48 | 49 | 50 | class CustomOllamaMultiModal(MultiModalLLM): 51 | model: str = Field(description="The MultiModal Ollama model to use.") 52 | temperature: float = Field( 53 | default=0.75, 54 | description="The temperature to use for sampling.", 55 | gte=0.0, 56 | lte=1.0, 57 | ) 58 | context_window: int = Field( 59 | default=DEFAULT_CONTEXT_WINDOW, 60 | description="The maximum number of context tokens for the model.", 61 | gt=0, 62 | ) 63 | additional_kwargs: Dict[str, Any] = Field( 64 | default_factory=dict, 65 | description="Additional model parameters for the Ollama API.", 66 | ) 67 | 68 | def __init__(self, **kwargs: Any) -> None: 69 | """Init params.""" 70 | # make sure that ollama is installed 71 | try: 72 | import ollama # noqa: F401 73 | except ImportError: 74 | raise ImportError( 75 | "Ollama is not installed. Please install it using `pip install ollama`." 76 | ) 77 | super().__init__(**kwargs) 78 | 79 | @classmethod 80 | def class_name(cls) -> str: 81 | return "Ollama_multi_modal_llm" 82 | 83 | @property 84 | def metadata(self) -> MultiModalLLMMetadata: 85 | """LLM metadata.""" 86 | return MultiModalLLMMetadata( 87 | context_window=self.context_window, 88 | num_output=DEFAULT_NUM_OUTPUTS, 89 | model_name=self.model, 90 | is_chat_model=True, # Ollama supports chat API for all models 91 | ) 92 | 93 | @property 94 | def _model_kwargs(self) -> Dict[str, Any]: 95 | base_kwargs = { 96 | "temperature": self.temperature, 97 | "num_ctx": self.context_window, 98 | } 99 | return { 100 | **base_kwargs, 101 | **self.additional_kwargs, 102 | } 103 | 104 | # TODO: 105 | # def init_model(self, preserve_time) -> None: 106 | # response = ollama.generate( 107 | # model=self.model, 108 | # prompt=prompt, 109 | # images=image_documents_to_base64(image_documents), 110 | # stream=False, 111 | # options=self._model_kwargs, 112 | # **kwargs, 113 | # ) 114 | # """Init model.""" 115 | # pass 116 | 117 | def chat( 118 | self, 119 | #messages: Sequence[ChatMessage], 120 | ollama_messages: Sequence[Dict[str, Any]], # directly pass the ollama_messages 121 | **kwargs: Any) -> ChatResponse: 122 | """Chat.""" 123 | import ollama 124 | 125 | # ollama_messages = _messages_to_dicts(messages) 126 | response = ollama.chat( 127 | model=self.model, messages=ollama_messages, stream=False, **kwargs 128 | ) 129 | return ChatResponse( 130 | message=ChatMessage( 131 | content=response["message"]["content"], 132 | role=MessageRole(response["message"]["role"]), 133 | additional_kwargs=get_additional_kwargs(response, ("message",)), 134 | ), 135 | raw=response["message"], 136 | additional_kwargs=get_additional_kwargs(response, ("message",)), 137 | ) 138 | 139 | def stream_chat( 140 | self, 141 | #messages: Sequence[ChatMessage], 142 | ollama_messages: Sequence[Dict[str, Any]], # directly pass the ollama_messages 143 | **kwargs: Any 144 | ) -> ChatResponseGen: 145 | raise NotImplementedError("This method is not implemented yet.") 146 | 147 | 148 | 149 | # def stream_chat( 150 | # self, 151 | # #messages: Sequence[ChatMessage], 152 | # ollama_messages: Sequence[Dict[str, Any]], # directly pass the ollama_messages 153 | # **kwargs: Any 154 | # ) -> ChatResponseGen: 155 | # """Stream chat.""" 156 | # import ollama 157 | 158 | # #ollama_messages = _messages_to_dicts(messages) 159 | 160 | # response = ollama.chat( 161 | # model=self.model, messages=ollama_messages, stream=True, **kwargs 162 | # ) 163 | # text = "" 164 | # for chunk in response: 165 | # if "done" in chunk and chunk["done"]: 166 | # break 167 | # message = chunk["message"] 168 | # delta = message.get("content") 169 | # text += delta 170 | # yield ChatResponse( 171 | # message=ChatMessage( 172 | # content=text, 173 | # role=MessageRole(message["role"]), 174 | # additional_kwargs=get_additional_kwargs( 175 | # message, ("content", "role") 176 | # ), 177 | # ), 178 | # delta=delta, 179 | # raw=message, 180 | # additional_kwargs=get_additional_kwargs(chunk, ("message",)), 181 | # ) 182 | 183 | def complete( 184 | self, 185 | prompt: str, 186 | images: List[str], # directly pass the base64 images 187 | formatted: bool = False, 188 | **kwargs: Any, 189 | ) -> CompletionResponse: 190 | """Complete.""" 191 | import ollama 192 | 193 | response = ollama.generate( 194 | model=self.model, 195 | prompt=prompt, 196 | images=images, 197 | stream=False, 198 | options=self._model_kwargs, 199 | **kwargs, 200 | ) 201 | return CompletionResponse( 202 | text=response["response"], 203 | raw=response, 204 | additional_kwargs=get_additional_kwargs(response, ("response",)), 205 | ) 206 | 207 | def stream_complete( 208 | self, 209 | prompt: str, 210 | image_documents: Sequence[ImageDocument], 211 | formatted: bool = False, 212 | **kwargs: Any, 213 | ) -> CompletionResponseGen: 214 | raise NotImplementedError("This method is not implemented yet.") 215 | 216 | # def stream_complete( 217 | # self, 218 | # prompt: str, 219 | # image_documents: Sequence[ImageDocument], 220 | # formatted: bool = False, 221 | # **kwargs: Any, 222 | # ) -> CompletionResponseGen: 223 | # """Stream complete.""" 224 | # import ollama 225 | 226 | # response = ollama.generate( 227 | # model=self.model, 228 | # prompt=prompt, 229 | # images=image_documents_to_base64(image_documents), 230 | # stream=True, 231 | # options=self._model_kwargs, 232 | # **kwargs, 233 | # ) 234 | # text = "" 235 | # for chunk in response: 236 | # if "done" in chunk and chunk["done"]: 237 | # break 238 | # delta = chunk.get("response") 239 | # text += delta 240 | # yield CompletionResponse( 241 | # text=str(chunk["response"]), 242 | # delta=delta, 243 | # raw=chunk, 244 | # additional_kwargs=get_additional_kwargs(chunk, ("response",)), 245 | # ) 246 | 247 | async def acomplete( 248 | self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any 249 | ) -> CompletionResponse: 250 | raise NotImplementedError("Ollama does not support async completion.") 251 | 252 | async def achat( 253 | self, messages: Sequence[ChatMessage], **kwargs: Any 254 | ) -> ChatResponse: 255 | raise NotImplementedError("Ollama does not support async chat.") 256 | 257 | async def astream_complete( 258 | self, prompt: str, image_documents: Sequence[ImageDocument], **kwargs: Any 259 | ) -> CompletionResponseAsyncGen: 260 | raise NotImplementedError("Ollama does not support async streaming completion.") 261 | 262 | async def astream_chat( 263 | self, messages: Sequence[ChatMessage], **kwargs: Any 264 | ) -> ChatResponseAsyncGen: 265 | raise NotImplementedError("Ollama does not support async streaming chat.") 266 | -------------------------------------------------------------------------------- /qdrant_db/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Virgil-L/Local_MultiModal_RAG_with_llamaindex/53fe2eed1be93ba795d9955b6f5a9493b79cdba4/qdrant_db/.gitignore -------------------------------------------------------------------------------- /retrieval_serving.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, jsonify 2 | import os 3 | from functools import partial 4 | import yaml 5 | import json 6 | import warnings 7 | warnings.filterwarnings(action="ignore", category=FutureWarning) 8 | 9 | import numpy as np 10 | from llama_index.legacy.schema import QueryBundle 11 | 12 | 13 | class JSON_Improved(json.JSONEncoder): 14 | def default(self, obj): 15 | if isinstance(obj, np.ndarray): 16 | return int(obj) 17 | elif isinstance(obj, np.float16): 18 | return float(obj) 19 | elif isinstance(obj, np.float32): 20 | return float(obj) 21 | elif isinstance(obj, np.float64): 22 | return float(obj) 23 | elif isinstance(obj, np.int32): 24 | return int(obj) 25 | else: 26 | return super(JSON_Improved, self).default(obj) 27 | 28 | 29 | from flask.json.provider import JSONProvider 30 | 31 | class CustomJSONProvider(JSONProvider): 32 | 33 | def dumps(self, obj, **kwargs): 34 | return json.dumps(obj, **kwargs, cls=JSON_Improved) 35 | 36 | def loads(self, s: str | bytes, **kwargs): 37 | return json.loads(s, **kwargs) 38 | 39 | 40 | Flask.json_provider_class = CustomJSONProvider 41 | app = Flask(__name__) 42 | 43 | 44 | with open("retriever_config.yaml", 'r') as f: 45 | try: 46 | config = yaml.safe_load(f) 47 | except yaml.YAMLError as exc: 48 | print(exc) 49 | 50 | 51 | mm_retriever = None 52 | client = None 53 | 54 | def initialize_service(config): 55 | global mm_retriever, client 56 | from qdrant_client import QdrantClient 57 | from qdrant_client.http import models as qd_models 58 | from src.custom_vectore_store import MultiModalQdrantVectorStore 59 | from src.custom_embeddings import custom_sparse_doc_vectors, custom_sparse_query_vectors 60 | from llama_index.legacy.vector_stores import QdrantVectorStore 61 | 62 | try: 63 | if os.path.exists(os.path.join(config['qdrant_path'], ".lock")): 64 | os.remove(os.path.join(config['qdrant_path'], ".lock")) 65 | client = QdrantClient(path=config['qdrant_path']) 66 | print("Connected to Qdrant") 67 | except Exception as e: 68 | print("Error connecting to Qdrant: ", str(e)) 69 | 70 | # load model 71 | from transformers import AutoTokenizer, AutoModelForMaskedLM 72 | 73 | splade_q_tokenizer = AutoTokenizer.from_pretrained(config['splade_query_path'], clean_up_tokenization_spaces=True) 74 | splade_q_model = AutoModelForMaskedLM.from_pretrained(config['splade_query_path']) 75 | 76 | splade_d_tokenizer = AutoTokenizer.from_pretrained(config['splade_doc_path'], clean_up_tokenization_spaces=True) 77 | splade_d_model = AutoModelForMaskedLM.from_pretrained(config['splade_doc_path']) 78 | 79 | custom_sparse_doc_fn = partial(custom_sparse_doc_vectors, splade_d_tokenizer, splade_d_model, 512) 80 | custom_sparse_query_fn = partial(custom_sparse_query_vectors, splade_q_tokenizer, splade_q_model, 512) 81 | 82 | text_store = QdrantVectorStore( 83 | client=client, 84 | collection_name=config['text_collection_name'], 85 | enable_hybrid=True, 86 | sparse_query_fn=custom_sparse_query_fn, 87 | sparse_doc_fn=custom_sparse_doc_fn, 88 | stores_text=True, 89 | ) 90 | 91 | image_store = MultiModalQdrantVectorStore( 92 | client=client, 93 | collection_name=config['image_collection_name'], 94 | enable_hybrid=True, 95 | sparse_query_fn=custom_sparse_query_fn, 96 | sparse_doc_fn=custom_sparse_doc_fn, 97 | stores_text=False, 98 | ) 99 | 100 | 101 | from llama_index.legacy.embeddings import HuggingFaceEmbedding 102 | from src.custom_embeddings import CustomizedCLIPEmbedding 103 | 104 | 105 | text_embedding = HuggingFaceEmbedding(model_name=config['embedding_path'], device="cpu", pooling="mean") 106 | image_embedding = CustomizedCLIPEmbedding(model_name=config['image_encoder_path'], device="cpu") 107 | 108 | from llama_index.core.postprocessor import SentenceTransformerRerank 109 | 110 | reranker = SentenceTransformerRerank( 111 | model=config['reranker_path'], 112 | top_n=3, 113 | device="cpu", 114 | keep_retrieval_score=False, 115 | ) 116 | from src.mm_retriever import MultiModalQdrantRetriever 117 | 118 | mm_retriever = MultiModalQdrantRetriever( 119 | text_vector_store = text_store, 120 | image_vector_store = image_store, 121 | text_embed_model = text_embedding, 122 | mm_embed_model = image_embedding, 123 | reranker = reranker, 124 | text_similarity_top_k = config['text_similarity_top_k'], 125 | text_sparse_top_k = config['text_sparse_top_k'], 126 | text_rerank_top_n = config['text_rerank_top_n'], 127 | image_similarity_top_k = config['image_similarity_top_k'], 128 | image_sparse_top_k = config['image_sparse_top_k'], 129 | image_rerank_top_n = config['image_rerank_top_n'], 130 | sparse_query_fn = custom_sparse_query_fn, 131 | ) 132 | 133 | 134 | def initialize(): 135 | try: 136 | initialize_service(config=config) 137 | print("########## Retriever Service initialized. ##########") 138 | return jsonify({"status": "success", "message": "Service initialized."}), 200 139 | except Exception as e: 140 | print(f"Error initializing service: {str(e)}") 141 | return jsonify({"status": "error", "message": str(e)}), 500 142 | 143 | with app.app_context(): 144 | initialize() 145 | 146 | 147 | def process_query(query, text_topk=None, image_topk=None): 148 | 149 | query_bundle=QueryBundle(query_str=query) 150 | 151 | text_query_result = mm_retriever.retrieve_text_nodes(query_bundle=query_bundle, query_mode="hybrid") 152 | print("Retrieved text nodes") 153 | reranked_text_nodes = mm_retriever.rerank_text_nodes(query_bundle, text_query_result, text_rerank_top_n=text_topk) 154 | print("Reranked text nodes") 155 | image_query_result = mm_retriever.retrieve_image_nodes(query_bundle=query_bundle, query_mode="hybrid") 156 | print("Retrieved image nodes") 157 | reranked_image_nodes = mm_retriever.rerank_image_nodes(query_bundle, image_query_result, image_rerank_top_n=image_topk) 158 | print("Reranked image nodes") 159 | 160 | # for item in reranked_image_nodes: 161 | # item.node.metadata['vectors'] = None 162 | # for item in reranked_text_nodes: 163 | # item.node.metadata['vectors'] = None 164 | # del item.node.metadata['regionBoundary'] 165 | # del item.node.metadata['captionBoundary'] 166 | return reranked_text_nodes, reranked_image_nodes 167 | 168 | 169 | @app.route("/api", methods=['POST']) 170 | def handle_request(): 171 | try: 172 | data = request.json 173 | if not data or 'query' not in data: 174 | return jsonify({"status": "error", "message": "Invalid request."}), 400 175 | 176 | text_topk = data.get('text_topk', None) 177 | image_topk = data.get('image_topk', None) 178 | query = data['query'] 179 | # 处理查询 180 | text_nodes, image_nodes = process_query(query, text_topk=text_topk, image_topk=image_topk) 181 | text_nodes = [node.to_dict() for node in text_nodes] 182 | image_nodes = [node.to_dict() for node in image_nodes] 183 | return jsonify({"status": "success", 184 | "query": query, 185 | "text_result": text_nodes, 186 | "image_result": image_nodes}), 200 187 | except Exception as e: 188 | return jsonify({"status": "error", "message": str(e)}), 500 189 | 190 | if __name__ == "__main__": 191 | from waitress import serve 192 | serve(app, host="127.0.0.1", port=5000) -------------------------------------------------------------------------------- /retriever_config.yaml: -------------------------------------------------------------------------------- 1 | qdrant_path: qdrant_db 2 | splade_query_path: ./models/embedding_models/efficient-splade-VI-BT-large-doc 3 | splade_doc_path: ./models/embedding_models/efficient-splade-VI-BT-large-query 4 | 5 | embedding_path: ./models/embedding_models/bge-small-en-v1.5 6 | image_encoder_path: ./models/embedding_models/clip-vit-base-patch32 7 | reranker_path: ./models/embedding_models/bge-reranker-base 8 | 9 | text_collection_name: text_collection 10 | image_collection_name: image_collection 11 | 12 | text_similarity_top_k: 5 13 | text_sparse_top_k: 5 14 | text_rerank_top_n: 3 15 | 16 | image_similarity_top_k: 5 17 | image_sparse_top_k: 5 18 | image_rerank_top_n: 1 -------------------------------------------------------------------------------- /serve_grobid_light.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Recommended way to run Grobid is to have docker installed. See details here: https://grobid.readthedocs.io/en/latest/Grobid-docker/ 3 | 4 | if ! command -v docker &> /dev/null; then 5 | echo "Error: Docker is not installed. Please install Docker before running Grobid." 6 | exit 1 7 | fi 8 | 9 | 10 | machine_arch=$(uname -m) 11 | 12 | if [ "$machine_arch" == "armv7l" ] || [ "$machine_arch" == "aarch64" ]; then 13 | docker run --rm --gpus all --init --ulimit core=0 -p 8070:8070 lfoppiano/grobid:0.8.0-arm 14 | else 15 | docker run --rm --gpus all --init --ulimit core=0 -p 8070:8070 lfoppiano/grobid:0.8.0 16 | fi 17 | -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Virgil-L/Local_MultiModal_RAG_with_llamaindex/53fe2eed1be93ba795d9955b6f5a9493b79cdba4/src/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /src/__pycache__/custom_embeddings.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Virgil-L/Local_MultiModal_RAG_with_llamaindex/53fe2eed1be93ba795d9955b6f5a9493b79cdba4/src/__pycache__/custom_embeddings.cpython-311.pyc -------------------------------------------------------------------------------- /src/__pycache__/custom_vectore_store.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Virgil-L/Local_MultiModal_RAG_with_llamaindex/53fe2eed1be93ba795d9955b6f5a9493b79cdba4/src/__pycache__/custom_vectore_store.cpython-311.pyc -------------------------------------------------------------------------------- /src/__pycache__/mm_retriever.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Virgil-L/Local_MultiModal_RAG_with_llamaindex/53fe2eed1be93ba795d9955b6f5a9493b79cdba4/src/__pycache__/mm_retriever.cpython-311.pyc -------------------------------------------------------------------------------- /src/__pycache__/prompt.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Virgil-L/Local_MultiModal_RAG_with_llamaindex/53fe2eed1be93ba795d9955b6f5a9493b79cdba4/src/__pycache__/prompt.cpython-311.pyc -------------------------------------------------------------------------------- /src/build_qdrant_collections.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from qdrant_client import QdrantClient 5 | from qdrant_client.http import models as qd_models 6 | 7 | from llama_index.legacy.schema import ImageNode, TextNode, NodeRelationship, RelatedNodeInfo 8 | from io import BytesIO 9 | import base64 10 | from PIL import Image 11 | 12 | import torch 13 | from transformers import AutoTokenizer, AutoModelForMaskedLM 14 | from llama_index.legacy.embeddings import HuggingFaceEmbedding 15 | from custom_embeddings import CustomizedCLIPEmbedding 16 | 17 | from llama_index.legacy.node_parser import SentenceSplitter 18 | 19 | from typing import ( 20 | Any, 21 | Dict, 22 | List, 23 | Tuple, 24 | ) 25 | 26 | import argparse 27 | 28 | 29 | PDF_FOLDER = "data/paper_PDF" 30 | BGE_PATH = "./embedding_models/bge-small-en-v1.5" 31 | CLIP_PATH = "./embedding_models/clip-vit-base-patch32" 32 | SPLADE_DOC_PATH = "./embedding_models/efficient-splade-VI-BT-large-doc" 33 | SPLADE_QUERY_PATH = "./embedding_models/efficient-splade-VI-BT-large-doc" 34 | QDRANT_PATH = "./qdrant_db" 35 | 36 | parser = argparse.ArgumentParser(description='Parse PDFs to extract figures and texts') 37 | parser.add_argument('--pdf_folder', type=str, default=PDF_FOLDER, help='Path to the folder containing PDFs') 38 | parser.add_argument('--text_embedding_model', type=str, default=BGE_PATH, help='Path to the text embedding model') 39 | parser.add_argument('--sparse_text_embedding_model', type=str, default=SPLADE_DOC_PATH, help='Path to the sparse text embedding model') 40 | parser.add_argument('--image_embedding_model', type=str, default=CLIP_PATH, help='Path to the image embedding model') 41 | parser.add_argument('--chunk_size', type=int, default=384, help='Size of the chunks to split the text into') 42 | parser.add_argument('--chunk_overlap', type=int, default=32, help='Overlap between the chunks') 43 | parser.add_argument('--storage_path', type=str, default=QDRANT_PATH, help='Path of the qdrant storage') 44 | 45 | 46 | 47 | 48 | 49 | 50 | def extract_text_nodes(text_data: Dict, text_parser, config_file) -> List[TextNode]: 51 | title_node = TextNode( 52 | text = ':\n'.join(('title', text_data['title'])), 53 | metadata = { 54 | "source_file_path": os.path.join(os.getcwd(), PDF_FOLDER, config_file.replace(".json", ".pdf")), 55 | "elementType": "title", 56 | } 57 | ) 58 | 59 | author_node = TextNode( 60 | text = ':\n'.join(('authors', text_data['authors'])), 61 | metadata = { 62 | "source_file_path": os.path.join(os.getcwd(), PDF_FOLDER, config_file.replace(".json", ".pdf")), 63 | "elementType": "author", 64 | } 65 | ) 66 | 67 | abstract_text = ':\n'.join(('abstract', text_data['abstract'])) 68 | splitted_abstract = text_parser.split_text(abstract_text) 69 | abstract_nodes = [TextNode( 70 | text = text, metadata = {"source_file_path": os.path.join(os.getcwd(), PDF_FOLDER, config_file.replace(".json", ".pdf")),"elementType": "abstract",} 71 | ) for text in splitted_abstract] 72 | 73 | 74 | section_text_list = [section['heading']+'\n'+section['text'] for section in text_data['sections']] 75 | 76 | 77 | for i in range(len(section_text_list)-1, -1, -1): 78 | if len(section_text_list[i]) < text_parser.chunk_size: 79 | if i > 0: 80 | section_text_list[i-1] += "\n" + section_text_list[i] 81 | section_text_list.pop(i) 82 | else: 83 | section_text_list[i+1] += "\n" + section_text_list[i] 84 | section_text_list.pop(i) 85 | 86 | 87 | section_nodes = [] 88 | for section_text in section_text_list: 89 | splitted_section = text_parser.split_text(section_text) 90 | section_nodes.extend([TextNode( 91 | text = text, metadata = {"source_file_path": os.path.join(os.getcwd(), PDF_FOLDER, config_file.replace(".json", ".pdf")),"elementType": "section",} 92 | ) for text in splitted_section]) 93 | 94 | 95 | non_title_nodes = [author_node] + abstract_nodes + section_nodes 96 | for node in non_title_nodes: 97 | build_parent_child_relationships(title_node, node) 98 | 99 | return [title_node] + non_title_nodes 100 | 101 | 102 | 103 | def build_parent_child_relationships(parent_node, child_node): 104 | child_node.relationships[NodeRelationship.PARENT] = RelatedNodeInfo(node_id=parent_node.id_, metadata={'elementType':parent_node.metadata['elementType']}) 105 | 106 | if NodeRelationship.CHILD not in parent_node.relationships.keys(): 107 | parent_node.relationships[NodeRelationship.CHILD] = [ 108 | RelatedNodeInfo( 109 | node_id=child_node.id_, 110 | metadata={'elementType':child_node.metadata['elementType']} 111 | ) 112 | ] 113 | else: 114 | parent_node.relationships[NodeRelationship.CHILD].append( 115 | RelatedNodeInfo( 116 | node_id=child_node.id_, 117 | metadata={'elementType':child_node.metadata['elementType']} 118 | ) 119 | ) 120 | 121 | return 122 | 123 | 124 | 125 | def compute_sparse_text_vector(text, tokenizer, model, max_length=512): 126 | """ 127 | Computes a vector from logits and attention mask using ReLU, log, and max operations. 128 | """ 129 | 130 | tokens = tokenizer( 131 | text, truncation=True, padding=True, max_length=max_length, return_tensors="pt" 132 | ) 133 | output = model(**tokens) 134 | logits, attention_mask = output.logits, tokens.attention_mask 135 | relu_log = torch.log(1 + torch.relu(logits)) 136 | weighted_log = relu_log * attention_mask.unsqueeze(-1) 137 | max_val, _ = torch.max(weighted_log, dim=1) 138 | vec = max_val.squeeze() 139 | 140 | indices = vec.nonzero(as_tuple=True)[0].tolist() 141 | values = vec[indices].tolist() 142 | 143 | return qd_models.SparseVector(indices=indices, values=values) 144 | 145 | 146 | 147 | def build_text_nodes(chunk_size, chunk_overlap, pdf_folder, bge_embedding): 148 | TEXT_FOLDER = os.path.join(PDF_FOLDER, "parsed_texts") 149 | text_config_files = os.listdir(TEXT_FOLDER) 150 | 151 | text_parser = SentenceSplitter( 152 | chunk_size=chunk_size, 153 | chunk_overlap=chunk_overlap, 154 | # separator=" ", 155 | ) 156 | 157 | text_nodes = [] 158 | for config_file in text_config_files: 159 | if config_file.endswith(".json"): 160 | with open(os.path.join(TEXT_FOLDER, config_file), "r") as cf: 161 | text_data = json.load(cf) 162 | text_nodes.extend(extract_text_nodes(text_data, text_parser, config_file)) 163 | 164 | 165 | for text_node in text_nodes: 166 | text_embedding = bge_embedding.get_text_embedding(text_node.get_text()) 167 | text_node.embedding = text_embedding 168 | 169 | return text_nodes 170 | 171 | 172 | def create_text_collection(text_nodes, client, collection_name, sparse_tokenizer, sparse_embedding): 173 | text_embedding_size = len(text_nodes[0].embedding) 174 | 175 | try: 176 | client.get_collection(collection_name) 177 | except: 178 | client.create_collection( 179 | collection_name=collection_name, 180 | vectors_config={ 181 | "text-dense": qd_models.VectorParams(size=text_embedding_size, distance=qd_models.Distance.COSINE, on_disk=True), 182 | }, 183 | sparse_vectors_config={ 184 | "text-sparse": qd_models.SparseVectorParams( 185 | index=qd_models.SparseIndexParams(on_disk=True) 186 | ) 187 | }, 188 | optimizers_config=qd_models.OptimizersConfigDiff(memmap_threshold=20000), 189 | ) 190 | 191 | client.upsert( 192 | collection_name=collection_name, 193 | points=[ 194 | qd_models.PointStruct( 195 | id = text_node.id_, 196 | vector = { 197 | "text-dense": text_node.embedding, 198 | "text-sparse": compute_sparse_text_vector(text_node.get_text(), sparse_tokenizer, sparse_embedding) 199 | }, 200 | payload = { 201 | "text": text_node.text, 202 | "metadata": text_node.metadata, 203 | } 204 | ) for text_node in text_nodes 205 | ] 206 | ) 207 | 208 | return 209 | 210 | 211 | 212 | def build_image_nodes(pdf_folder, bge_embedding, clip_embedding): 213 | 214 | IMAGE_FOLDER = os.path.join(pdf_folder, "parsed_figures", "data") 215 | img_config_files = os.listdir(IMAGE_FOLDER) 216 | 217 | img_nodes = [] 218 | 219 | for config_file in img_config_files: 220 | if config_file.endswith(".json"): 221 | with open(os.path.join(IMAGE_FOLDER, config_file), "r") as cf: 222 | img_data = json.load(cf) 223 | 224 | for img_config in img_data: 225 | with open(img_config['renderURL'], "rb") as img_file: 226 | img_base64_bytes = base64.b64encode(img_file.read()) 227 | 228 | img_metadata = {k: img_config[k] for k in img_config.keys() & {'name', 'page', 'figType', 'imageText', 'regionBoundary', 'captionBoundary'} } 229 | img_metadata['elementType'] = img_metadata.pop('figType') 230 | img_metadata['source_file_path'] = os.path.join(os.getcwd(), PDF_FOLDER, config_file.replace(".json", ".pdf")) 231 | 232 | img_node = ImageNode( 233 | image = img_base64_bytes, 234 | metadata = img_metadata, 235 | image_path=img_config["renderURL"], 236 | text=img_config['caption'], 237 | ) 238 | img_nodes.append(img_node) 239 | 240 | for img_node in img_nodes: 241 | img_embedding = clip_embedding.get_image_embedding(BytesIO(base64.b64decode(img_node.image))) 242 | text_embedding = bge_embedding.get_text_embedding(img_node.text) 243 | img_node.embedding = img_embedding 244 | img_node.text_embedding = text_embedding 245 | 246 | 247 | return img_nodes 248 | 249 | 250 | 251 | def create_image_collection(image_nodes, client, collection_name, sparse_tokenizer, sparse_embedding): 252 | image_embedding_size = len(image_nodes[0].embedding) 253 | text_embedding_size = len(image_nodes[0].text_embedding) 254 | 255 | try: 256 | client.get_collection(collection_name) 257 | except: 258 | client.create_collection( 259 | collection_name=collection_name, 260 | vectors_config={ 261 | "multi-modal": qd_models.VectorParams(size=image_embedding_size, distance=qd_models.Distance.COSINE, on_disk=True), 262 | "text-dense": qd_models.VectorParams(size=text_embedding_size, distance=qd_models.Distance.COSINE, on_disk=True), 263 | }, 264 | sparse_vectors_config={ 265 | "text-sparse": qd_models.SparseVectorParams( 266 | index=qd_models.SparseIndexParams(on_disk=True) 267 | ) 268 | }, 269 | optimizers_config=qd_models.OptimizersConfigDiff(memmap_threshold=20000), 270 | ) 271 | 272 | client.upsert( 273 | collection_name=collection_name, 274 | points=[ 275 | qd_models.PointStruct( 276 | id = image_node.id_, 277 | vector = { 278 | "multi-modal": image_node.embedding, 279 | "text-dense": image_node.text_embedding, 280 | "text-sparse":compute_sparse_text_vector(image_node.text, sparse_tokenizer, sparse_embedding), 281 | }, 282 | 283 | payload = { 284 | "image_path": image_node.image_path, 285 | "metadata": image_node.metadata, 286 | "image": image_node.image, 287 | "text": image_node.text, 288 | } 289 | ) for image_node in image_nodes 290 | ] 291 | ) 292 | 293 | return 294 | 295 | 296 | 297 | if __name__ == "__main__": 298 | args = parser.parse_args() 299 | 300 | print("Loading Embedding Models...\n") 301 | bge_embedding = HuggingFaceEmbedding(model_name=args.text_embedding_model, device="cpu", pooling="mean") 302 | clip_embedding = CustomizedCLIPEmbedding(model_name=args.image_embedding_model, device="cpu") 303 | 304 | splade_doc_tokenizer = AutoTokenizer.from_pretrained(args.sparse_text_embedding_model) 305 | splade_doc_embedding = AutoModelForMaskedLM.from_pretrained(args.sparse_text_embedding_model) 306 | 307 | print("Building Text and Image Nodes...\n") 308 | 309 | text_nodes = build_text_nodes( 310 | chunk_size=args.chunk_size, 311 | chunk_overlap=args.chunk_overlap, 312 | pdf_folder=args.pdf_folder, 313 | bge_embedding=bge_embedding, 314 | ) 315 | 316 | image_nodes = build_image_nodes( 317 | pdf_folder = args.pdf_folder, 318 | bge_embedding = bge_embedding, 319 | clip_embedding = clip_embedding, 320 | ) 321 | 322 | print("Creating Qdrant Collections...\n") 323 | client = QdrantClient(path=args.storage_path) 324 | create_text_collection( 325 | text_nodes=text_nodes, 326 | client=client, 327 | collection_name="text_collection", 328 | sparse_tokenizer=splade_doc_tokenizer, 329 | sparse_embedding=splade_doc_embedding) 330 | 331 | create_image_collection( 332 | image_nodes=image_nodes, 333 | client=client, 334 | collection_name="image_collection", 335 | sparse_tokenizer=splade_doc_tokenizer, 336 | sparse_embedding=splade_doc_embedding 337 | ) 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | -------------------------------------------------------------------------------- /src/custom_embeddings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, List, Tuple 3 | 4 | import torch 5 | from transformers import AutoTokenizer, AutoModelForMaskedLM 6 | 7 | from llama_index.legacy.bridge.pydantic import Field, PrivateAttr 8 | from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE 9 | from llama_index.legacy.embeddings.base import Embedding 10 | from llama_index.legacy.embeddings.multi_modal_base import MultiModalEmbedding 11 | from llama_index.legacy.schema import ImageType 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class CustomizedCLIPEmbedding(MultiModalEmbedding): 17 | """Customized multimodal embedding models for encoding text and image for Multi-Modal purpose. (e.g. CLIP, BLIP, BLIP2) 18 | 19 | This class provides an interface to generate embeddings using a model 20 | deployed in OpenAI CLIP. At the initialization it requires a model name 21 | of CLIP. 22 | 23 | """ 24 | 25 | embed_batch_size: int = Field(default=DEFAULT_EMBED_BATCH_SIZE, gt=0) 26 | 27 | _clip: Any = PrivateAttr() 28 | _model: Any = PrivateAttr() 29 | _preprocess: Any = PrivateAttr() 30 | _device: Any = PrivateAttr() 31 | 32 | @classmethod 33 | def class_name(cls) -> str: 34 | return "CustomizedCLIPEmbedding" 35 | 36 | def __init__( 37 | self, 38 | *, 39 | model_name: str, 40 | device: str = None, 41 | embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, 42 | **kwargs: Any, 43 | ): 44 | """Initializes the ClipEmbedding class. 45 | 46 | Args: 47 | embed_batch_size (int, optional): The batch size for embedding generation. Defaults to 10, 48 | must be > 0 and <= 100. 49 | model_name (str): The model name of Clip model. 50 | 51 | Raises: 52 | ImportError: If the `clip` package is not available in the PYTHONPATH. 53 | ValueError: If the model cannot be fetched from Open AI. or if the embed_batch_size 54 | is not in the range (0, 100]. 55 | """ 56 | if embed_batch_size <= 0: 57 | raise ValueError(f"Embed batch size {embed_batch_size} must be > 0.") 58 | 59 | # try: 60 | # import clip 61 | # import torch 62 | # except ImportError: 63 | # raise ImportError( 64 | # "ClipEmbedding requires `pip install git+https://github.com/openai/CLIP.git` and torch." 65 | # ) 66 | 67 | try: 68 | from transformers import CLIPProcessor, CLIPModel 69 | import torch 70 | except ImportError: 71 | raise ImportError( 72 | "CustomizedCLIPEmbedding requires huggingface transformers and torch." 73 | ) 74 | 75 | super().__init__( 76 | embed_batch_size=embed_batch_size, model_name=model_name, **kwargs 77 | ) 78 | 79 | # try: 80 | # self._device = "cuda" if torch.cuda.is_available() else "cpu" 81 | # if self.model_name not in AVAILABLE_CLIP_MODELS: 82 | # raise ValueError( 83 | # f"Model name {self.model_name} is not available in CLIP." 84 | # ) 85 | # self._model, self._preprocess = clip.load( 86 | # self.model_name, device=self._device 87 | # ) 88 | 89 | # except Exception as e: 90 | # logger.error(f"Error while loading clip model.") 91 | # raise ValueError("Unable to fetch the requested embeddings model") from e 92 | 93 | try: 94 | if device == None: 95 | self._device = "cuda" if torch.cuda.is_available() else "cpu" 96 | else: 97 | self._device = device 98 | self._model = CLIPModel.from_pretrained(self.model_name).to(self._device) 99 | self._preprocess = CLIPProcessor.from_pretrained(self.model_name) 100 | 101 | except Exception as e: 102 | logger.error(f"Error while loading clip model.") 103 | raise ValueError("Unable to fetch the requested embeddings model") from e 104 | 105 | 106 | 107 | # TEXT EMBEDDINGS 108 | 109 | async def _aget_query_embedding(self, query: str) -> Embedding: 110 | return self._get_query_embedding(query) 111 | 112 | def _get_text_embedding(self, text: str) -> Embedding: 113 | return self._get_text_embeddings([text])[0] 114 | 115 | def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: 116 | results = [] 117 | try: 118 | import torch 119 | except ImportError: 120 | raise ImportError( 121 | "CustomizedCLIPEmbedding requires `pip install torch`." 122 | ) 123 | with torch.no_grad(): 124 | for text in texts: 125 | # try: 126 | # import clip 127 | # except ImportError: 128 | # raise ImportError( 129 | # "ClipEmbedding requires `pip install git+https://github.com/openai/CLIP.git` and torch." 130 | # ) 131 | # text_embedding = self._model.encode_text( 132 | # clip.tokenize(text).to(self._device) 133 | # ) 134 | 135 | #TODO 136 | text_embedding = self._model.get_text_features(**self._preprocess.tokenizer(text, return_tensors="pt").to(self._device)) 137 | 138 | results.append(text_embedding.tolist()[0]) 139 | 140 | return results 141 | 142 | def _get_query_embedding(self, query: str) -> Embedding: 143 | return self._get_text_embedding(query) 144 | 145 | # IMAGE EMBEDDINGS 146 | 147 | async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding: 148 | return self._get_image_embedding(img_file_path) 149 | 150 | def _get_image_embedding(self, img_file_path: ImageType) -> Embedding: 151 | try: 152 | import torch 153 | from PIL import Image 154 | except ImportError: 155 | raise ImportError( 156 | "CustomizedCLIPEmbedding requires `pip install torch` and `pip install pillow`." 157 | ) 158 | # with torch.no_grad(): 159 | # image = ( 160 | # self._preprocess(Image.open(img_file_path)) 161 | # .unsqueeze(0) 162 | # .to(self._device) 163 | # ) 164 | # return self._model.encode_image(image).tolist()[0] 165 | with torch.no_grad(): 166 | img_inputs = self._preprocess.image_processor.preprocess(Image.open(img_file_path), return_tensors="pt").to(self._device) 167 | return self._model.get_image_features(**img_inputs).tolist()[0] 168 | 169 | 170 | 171 | 172 | 173 | def custom_sparse_doc_vectors( 174 | doc_tokenizer, 175 | doc_model, 176 | max_length: int, 177 | texts: List[str], 178 | ) -> Tuple[List[List[int]], List[List[float]]]: 179 | 180 | tokens = doc_tokenizer( 181 | texts, max_length=max_length, truncation=True, padding=True, return_tensors="pt" 182 | ) 183 | 184 | # if torch.cuda.is_available(): 185 | # tokens = tokens.to("cuda") 186 | 187 | output = doc_model(**tokens) 188 | logits, attention_mask = output.logits, tokens.attention_mask 189 | relu_log = torch.log(1 + torch.relu(logits)) 190 | weighted_log = relu_log * attention_mask.unsqueeze(-1) 191 | tvecs, _ = torch.max(weighted_log, dim=1) 192 | 193 | indices = [] 194 | vecs = [] 195 | for batch in tvecs: 196 | indices.append(batch.nonzero(as_tuple=True)[0].tolist()) 197 | vecs.append(batch[indices[-1]].tolist()) 198 | 199 | return indices, vecs 200 | 201 | 202 | def custom_sparse_query_vectors( 203 | query_tokenizer, 204 | query_model, 205 | max_length: int, 206 | texts: List[str], 207 | ) -> Tuple[List[List[int]], List[List[float]]]: 208 | """ 209 | Computes vectors from logits and attention mask using ReLU, log, and max operations. 210 | """ 211 | # TODO: compute sparse vectors in batches if max length is exceeded 212 | tokens = query_tokenizer( 213 | texts, max_length=max_length, truncation=True, padding=True, return_tensors="pt" 214 | ) 215 | 216 | # if torch.cuda.is_available(): 217 | # tokens = tokens.to("cuda") 218 | 219 | output = query_model(**tokens) 220 | logits, attention_mask = output.logits, tokens.attention_mask 221 | relu_log = torch.log(1 + torch.relu(logits)) 222 | weighted_log = relu_log * attention_mask.unsqueeze(-1) 223 | tvecs, _ = torch.max(weighted_log, dim=1) 224 | 225 | # extract the vectors that are non-zero and their indices 226 | indices = [] 227 | vecs = [] 228 | for batch in tvecs: 229 | indices.append(batch.nonzero(as_tuple=True)[0].tolist()) 230 | vecs.append(batch[indices[-1]].tolist()) 231 | 232 | return indices, vecs 233 | 234 | 235 | if __name__ == "__main__": 236 | 237 | import requests 238 | from PIL import Image 239 | 240 | CLIP_PATH = "./embedding_models/clip-vit-base-patch32" 241 | clip_embedding = CustomizedCLIPEmbedding(model_name=CLIP_PATH, device="cpu") 242 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 243 | image = Image.open(requests.get(url, stream=True).raw) 244 | texts = ["This is a test sentence for custom CLIP embedding model.", "This is another test sentence."] 245 | txt_embedding = clip_embedding._get_text_embeddings(texts) 246 | print(f"\n\nText Embedding: {txt_embedding}") 247 | 248 | img_embedding = clip_embedding._get_image_embedding(requests.get(url, stream=True).raw) 249 | print(f"\n\nImage Embedding: {img_embedding}") -------------------------------------------------------------------------------- /src/custom_vectore_store.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multi-modal Qdrant vector store index. 3 | 4 | An index that is built on top of an existing Qdrant collection for text-image retrieval. 5 | 6 | """ 7 | 8 | 9 | 10 | import logging 11 | from typing import Any, List, Optional, Tuple, cast 12 | 13 | import qdrant_client 14 | from grpc import RpcError 15 | 16 | from llama_index.legacy.vector_stores.qdrant import QdrantVectorStore 17 | from llama_index.legacy.bridge.pydantic import Field, PrivateAttr 18 | from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode, ImageNode 19 | from llama_index.legacy.utils import iter_batch 20 | from llama_index.legacy.vector_stores.types import ( 21 | VectorStoreQuery, 22 | VectorStoreQueryMode, 23 | VectorStoreQueryResult, 24 | ) 25 | from llama_index.legacy.vector_stores.utils import ( 26 | legacy_metadata_dict_to_node, 27 | metadata_dict_to_node, 28 | node_to_metadata_dict, 29 | ) 30 | from llama_index.legacy.vector_stores.qdrant_utils import ( 31 | HybridFusionCallable, 32 | SparseEncoderCallable, 33 | default_sparse_encoder, 34 | relative_score_fusion, 35 | ) 36 | 37 | 38 | from qdrant_client.http import models as rest 39 | from qdrant_client.http.exceptions import UnexpectedResponse 40 | from qdrant_client.http.models import ( 41 | FieldCondition, 42 | Filter, 43 | MatchAny, 44 | MatchExcept, 45 | MatchText, 46 | MatchValue, 47 | Payload, 48 | Range, 49 | ) 50 | 51 | logger = logging.getLogger(__name__) 52 | import_err_msg = ( 53 | "`qdrant-client` package not found, please run `pip install qdrant-client`" 54 | ) 55 | 56 | 57 | class MultiModalQdrantVectorStore(QdrantVectorStore): 58 | """ 59 | Multi-modal Qdrant Vector Store. 60 | 61 | In this vector store, embeddings and docs are stored within a 62 | Qdrant collection. 63 | 64 | During query time, the index uses Qdrant to query for the top 65 | k most similar nodes. 66 | 67 | Args: 68 | collection_name: (str): name of the Qdrant collection 69 | client (Optional[Any]): QdrantClient instance from `qdrant-client` package 70 | aclient (Optional[Any]): AsyncQdrantClient instance from `qdrant-client` package 71 | url (Optional[str]): url of the Qdrant instance 72 | api_key (Optional[str]): API key for authenticating with Qdrant 73 | batch_size (int): number of points to upload in a single request to Qdrant. Defaults to 64 74 | parallel (int): number of parallel processes to use during upload. Defaults to 1 75 | max_retries (int): maximum number of retries in case of a failure. Defaults to 3 76 | client_kwargs (Optional[dict]): additional kwargs for QdrantClient and AsyncQdrantClient 77 | enable_hybrid (bool): whether to enable hybrid search using dense and sparse vectors 78 | sparse_doc_fn (Optional[SparseEncoderCallable]): function to encode sparse vectors 79 | sparse_query_fn (Optional[SparseEncoderCallable]): function to encode sparse queries 80 | hybrid_fusion_fn (Optional[HybridFusionCallable]): function to fuse hybrid search results 81 | """ 82 | 83 | stores_text: bool = True 84 | flat_metadata: bool = False 85 | 86 | collection_name: str 87 | path: Optional[str] 88 | url: Optional[str] 89 | api_key: Optional[str] 90 | batch_size: int 91 | parallel: int 92 | max_retries: int 93 | client_kwargs: dict = Field(default_factory=dict) 94 | enable_hybrid: bool 95 | 96 | _client: Any = PrivateAttr() 97 | _aclient: Any = PrivateAttr() 98 | _collection_initialized: bool = PrivateAttr() 99 | _sparse_doc_fn: Optional[SparseEncoderCallable] = PrivateAttr() 100 | _sparse_query_fn: Optional[SparseEncoderCallable] = PrivateAttr() 101 | _hybrid_fusion_fn: Optional[HybridFusionCallable] = PrivateAttr() 102 | 103 | def __init__( 104 | self, 105 | collection_name: str, 106 | client: Optional[Any] = None, 107 | aclient: Optional[Any] = None, 108 | url: Optional[str] = None, 109 | api_key: Optional[str] = None, 110 | batch_size: int = 64, 111 | parallel: int = 1, 112 | max_retries: int = 3, 113 | client_kwargs: Optional[dict] = None, 114 | enable_hybrid: bool = False, 115 | sparse_doc_fn: Optional[SparseEncoderCallable] = None, 116 | sparse_query_fn: Optional[SparseEncoderCallable] = None, 117 | hybrid_fusion_fn: Optional[HybridFusionCallable] = None, 118 | **kwargs: Any, 119 | ) -> None: 120 | """Init params.""" 121 | if ( 122 | client is None 123 | and aclient is None 124 | and (url is None or api_key is None or collection_name is None) 125 | ): 126 | raise ValueError( 127 | "Must provide either a QdrantClient instance or a url and api_key." 128 | ) 129 | 130 | if client is None and aclient is None: 131 | client_kwargs = client_kwargs or {} 132 | self._client = qdrant_client.QdrantClient( 133 | url=url, api_key=api_key, **client_kwargs 134 | ) 135 | self._aclient = qdrant_client.AsyncQdrantClient( 136 | url=url, api_key=api_key, **client_kwargs 137 | ) 138 | else: 139 | if client is not None and aclient is not None: 140 | logger.warning( 141 | "Both client and aclient are provided. If using `:memory:` " 142 | "mode, the data between clients is not synced." 143 | ) 144 | 145 | self._client = client 146 | self._aclient = aclient 147 | 148 | if self._client is not None: 149 | self._collection_initialized = self._collection_exists(collection_name) 150 | else: 151 | # need to do lazy init for async clients 152 | self._collection_initialized = False 153 | 154 | # TODO: setup hybrid search if enabled 155 | # if enable_hybrid: 156 | # self._sparse_doc_fn = sparse_doc_fn or default_sparse_encoder( 157 | # "naver/efficient-splade-VI-BT-large-doc" 158 | # ) 159 | # self._sparse_query_fn = sparse_query_fn or default_sparse_encoder( 160 | # "naver/efficient-splade-VI-BT-large-query" 161 | # ) 162 | # self._hybrid_fusion_fn = hybrid_fusion_fn or cast( 163 | # HybridFusionCallable, relative_score_fusion 164 | # ) 165 | 166 | super().__init__( 167 | client=client, 168 | collection_name=collection_name, 169 | url=url, 170 | api_key=api_key, 171 | batch_size=batch_size, 172 | parallel=parallel, 173 | max_retries=max_retries, 174 | client_kwargs=client_kwargs or {}, 175 | enable_hybrid=enable_hybrid, 176 | sparse_doc_fn=sparse_doc_fn, 177 | sparse_query_fn=sparse_query_fn, 178 | ) 179 | 180 | @classmethod 181 | def class_name(cls) -> str: 182 | return "MultiModalQdrantVectorStore" 183 | 184 | 185 | #TODO: write a more flexible hybrid fusion implementation 186 | def text_to_caption_query( 187 | self, 188 | query: VectorStoreQuery, 189 | **kwargs: Any, 190 | ) -> VectorStoreQueryResult: 191 | """ 192 | Text-to-Text query with similarity of query text embedding and node text (image caption) embedding 193 | 194 | Args: 195 | query (VectorStoreQuery): query 196 | """ 197 | 198 | 199 | query_embedding = cast(List[float], query.query_embedding) 200 | # NOTE: users can pass in qdrant_filters (nested/complicated filters) to override the default MetadataFilters 201 | qdrant_filters = kwargs.get("qdrant_filters") 202 | if qdrant_filters is not None: 203 | query_filter = qdrant_filters 204 | else: 205 | query_filter = cast(Filter, self._build_query_filter(query)) 206 | 207 | 208 | 209 | if ( 210 | query.mode == VectorStoreQueryMode.SPARSE 211 | and self.enable_hybrid 212 | and self._sparse_query_fn is not None 213 | and query.query_str is not None 214 | ): 215 | sparse_indices, sparse_embedding = self._sparse_query_fn( 216 | [query.query_str], 217 | ) 218 | sparse_top_k = query.sparse_top_k or query.similarity_top_k 219 | 220 | sparse_response = self._client.search_batch( 221 | collection_name=self.collection_name, 222 | requests=[ 223 | rest.SearchRequest( 224 | vector=rest.NamedSparseVector( 225 | name="text-sparse", 226 | vector=rest.SparseVector( 227 | indices=sparse_indices[0], 228 | values=sparse_embedding[0], 229 | ), 230 | ), 231 | limit=sparse_top_k, 232 | filter=query_filter, 233 | with_payload=True, 234 | with_vector=True, 235 | ), 236 | ], 237 | ) 238 | return self.parse_image_to_query_result(sparse_response[0]) 239 | 240 | elif self.enable_hybrid: 241 | # search for dense vectors only 242 | response = self._client.search_batch( 243 | collection_name=self.collection_name, 244 | requests=[ 245 | rest.SearchRequest( 246 | vector=rest.NamedVector( 247 | name="text-dense", 248 | vector=query_embedding, 249 | ), 250 | limit=query.similarity_top_k, 251 | filter=query_filter, 252 | with_payload=True, 253 | with_vector=True 254 | ), 255 | ], 256 | ) 257 | 258 | return self.parse_image_to_query_result(response[0]) 259 | else: 260 | response = self._client.search( 261 | collection_name=self.collection_name, 262 | query_vector=query_embedding, 263 | limit=query.similarity_top_k, 264 | query_filter=query_filter, 265 | ) 266 | return self.parse_image_to_query_result(response) 267 | 268 | 269 | 270 | def text_to_image_query( 271 | self, 272 | query: VectorStoreQuery, 273 | **kwargs: Any, 274 | ) -> VectorStoreQueryResult: 275 | """ 276 | Text-to-Image query with cros-modal similarity of query text embedding and node image embedding 277 | 278 | Args: 279 | query (VectorStoreQuery): query 280 | """ 281 | query_embedding = cast(List[float], query.query_embedding) 282 | # NOTE: users can pass in qdrant_filters (nested/complicated filters) to override the default MetadataFilters 283 | qdrant_filters = kwargs.get("qdrant_filters") 284 | if qdrant_filters is not None: 285 | query_filter = qdrant_filters 286 | else: 287 | query_filter = cast(Filter, self._build_query_filter(query)) 288 | 289 | 290 | response = self._client.search_batch( 291 | collection_name=self.collection_name, 292 | requests=[ 293 | rest.SearchRequest( 294 | vector=rest.NamedVector( 295 | name="multi-modal", 296 | vector=query_embedding, 297 | ), 298 | limit=query.similarity_top_k, 299 | filter=query_filter, 300 | with_payload=True, 301 | with_vector=True, 302 | ), 303 | ], 304 | ) 305 | return self.parse_image_to_query_result(response[0]) 306 | 307 | 308 | def parse_image_to_query_result(self, response: List[Any]) -> VectorStoreQueryResult: 309 | """ 310 | Convert vector store response to VectorStoreQueryResult. 311 | 312 | Args: 313 | response: List[Any]: List of results returned from the vector store. 314 | """ 315 | nodes = [] 316 | similarities = [] 317 | ids = [] 318 | 319 | for point in response: 320 | payload = cast(Payload, point.payload) 321 | 322 | # try: 323 | # node = metadata_dict_to_node(payload) 324 | # except Exception: 325 | # # NOTE: deprecated legacy logic for backward compatibility 326 | # logger.debug("Failed to parse Node metadata, fallback to legacy logic.") 327 | # metadata, node_info, relationships = legacy_metadata_dict_to_node( 328 | # payload 329 | # ) 330 | 331 | # node = ImageNode( 332 | # id_=str(point.id), 333 | # text=payload.get("text"), 334 | # image=payload.get("image"), 335 | # image_path=payload.get("image_path"), 336 | # metadata=metadata, 337 | # relationships=relationships, 338 | # ) 339 | 340 | node = ImageNode( 341 | id_=str(point.id), 342 | text=payload.get("text"), 343 | image=payload.get("image"), 344 | image_path=payload.get("image_path"), 345 | metadata=payload.get("metadata"), 346 | #relationships=relationships, 347 | ) 348 | node.metadata["vectors"] = point.vector 349 | 350 | 351 | 352 | nodes.append(node) 353 | similarities.append(point.score) 354 | ids.append(str(point.id)) 355 | 356 | return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) 357 | 358 | 359 | #TODO 360 | def _build_query_filter(self, query: VectorStoreQuery) -> Optional[Any]: 361 | 362 | return super()._build_query_filter(query) 363 | 364 | 365 | 366 | 367 | -------------------------------------------------------------------------------- /src/llava_llamacpp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, Dict, Optional, Sequence, Tuple 3 | 4 | import requests 5 | from llama_index.core.base.llms.types import ( 6 | ChatMessage, 7 | ChatResponse, 8 | ChatResponseGen, 9 | CompletionResponse, 10 | CompletionResponseGen, 11 | LLMMetadata, 12 | ) 13 | from llama_index.core.bridge.pydantic import Field, PrivateAttr 14 | from llama_index.core.callbacks import CallbackManager 15 | from llama_index.core.constants import ( 16 | DEFAULT_CONTEXT_WINDOW, 17 | DEFAULT_NUM_OUTPUTS, 18 | DEFAULT_TEMPERATURE, 19 | ) 20 | from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback 21 | from llama_index.core.llms.custom import CustomLLM 22 | from llama_index.core.base.llms.generic_utils import ( 23 | completion_response_to_chat_response, 24 | stream_completion_response_to_chat_response, 25 | ) 26 | from llama_index.core.types import BaseOutputParser, PydanticProgramMode 27 | from llama_index.core.utils import get_cache_dir 28 | from tqdm import tqdm 29 | 30 | from llama_cpp import Llama 31 | 32 | DEFAULT_LLAMA_CPP_GGML_MODEL = ( 33 | "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/resolve" 34 | "/main/llama-2-13b-chat.ggmlv3.q4_0.bin" 35 | ) 36 | DEFAULT_LLAMA_CPP_GGUF_MODEL = ( 37 | "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve" 38 | "/main/llama-2-13b-chat.Q4_0.gguf" 39 | ) 40 | DEFAULT_LLAMA_CPP_MODEL_VERBOSITY = True 41 | 42 | def get_additional_kwargs( 43 | response: Dict[str, Any], exclude: Tuple[str, ...] 44 | ) -> Dict[str, Any]: 45 | return {k: v for k, v in response.items() if k not in exclude} 46 | 47 | 48 | class Llava_LlamaCPP(CustomLLM): 49 | model_url: Optional[str] = Field( 50 | description="The URL llama-cpp model to download and use." 51 | ) 52 | model_path: Optional[str] = Field( 53 | description="The path to the llama-cpp model to use." 54 | ) 55 | temperature: float = Field( 56 | default=DEFAULT_TEMPERATURE, 57 | description="The temperature to use for sampling.", 58 | gte=0.0, 59 | lte=1.0, 60 | ) 61 | max_new_tokens: int = Field( 62 | default=DEFAULT_NUM_OUTPUTS, 63 | description="The maximum number of tokens to generate.", 64 | gt=0, 65 | ) 66 | context_window: int = Field( 67 | default=DEFAULT_CONTEXT_WINDOW, 68 | description="The maximum number of context tokens for the model.", 69 | gt=0, 70 | ) 71 | generate_kwargs: Dict[str, Any] = Field( 72 | default_factory=dict, description="Kwargs used for generation." 73 | ) 74 | model_kwargs: Dict[str, Any] = Field( 75 | default_factory=dict, description="Kwargs used for model initialization." 76 | ) 77 | verbose: bool = Field( 78 | default=DEFAULT_LLAMA_CPP_MODEL_VERBOSITY, 79 | description="Whether to print verbose output.", 80 | ) 81 | 82 | _model: Any = PrivateAttr() 83 | 84 | def __init__( 85 | self, 86 | model_url: Optional[str] = None, 87 | model_path: Optional[str] = None, 88 | temperature: float = DEFAULT_TEMPERATURE, 89 | max_new_tokens: int = DEFAULT_NUM_OUTPUTS, 90 | context_window: int = DEFAULT_CONTEXT_WINDOW, 91 | callback_manager: Optional[CallbackManager] = None, 92 | generate_kwargs: Optional[Dict[str, Any]] = None, 93 | model_kwargs: Optional[Dict[str, Any]] = None, 94 | verbose: bool = DEFAULT_LLAMA_CPP_MODEL_VERBOSITY, 95 | system_prompt: Optional[str] = None, 96 | messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, 97 | completion_to_prompt: Optional[Callable[[str], str]] = None, 98 | pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, 99 | output_parser: Optional[BaseOutputParser] = None, 100 | ) -> None: 101 | model_kwargs = { 102 | **{"n_ctx": context_window, "verbose": verbose}, 103 | **(model_kwargs or {}), # Override defaults via model_kwargs 104 | } 105 | 106 | # check if model is cached 107 | if model_path is not None: 108 | if not os.path.exists(model_path): 109 | raise ValueError( 110 | "Provided model path does not exist. " 111 | "Please check the path or provide a model_url to download." 112 | ) 113 | else: 114 | self._model = Llama(model_path=model_path, **model_kwargs) 115 | else: 116 | cache_dir = get_cache_dir() 117 | model_url = model_url or self._get_model_path_for_version() 118 | model_name = os.path.basename(model_url) 119 | model_path = os.path.join(cache_dir, "models", model_name) 120 | if not os.path.exists(model_path): 121 | os.makedirs(os.path.dirname(model_path), exist_ok=True) 122 | self._download_url(model_url, model_path) 123 | assert os.path.exists(model_path) 124 | 125 | self._model = Llama(model_path=model_path, **model_kwargs) 126 | 127 | model_path = model_path 128 | generate_kwargs = generate_kwargs or {} 129 | generate_kwargs.update( 130 | {"temperature": temperature, "max_tokens": max_new_tokens} 131 | ) 132 | 133 | super().__init__( 134 | model_path=model_path, 135 | model_url=model_url, 136 | temperature=temperature, 137 | context_window=context_window, 138 | max_new_tokens=max_new_tokens, 139 | callback_manager=callback_manager, 140 | generate_kwargs=generate_kwargs, 141 | model_kwargs=model_kwargs, 142 | verbose=verbose, 143 | system_prompt=system_prompt, 144 | messages_to_prompt=messages_to_prompt, 145 | completion_to_prompt=completion_to_prompt, 146 | pydantic_program_mode=pydantic_program_mode, 147 | output_parser=output_parser, 148 | ) 149 | 150 | @classmethod 151 | def class_name(cls) -> str: 152 | return "LlamaCPP_llm" 153 | 154 | @property 155 | def metadata(self) -> LLMMetadata: 156 | """LLM metadata.""" 157 | return LLMMetadata( 158 | context_window=self._model.context_params.n_ctx, 159 | num_output=self.max_new_tokens, 160 | model_name=self.model_path, 161 | ) 162 | 163 | def _get_model_path_for_version(self) -> str: 164 | """Get model path for the current llama-cpp version.""" 165 | import pkg_resources 166 | 167 | version = pkg_resources.get_distribution("llama-cpp-python").version 168 | major, minor, patch = version.split(".") 169 | 170 | # NOTE: llama-cpp-python<=0.1.78 supports GGML, newer support GGUF 171 | if int(major) <= 0 and int(minor) <= 1 and int(patch) <= 78: 172 | return DEFAULT_LLAMA_CPP_GGML_MODEL 173 | else: 174 | return DEFAULT_LLAMA_CPP_GGUF_MODEL 175 | 176 | def _download_url(self, model_url: str, model_path: str) -> None: 177 | completed = False 178 | try: 179 | print("Downloading url", model_url, "to path", model_path) 180 | with requests.get(model_url, stream=True) as r: 181 | with open(model_path, "wb") as file: 182 | total_size = int(r.headers.get("Content-Length") or "0") 183 | if total_size < 1000 * 1000: 184 | raise ValueError( 185 | "Content should be at least 1 MB, but is only", 186 | r.headers.get("Content-Length"), 187 | "bytes", 188 | ) 189 | print("total size (MB):", round(total_size / 1000 / 1000, 2)) 190 | chunk_size = 1024 * 1024 # 1 MB 191 | for chunk in tqdm( 192 | r.iter_content(chunk_size=chunk_size), 193 | total=int(total_size / chunk_size), 194 | ): 195 | file.write(chunk) 196 | completed = True 197 | except Exception as e: 198 | print("Error downloading model:", e) 199 | finally: 200 | if not completed: 201 | print("Download incomplete.", "Removing partially downloaded file.") 202 | os.remove(model_path) 203 | raise ValueError("Download incomplete.") 204 | 205 | @llm_chat_callback() 206 | def chat(self, messages: Sequence[Dict[str, Any]], **kwargs: Any) -> ChatResponse: 207 | # https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion 208 | response = self._model.create_chat_completion(messages, **kwargs)['choices'][0] 209 | 210 | return ChatResponse( 211 | message=ChatMessage( 212 | content=response["message"]["content"], 213 | role=MessageRole(response["message"]["role"]), 214 | additional_kwargs=get_additional_kwargs(response, ("message",)), 215 | ), 216 | raw=response["message"], 217 | additional_kwargs=get_additional_kwargs(response, ("message",)), 218 | ) 219 | 220 | 221 | 222 | ##TODO 223 | @llm_chat_callback() 224 | def stream_chat( 225 | self, messages: Sequence[ChatMessage], **kwargs: Any 226 | ) -> ChatResponseGen: 227 | prompt = self.messages_to_prompt(messages) 228 | completion_response = self.stream_complete(prompt, formatted=True, **kwargs) 229 | return stream_completion_response_to_chat_response(completion_response) 230 | 231 | ##TODO 232 | @llm_completion_callback() 233 | def complete( 234 | self, prompt: str, formatted: bool = False, **kwargs: Any 235 | ) -> CompletionResponse: 236 | self.generate_kwargs.update({"stream": False}) 237 | 238 | if not formatted: 239 | prompt = self.completion_to_prompt(prompt) 240 | 241 | response = self._model(prompt=prompt, **self.generate_kwargs) 242 | 243 | return CompletionResponse(text=response["choices"][0]["text"], raw=response) 244 | 245 | ##TODO 246 | @llm_completion_callback() 247 | def stream_complete( 248 | self, prompt: str, formatted: bool = False, **kwargs: Any 249 | ) -> CompletionResponseGen: 250 | self.generate_kwargs.update({"stream": True}) 251 | 252 | if not formatted: 253 | prompt = self.completion_to_prompt(prompt) 254 | 255 | response_iter = self._model(prompt=prompt, **self.generate_kwargs) 256 | 257 | def gen() -> CompletionResponseGen: 258 | text = "" 259 | for response in response_iter: 260 | delta = response["choices"][0]["text"] 261 | text += delta 262 | yield CompletionResponse(delta=delta, text=text, raw=response) 263 | 264 | return gen() 265 | -------------------------------------------------------------------------------- /src/mm_query_engine.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Sequence, Tuple 2 | 3 | from llama_index.core.base.response.schema import RESPONSE_TYPE, Response 4 | from llama_index.core.callbacks.base import CallbackManager 5 | from llama_index.core.callbacks.schema import CBEventType, EventPayload 6 | from llama_index.core.indices.multi_modal import MultiModalVectorIndexRetriever 7 | from llama_index.core.indices.query.base import BaseQueryEngine 8 | from llama_index.core.indices.query.schema import QueryBundle, QueryType 9 | from llama_index.core.multi_modal_llms.base import MultiModalLLM 10 | from llama_index.core.postprocessor.types import BaseNodePostprocessor 11 | from llama_index.core.prompts import BasePromptTemplate 12 | from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT 13 | from llama_index.core.query_engine.citation_query_engine import CITATION_QA_TEMPLATE 14 | from llama_index.core.prompts import PromptTemplate 15 | from llama_index.core.prompts.mixin import PromptMixinType 16 | from llama_index.core.schema import ImageNode, NodeWithScore 17 | from mm_retriever import MultiModalQdrantRetriever 18 | 19 | 20 | # rewrite CITATION_QA_TEMPLATE 21 | TEXT_QA_TEMPLATE = PromptTemplate( 22 | "Please provide an answer based solely on the provided sources. " 23 | "When referencing information from a source, " 24 | "cite the appropriate source(s) using their corresponding numbers. " 25 | "Every answer should include at least one source citation. " 26 | "Only cite a source when you are explicitly referencing it. " 27 | "If none of the sources are helpful, you should indicate that. " 28 | "Below are several numbered sources of information:" 29 | "\n------\n" 30 | "{context_str}" 31 | "\n------\n" 32 | "Query: {query_str}\n" 33 | "Answer: " 34 | ) 35 | 36 | IMAGE_QA_TEMPLATE = PromptTemplate( 37 | "\n" 38 | "Caption: {context_str}" 39 | "\n------\n" 40 | "You are a smart agent who can answer questions based on external information. " 41 | "Above is an annotated image you retrieved. Please provide an answer to the query based solely on the image and caption. " 42 | "If the image is not helpful, you should indicate that. \n" 43 | "Query: {query_str}\n" 44 | "Note: Don't include expressions like \"This image appears to be XXX\" in your answer.\n" 45 | "Answer: " 46 | ) 47 | 48 | ANSWER_INTEGRATION_TEMPLATE = PromptTemplate( 49 | "With the following sources related to your question from my knowledge base: \n" 50 | "\n"+"-"*50+"\n" 51 | "Paragraphs:\n\n" 52 | "{context_str}\n" 53 | "\nImages:\n" 54 | "{image_context_str}\n" 55 | "\n"+"-"*50+"\n" 56 | "Here is my answer:\n" 57 | "\n{text_context_response}\n{image_context_response}" 58 | ) 59 | 60 | # def _get_image_and_text_nodes( 61 | # nodes: List[NodeWithScore], 62 | # ) -> Tuple[List[NodeWithScore], List[NodeWithScore]]: 63 | # image_nodes = [] 64 | # text_nodes = [] 65 | # for res_node in nodes: 66 | # if isinstance(res_node.node, ImageNode): 67 | # image_nodes.append(res_node) 68 | # else: 69 | # text_nodes.append(res_node) 70 | # return image_nodes, text_nodes 71 | 72 | 73 | class CustomMultiModalQueryEngine(BaseQueryEngine): 74 | """Simple Multi Modal Retriever query engine. 75 | 76 | Assumes that retrieved text context fits within context window of LLM, along with images. 77 | 78 | Args: 79 | retriever (MultiModalVectorIndexRetriever): A retriever object. 80 | multi_modal_llm (Optional[MultiModalLLM]): MultiModalLLM Models. 81 | text_qa_template (Optional[BasePromptTemplate]): Text QA Prompt Template. 82 | image_qa_template (Optional[BasePromptTemplate]): Image QA Prompt Template. 83 | node_postprocessors (Optional[List[BaseNodePostprocessor]]): Node Postprocessors. 84 | callback_manager (Optional[CallbackManager]): A callback manager. 85 | """ 86 | 87 | def __init__( 88 | self, 89 | retriever: MultiModalQdrantRetriever, 90 | multi_modal_llm: MultiModalLLM, 91 | text_qa_template: Optional[BasePromptTemplate] = None, 92 | image_qa_template: Optional[BasePromptTemplate] = None, 93 | answer_integration_template: Optional[BasePromptTemplate] = None, 94 | node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, 95 | callback_manager: Optional[CallbackManager] = None, 96 | **kwargs: Any, 97 | ) -> None: 98 | self._retriever = retriever 99 | self._multi_modal_llm = multi_modal_llm 100 | 101 | self._text_qa_template = text_qa_template or CITATION_QA_TEMPLATE 102 | self._image_qa_template = image_qa_template or IMAGE_QA_TEMPLATE 103 | 104 | self._answer_integration_template = answer_integration_template or ANSWER_INTEGRATION_TEMPLATE 105 | 106 | self._node_postprocessors = node_postprocessors or [] 107 | callback_manager = callback_manager or CallbackManager([]) 108 | for node_postprocessor in self._node_postprocessors: 109 | node_postprocessor.callback_manager = callback_manager 110 | 111 | super().__init__(callback_manager) 112 | 113 | def _get_prompts(self) -> Dict[str, Any]: 114 | """Get prompts.""" 115 | return {"text_qa_template": self._text_qa_template} 116 | 117 | def _get_prompt_modules(self) -> PromptMixinType: 118 | """Get prompt sub-modules.""" 119 | return {} 120 | 121 | def _apply_node_postprocessors( 122 | self, nodes: List[NodeWithScore], query_bundle: QueryBundle 123 | ) -> List[NodeWithScore]: 124 | for node_postprocessor in self._node_postprocessors: 125 | nodes = node_postprocessor.postprocess_nodes( 126 | nodes, query_bundle=query_bundle 127 | ) 128 | return nodes 129 | 130 | def retrieve(self, 131 | query_bundle: QueryBundle, 132 | text_query_mode: str = "hybrid", 133 | image_query_mode: str = "default", 134 | metadata_filters = None) -> Dict[str, List[NodeWithScore]]: 135 | 136 | text_retrieval_result = self._retriever.retrieve_text_nodes(query_bundle, text_query_mode, metadata_filters) 137 | image_retrieval_result = self._retriever.retrieve_image_nodes(query_bundle, image_query_mode, metadata_filters) 138 | 139 | reranked_text_nodes = self._retriever.rerank_text_nodes(query_bundle, text_retrieval_result) 140 | reranked_image_nodes = self._retriever.rerank_image_nodes(query_bundle, image_retrieval_result) 141 | 142 | retrieval_results = { 143 | "text_nodes": self._apply_node_postprocessors(reranked_text_nodes, query_bundle=query_bundle), 144 | "image_nodes": self._apply_node_postprocessors(reranked_image_nodes, query_bundle=query_bundle), 145 | } 146 | 147 | return retrieval_results 148 | 149 | # async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: 150 | # nodes = await self._retriever.aretrieve(query_bundle) 151 | # return self._apply_node_postprocessors(nodes, query_bundle=query_bundle) 152 | 153 | def synthesize( 154 | self, 155 | query_bundle: QueryBundle, 156 | #nodes: List[NodeWithScore], 157 | retrieval_results: Dict[str, List[NodeWithScore]], 158 | additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, 159 | ) -> RESPONSE_TYPE: 160 | 161 | 162 | #image_nodes, text_nodes = _get_image_and_text_nodes(nodes) 163 | image_nodes, text_nodes = retrieval_results["image_nodes"], retrieval_results["text_nodes"] 164 | 165 | #TODO: format prompt with (text context), (image + caption of image) 166 | context_str = "\n\n".join([f"Source {text_nodes.index(r)+1}:\n" + r.get_content() for r in text_nodes]) 167 | fmt_prompt = self._text_qa_template.format( 168 | context_str=context_str, 169 | query_str=query_bundle.query_str, 170 | ) 171 | 172 | image_context_str = "\n\n".join([r.get_content() for r in image_nodes]) 173 | image_query_fmt_prompt = self._image_qa_template.format(context_str=image_context_str, query_str=query_bundle.query_str) 174 | 175 | text_context_messages = [ 176 | { 177 | "role": "user", 178 | "content":[ 179 | {"type":"text", "text":fmt_prompt} 180 | ] 181 | } 182 | ] 183 | 184 | 185 | ## Generate response when the mllm(llava) is under llamacpp framework 186 | ##TODO: handle multiple image input 187 | image_url = f"data:image/png;base64,{image_nodes[0].node.image.decode('utf-8')}" 188 | image_context_messages = [ 189 | { 190 | "role": "user", 191 | "content": [ 192 | {"type": "image_url", "image_url": {"url": image_url}}, 193 | {"type": "text", "text": image_query_fmt_prompt} 194 | ] 195 | } 196 | ] 197 | text_context_response = self._multi_modal_llm.chat( 198 | messages=text_context_messages, 199 | ) 200 | 201 | image_context_response = self._multi_modal_llm.chat( 202 | messages=image_context_messages, 203 | ) 204 | 205 | 206 | ## Generate response when the mllm(llava) is under ollama framework 207 | # text_context_response = self._multi_modal_llm.complete( 208 | # prompt=fmt_prompt, 209 | # images=[], 210 | # ) 211 | 212 | # image_context_response = self._multi_modal_llm.complete( 213 | # prompt=image_query_fmt_prompt, 214 | # images=[image_node.node.image for image_node in image_nodes], 215 | # ) 216 | 217 | 218 | #TODO: transform encoded base64 image to image object in GUI 219 | synthesized_response = self._answer_integration_template.format( 220 | context_str=context_str, 221 | image_context_str= "\n\n".join([""+ str(r.node.image) + '\n' + r.node.get_content() for r in image_nodes]), 222 | text_context_response=text_context_response.text.replace("\n"," ").strip(), 223 | image_context_response=i_q_response.text.replace("\n"," ").strip(), 224 | ) 225 | 226 | return Response( 227 | response=str(synthesized_response), 228 | source_nodes=text_nodes+image_nodes, 229 | metadata={ 230 | "query_str": query_bundle.query_str, 231 | "model_config": self._multi_modal_llm.metadata, 232 | }, 233 | ) 234 | 235 | 236 | 237 | # async def asynthesize( 238 | # self, 239 | # query_bundle: QueryBundle, 240 | # nodes: List[NodeWithScore], 241 | # additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, 242 | # ) -> RESPONSE_TYPE: 243 | # image_nodes, text_nodes = _get_image_and_text_nodes(nodes) 244 | # context_str = "\n\n".join([r.get_content() for r in text_nodes]) 245 | # fmt_prompt = self._text_qa_template.format( 246 | # context_str=context_str, query_str=query_bundle.query_str 247 | # ) 248 | # llm_response = await self._multi_modal_llm.acomplete( 249 | # prompt=fmt_prompt, 250 | # image_documents=image_nodes, 251 | # ) 252 | # return Response( 253 | # response=str(llm_response), 254 | # source_nodes=nodes, 255 | # metadata={"text_nodes": text_nodes, "image_nodes": image_nodes}, 256 | # ) 257 | 258 | async def asynthesize( 259 | self, 260 | query_bundle: QueryBundle, 261 | nodes: List[NodeWithScore], 262 | additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, 263 | ) -> RESPONSE_TYPE: 264 | raise NotImplementedError("Async synthesize not implemented yet") 265 | 266 | 267 | def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: 268 | """Answer a query.""" 269 | with self.callback_manager.event( 270 | CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} 271 | ) as query_event: 272 | with self.callback_manager.event( 273 | CBEventType.RETRIEVE, 274 | payload={EventPayload.QUERY_STR: query_bundle.query_str}, 275 | ) as retrieve_event: 276 | retrieval_results = self.retrieve(query_bundle) 277 | 278 | retrieve_event.on_end( 279 | payload={EventPayload.NODES: retrieval_results}, 280 | ) 281 | 282 | response = self.synthesize( 283 | query_bundle, 284 | retrieval_results=retrieval_results, 285 | ) 286 | 287 | query_event.on_end(payload={EventPayload.RESPONSE: response}) 288 | 289 | return response 290 | 291 | async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: 292 | raise NotImplementedError("Async query not implemented yet") 293 | 294 | # async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: 295 | # """Answer a query.""" 296 | # with self.callback_manager.event( 297 | # CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} 298 | # ) as query_event: 299 | # with self.callback_manager.event( 300 | # CBEventType.RETRIEVE, 301 | # payload={EventPayload.QUERY_STR: query_bundle.query_str}, 302 | # ) as retrieve_event: 303 | # nodes = await self.aretrieve(query_bundle) 304 | 305 | # retrieve_event.on_end( 306 | # payload={EventPayload.NODES: nodes}, 307 | # ) 308 | 309 | # response = await self.asynthesize( 310 | # query_bundle, 311 | # nodes=nodes, 312 | # ) 313 | 314 | # query_event.on_end(payload={EventPayload.RESPONSE: response}) 315 | 316 | # return response 317 | 318 | 319 | @property 320 | def retriever(self) -> MultiModalVectorIndexRetriever: 321 | """Get the retriever object.""" 322 | return self._retriever 323 | -------------------------------------------------------------------------------- /src/mm_retriever.py: -------------------------------------------------------------------------------- 1 | from llama_index.legacy.schema import QueryBundle 2 | from llama_index.legacy.retrievers import BaseRetriever 3 | from llama_index.legacy.schema import NodeWithScore 4 | from llama_index.legacy.vector_stores import VectorStoreQuery 5 | from llama_index.legacy.vector_stores.types import VectorStoreQueryMode 6 | from llama_index.legacy.vector_stores.qdrant import QdrantVectorStore 7 | from llama_index.legacy.schema import QueryType 8 | from .custom_vectore_store import MultiModalQdrantVectorStore 9 | from typing import Any, List, Optional 10 | 11 | import numpy as np 12 | from qdrant_client.http import models as rest 13 | 14 | def compute_cosine_similarity(a, b): 15 | if isinstance(a, rest.SparseVector): 16 | intersect_indices = set(a.indices) & set(b.indices) 17 | if len(intersect_indices) == 0: 18 | return 0 19 | else: 20 | a_intersect_values = [a.values[a.indices.index(i)] for i in intersect_indices] 21 | b_intersect_values = [b.values[b.indices.index(i)] for i in intersect_indices] 22 | 23 | 24 | return np.dot(a_intersect_values, b_intersect_values) / (np.linalg.norm(a.values) * np.linalg.norm(b.values)) 25 | else: 26 | return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) 27 | 28 | 29 | 30 | class MultiModalQdrantRetriever(BaseRetriever): 31 | """Retriever over a qdrant vector store.""" 32 | 33 | def __init__( 34 | self, 35 | text_vector_store: QdrantVectorStore, 36 | image_vector_store: MultiModalQdrantVectorStore, 37 | 38 | text_embed_model: Any, 39 | mm_embed_model: Any, 40 | reranker: Optional[Any]=None, 41 | 42 | 43 | text_similarity_top_k: int = 5, 44 | text_sparse_top_k: int = 5, 45 | text_rerank_top_n: int = 3, 46 | image_similarity_top_k: int = 5, 47 | image_sparse_top_k: int = 5, 48 | image_rerank_top_n: int = 1, 49 | 50 | sparse_query_fn: Optional[Any] = None, 51 | ) -> None: 52 | """Init params.""" 53 | 54 | self._text_vector_store = text_vector_store 55 | self._image_vector_store = image_vector_store 56 | self._text_embed_model = text_embed_model 57 | self._mm_embed_model = mm_embed_model 58 | self._reranker = reranker 59 | 60 | self._text_similarity_top_k = text_similarity_top_k 61 | self._text_sparse_top_k = text_sparse_top_k 62 | self._text_rerank_top_n = text_rerank_top_n 63 | self._image_similarity_top_k = image_similarity_top_k 64 | self._image_sparse_top_k = image_sparse_top_k 65 | self._image_rerank_top_n = image_rerank_top_n 66 | 67 | self._sparse_query_fn = sparse_query_fn 68 | 69 | super().__init__() 70 | 71 | def retrieve_text_nodes(self, query_bundle: QueryBundle, query_mode: str="hybrid", metadata_filters=None): 72 | 73 | query_embedding = self._text_embed_model.get_query_embedding( 74 | query_bundle.query_str 75 | ) 76 | 77 | # query with dense text embedding 78 | dense_query = VectorStoreQuery( 79 | query_str=query_bundle.query_str, 80 | query_embedding=query_embedding, 81 | similarity_top_k=self._text_similarity_top_k, 82 | sparse_top_k=self._text_sparse_top_k, 83 | mode=VectorStoreQueryMode.DEFAULT, 84 | filters=metadata_filters, 85 | ) 86 | 87 | # query with sparse text vector 88 | sparse_query = VectorStoreQuery( 89 | query_str=query_bundle.query_str, 90 | query_embedding=query_embedding, 91 | similarity_top_k=self._text_similarity_top_k, 92 | sparse_top_k=self._text_sparse_top_k, 93 | mode=VectorStoreQueryMode.SPARSE, 94 | filters=metadata_filters, 95 | ) 96 | 97 | # mm_query = VectorStoreQuery(...) 98 | 99 | # returns a VectorStoreQueryResult 100 | if query_mode == "default": 101 | dense_query_result = self._text_vector_store.query(dense_query) 102 | 103 | return { 104 | "text-dense": dense_query_result 105 | } 106 | 107 | elif query_mode == "sparse": 108 | sparse_query_result = self._text_vector_store.query(sparse_query) 109 | 110 | return { 111 | "text-sparse": sparse_query_result 112 | } 113 | 114 | 115 | elif query_mode == "hybrid": 116 | dense_query_result = self._text_vector_store.query(dense_query) 117 | sparse_query_result = self._text_vector_store.query(sparse_query) 118 | 119 | return { 120 | "text-dense": dense_query_result, 121 | "text-sparse": sparse_query_result 122 | } 123 | 124 | else: 125 | raise ValueError(f"Invalid text-to-text query mode: {query_mode}, must be one of ['default', 'sparse', 'hybrid']") 126 | 127 | 128 | def rerank_text_nodes(self, query_bundle: QueryBundle, text_retrieval_result, text_rerank_top_n = None): 129 | 130 | text_node_ids, text_nodes = [], [] 131 | # text_node_ids, text_node, text_node_scores = [], [], [] 132 | 133 | for key in text_retrieval_result.keys(): 134 | text_node_ids += text_retrieval_result[key].ids 135 | text_nodes += text_retrieval_result[key].nodes 136 | # text_node_scores += text_retrieval_result[key].similarities 137 | 138 | if text_rerank_top_n is None: 139 | text_rerank_top_n = self._text_rerank_top_n 140 | else: 141 | text_rerank_top_n = min(text_rerank_top_n, len(text_nodes)) 142 | self._reranker.top_n = text_rerank_top_n 143 | 144 | # drop duplicate nodes from sparse retrival and dense retrival 145 | unique_node_indices = list(set([text_node_ids.index(x) for x in text_node_ids if text_node_ids.count(x) >= 1])) 146 | 147 | ## reserve similarity score of retrival stage 148 | # text_nodes = [text_nodes[i] for i in unique_node_indices] 149 | # text_node_scores = [text_node_scores[i] for i in unique_node_indices] 150 | # text_nodes_with_score = [NodeWithScore(node=_[0], score=_[1]) for _ in list(zip(text_nodes, text_node_scores))] 151 | 152 | # set similarity score to 0.0 only for format consistency in reranking stage 153 | text_nodes_with_score = [NodeWithScore(node=text_nodes[i], score=0.0) for i in unique_node_indices] 154 | 155 | 156 | return self._reranker._postprocess_nodes(nodes=text_nodes_with_score, query_bundle=query_bundle) 157 | 158 | 159 | def retrieve_image_nodes(self, query_bundle: QueryBundle, query_mode: str="default", metadata_filters=None): 160 | 161 | 162 | if query_mode == "default": # Default: query with dense multi-modal embedding only 163 | mm_query = VectorStoreQuery( 164 | query_str=query_bundle.query_str, 165 | query_embedding=self._mm_embed_model.get_query_embedding(query_bundle.query_str), 166 | similarity_top_k=self._image_similarity_top_k, 167 | mode=VectorStoreQueryMode.DEFAULT, 168 | filters=metadata_filters, 169 | ) 170 | mm_query_result = self._image_vector_store.text_to_image_query(mm_query) 171 | 172 | return { 173 | "multi-modal": mm_query_result 174 | } 175 | 176 | 177 | elif query_mode == "text-dense": 178 | text_dense_query = VectorStoreQuery( 179 | query_str=query_bundle.query_str, 180 | query_embedding=self._text_embed_model.get_query_embedding(query_bundle.query_str), 181 | similarity_top_k=self._image_similarity_top_k, 182 | mode=VectorStoreQueryMode.DEFAULT, 183 | filters=metadata_filters, 184 | ) 185 | text_dense_query_result = self._image_vector_store.text_to_caption_query(text_dense_query) 186 | 187 | return { 188 | "text-dense": text_dense_query_result 189 | } 190 | 191 | 192 | elif query_mode == "text-sparse": 193 | text_sparse_query = VectorStoreQuery( 194 | query_str=query_bundle.query_str, 195 | #query_embedding=self._text_embed_model.get_query_embedding(query_bundle.query_str), 196 | similarity_top_k=self._image_sparse_top_k, 197 | mode=VectorStoreQueryMode.SPARSE, 198 | filters=metadata_filters, 199 | ) 200 | text_sparse_query_result = self._image_vector_store.text_to_caption_query(text_sparse_query) 201 | 202 | return { 203 | "text-sparse": text_sparse_query_result 204 | } 205 | 206 | elif query_mode == "hybrid": 207 | mm_query = VectorStoreQuery( 208 | query_str=query_bundle.query_str, 209 | query_embedding=self._mm_embed_model.get_query_embedding(query_bundle.query_str), 210 | similarity_top_k=self._image_similarity_top_k, 211 | mode=VectorStoreQueryMode.DEFAULT, 212 | filters=metadata_filters, 213 | ) 214 | mm_query_result = self._image_vector_store.text_to_image_query(mm_query) 215 | 216 | text_dense_query = VectorStoreQuery( 217 | query_str=query_bundle.query_str, 218 | query_embedding=self._text_embed_model.get_query_embedding(query_bundle.query_str), 219 | similarity_top_k=self._image_similarity_top_k, 220 | mode=VectorStoreQueryMode.DEFAULT, 221 | filters=metadata_filters, 222 | ) 223 | text_dense_query_result = self._image_vector_store.text_to_caption_query(text_dense_query) 224 | 225 | text_sparse_query = VectorStoreQuery( 226 | query_str=query_bundle.query_str, 227 | #query_embedding=self._text_embed_model.get_query_embedding(query_bundle.query_str), 228 | similarity_top_k=self._image_sparse_top_k, 229 | mode=VectorStoreQueryMode.SPARSE, 230 | filters=metadata_filters, 231 | ) 232 | text_sparse_query_result = self._image_vector_store.text_to_caption_query(text_sparse_query) 233 | 234 | return { 235 | "multi-modal": mm_query_result, 236 | "text-dense": text_dense_query_result, 237 | "text-sparse": text_sparse_query_result 238 | } 239 | 240 | else: 241 | raise ValueError(f"Invalid text-to-image query mode: {query_mode}, must be one of ['default', 'text-dense', 'text-sparse', 'hybrid']") 242 | 243 | 244 | 245 | def rerank_image_nodes(self, query_bundle: QueryBundle, image_retrieval_result, image_rerank_top_n = None): 246 | 247 | image_nodes, image_node_ids = [], [] 248 | for key in image_retrieval_result.keys(): 249 | image_node_ids += image_retrieval_result[key].ids 250 | image_nodes += image_retrieval_result[key].nodes 251 | 252 | # image_similarities = np.array(image_retrieval_result[key].similarities) 253 | # normed_similarities = (image_similarities - image_similarities.mean()) / image_similarities.std() 254 | # image_node_scores += normed_similarities.tolist() 255 | 256 | unique_node_indices = list(set([image_node_ids.index(x) for x in image_node_ids if image_node_ids.count(x) >= 1])) 257 | image_node_nodes = [image_nodes[i] for i in unique_node_indices] 258 | 259 | if image_rerank_top_n is None: 260 | image_rerank_top_n = self._image_rerank_top_n 261 | else: 262 | image_rerank_top_n = min(image_rerank_top_n, len(image_node_nodes)) 263 | 264 | query_str = query_bundle.query_str 265 | similarity_scores = {key: [] for key in image_retrieval_result.keys()} 266 | 267 | for key in image_retrieval_result.keys(): 268 | if key == "text-dense": 269 | query_embedding = self._text_embed_model.get_query_embedding(query_str) 270 | for i, node in enumerate(image_node_nodes): 271 | node_embedding = node.metadata['vectors'][key] 272 | similarity_scores[key].append(compute_cosine_similarity(query_embedding, node_embedding)) 273 | 274 | elif key == "text-sparse": 275 | query_embedding = self._sparse_query_fn(query_str) 276 | query_embedding = rest.SparseVector(indices=query_embedding[0][0], values=query_embedding[1][0]) 277 | for i, node in enumerate(image_node_nodes): 278 | node_embedding = node.metadata['vectors'][key] 279 | similarity_scores[key].append(compute_cosine_similarity(query_embedding, node_embedding)) 280 | 281 | elif key == "multi-modal": 282 | query_embedding = self._mm_embed_model.get_query_embedding(query_str) 283 | for i, node in enumerate(image_node_nodes): 284 | node_embedding = node.metadata['vectors'][key] 285 | similarity_scores[key].append(compute_cosine_similarity(query_embedding, node_embedding)) 286 | 287 | 288 | rerank_scores = np.zeros(len(image_node_nodes)) 289 | for key in similarity_scores.keys(): 290 | similarity_scores[key] = np.array(similarity_scores[key]) 291 | similarity_scores[key] = (similarity_scores[key] - similarity_scores[key].mean()) / similarity_scores[key].std() 292 | rerank_scores += similarity_scores[key] 293 | 294 | rerank_score_with_index = list(zip(rerank_scores, range(len(image_node_nodes)))) 295 | rerank_score_with_index = sorted(rerank_score_with_index, key=lambda x: x[0], reverse=True) 296 | topn_image_nodes = [NodeWithScore(node=image_node_nodes[_[1]], score=_[0]) for _ in rerank_score_with_index][:image_rerank_top_n] 297 | 298 | for node in topn_image_nodes: 299 | node.node.metadata['vectors'] = None 300 | 301 | return topn_image_nodes 302 | 303 | 304 | 305 | 306 | 307 | 308 | ### TODO: rewrite the following methods to use the new retrieve_text_nodes and retrieve_image_nodes 309 | 310 | def _retrieve(self, query_bundle: QueryBundle, query_mode: str="hybrid", metadata_filters=None): 311 | 312 | """ Deprecated abstract retrieve method from the BaseRetriever, this can only retrieve text nodes.""" 313 | 314 | raise NotImplementedError("This method is deprecated, please use retrieve_text_nodes and retrieve_image_nodes instead.") 315 | 316 | ###TODO: rewrite this method to use the new retrieve_text_nodes and retrieve_image_nodes 317 | # return { 318 | # "text_nodes": self.retrieve_text_nodes(query_bundle, query_mode, metadata_filters), 319 | # "image_nodes": self.retrieve_image_nodes(query_bundle, query_mode, metadata_filters) 320 | # } 321 | def retrieve(self, str_or_query_bundle: QueryType): 322 | raise NotImplementedError("This method is deprecated, please use retrieve_text_nodes and retrieve_image_nodes instead.") 323 | 324 | async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: 325 | """Asynchronously retrieve nodes given query. 326 | 327 | Implemented by the user. 328 | 329 | """ 330 | return self._retrieve(query_bundle) 331 | 332 | async def aretrieve(self, str_or_query_bundle: QueryType): 333 | return 334 | 335 | 336 | if __name__ == "__main__": 337 | 338 | from llama_index.legacy.vector_stores import QdrantVectorStore 339 | from custom_vectore_store import MultiModalQdrantVectorStore 340 | from custom_embeddings import custom_sparse_doc_vectors, custom_sparse_query_vectors 341 | 342 | from functools import partial 343 | 344 | from qdrant_client import QdrantClient 345 | from qdrant_client.http import models as qd_models 346 | 347 | try: 348 | client = QdrantClient(path="qdrant_db") 349 | except: 350 | print("Qdrant server not running, please start the server and try again.") 351 | pass 352 | 353 | # client = QdrantClient(path="qdrant_db") 354 | 355 | 356 | import torch 357 | from transformers import AutoTokenizer, AutoModelForMaskedLM 358 | 359 | 360 | from llama_index.legacy.embeddings import HuggingFaceEmbedding 361 | from custom_embeddings import CustomizedCLIPEmbedding 362 | 363 | BGE_PATH = "./embedding_models/bge-small-en-v1.5" 364 | CLIP_PATH = "./embedding_models/clip-vit-base-patch32" 365 | bge_embedding = HuggingFaceEmbedding(model_name=BGE_PATH, device="cpu", pooling="mean") 366 | clip_embedding = CustomizedCLIPEmbedding(model_name=CLIP_PATH, device="cpu") 367 | 368 | SPLADE_QUERY_PATH = "./embedding_models/efficient-splade-VI-BT-large-query" 369 | splade_q_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_PATH) 370 | splade_q_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_PATH) 371 | 372 | SPLADE_DOC_PATH = "./embedding_models/efficient-splade-VI-BT-large-doc" 373 | splade_d_tokenizer = AutoTokenizer.from_pretrained(SPLADE_DOC_PATH) 374 | splade_d_model = AutoModelForMaskedLM.from_pretrained(SPLADE_DOC_PATH) 375 | 376 | custom_sparse_doc_fn = partial(custom_sparse_doc_vectors, splade_d_tokenizer, splade_d_model, 512) 377 | custom_sparse_query_fn = partial(custom_sparse_query_vectors, splade_q_tokenizer, splade_q_model, 512) 378 | 379 | text_store = QdrantVectorStore( 380 | client=client, 381 | collection_name="text_collection", 382 | enable_hybrid=True, 383 | sparse_query_fn=custom_sparse_query_fn, 384 | sparse_doc_fn=custom_sparse_doc_fn, 385 | stores_text=True, 386 | ) 387 | 388 | image_store = MultiModalQdrantVectorStore( 389 | client=client, 390 | collection_name="image_collection", 391 | enable_hybrid=True, 392 | sparse_query_fn=custom_sparse_query_fn, 393 | sparse_doc_fn=custom_sparse_doc_fn, 394 | stores_text=False, 395 | ) 396 | 397 | 398 | mm_retriever = MultiModalQdrantRetriever( 399 | text_vector_store = text_store, 400 | image_vector_store = image_store, 401 | text_embed_model = bge_embedding, 402 | mm_embed_model = clip_embedding, 403 | ) 404 | 405 | text_query_result = mm_retriever.retrieve_text_nodes(query_bundle=QueryBundle(query_str="How does Llama 2 perform compared to other open-source models?"), query_mode="hybrid") 406 | 407 | image_retrieval_result = mm_retriever.retrieve_image_nodes(query_bundle=QueryBundle(query_str="How does Llama 2 perform compared to other open-source models?"), query_mode="hybrid") 408 | 409 | print(text_query_result) 410 | print(image_retrieval_result) -------------------------------------------------------------------------------- /src/parse_pdf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import scipdf 4 | import argparse 5 | import warnings 6 | from bs4.builder import XMLParsedAsHTMLWarning 7 | 8 | warnings.filterwarnings('ignore', category=XMLParsedAsHTMLWarning) 9 | 10 | PDF_FOLDER = "./data/paper_PDF" 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Parse PDFs to extract figures and texts') 14 | parser.add_argument('--pdf_folder', type=str, default=PDF_FOLDER, help='Path to the folder containing PDFs') 15 | parser.add_argument('--image_resolution', type=int, default=300, help='Resolution of the images extracted from the PDFs') 16 | parser.add_argument('--max_timeout', type=int, default=120, help='Maximum processing time (sec) of figure extraction from a PDF') 17 | 18 | 19 | if __name__ == "__main__": 20 | 21 | args = parser.parse_args() 22 | 23 | pdf_files = os.listdir(args.pdf_folder) 24 | if not os.path.exists(os.path.join(args.pdf_folder, "parsed_figures")): 25 | os.makedirs(os.path.join(args.pdf_folder, "parsed_figures")) 26 | if not os.path.exists(os.path.join(args.pdf_folder, "parsed_texts")): 27 | os.makedirs(os.path.join(args.pdf_folder, "parsed_texts")) 28 | 29 | for file in pdf_files: 30 | if file.endswith(".pdf"): 31 | print(f"---------- Processing {file} ----------") 32 | 33 | fp = os.path.join(args.pdf_folder, file) 34 | scipdf.parse_figures(pdf_folder=fp, resolution=args.image_resolution, output_folder=os.path.join(args.pdf_folder, "parsed_figures"), max_timeout=args.max_timeout) 35 | 36 | parsed_res = scipdf.parse_pdf_to_dict(pdf_path=fp) 37 | with open(os.path.join(args.pdf_folder, f"parsed_texts/{file.split('.pdf')[0]}.json"), "a") as f: 38 | json.dump(parsed_res, f) -------------------------------------------------------------------------------- /src/prompt.py: -------------------------------------------------------------------------------- 1 | CONTEXT_PROMPT_TEMPLATE = """### Instruction to Assistant 2 | You are an expert in computer science who is also skilled at teaching and explaining concepts. Your task is to respond to the User's current query by synthesizing information from both the text and image sources. 3 | Ensure that your response is coherent with the dialogue history and accurately references relevant sources. Be concise, informative, and consider both the textual and visual context. 4 | If you do not know the answer, reply with 'I am sorry, I dont have enough information'. 5 | 6 | ### Contextual Information 7 | """ 8 | 9 | DEFAULT_PROMPT_TEMPLATE = """### Instruction to Assistant 10 | You are an expert in computer science who is also skilled at teaching and explaining concepts. Your task is to respond to the User's current query. 11 | Ensure that your response is coherent with the dialogue history. Be concise and informative. 12 | If you do not know the answer, reply with 'I am sorry, I dont have enough information'. 13 | """ 14 | 15 | # - **Text Sources**: 16 | # {} 17 | 18 | # - **Image Sources**: 19 | # {} 20 | 21 | IMAGE_TOKEN = "./" 22 | 23 | def generate_sys_prompt(sources = None): 24 | if sources and (sources['text_sources'] or sources['image_sources']): 25 | prompt = CONTEXT_PROMPT_TEMPLATE 26 | if sources['text_sources']: 27 | prompt += "- **Text Sources**:\n" 28 | for i, item in enumerate(sources['text_sources']): 29 | text_chunk = item['text'].replace('\n', '\n ') 30 | prompt += f" - Source [{i+1}]:\n {text_chunk}\n" 31 | if sources['image_sources']: 32 | prompt += "- **Image Sources**:\n" 33 | for i, item in enumerate(sources['image_sources']): 34 | prompt += f" - Fig [{i+1}]: {IMAGE_TOKEN}\n Caption: {item['caption']}\n" 35 | else: 36 | prompt = DEFAULT_PROMPT_TEMPLATE 37 | return prompt --------------------------------------------------------------------------------