├── .gitignore ├── images ├── U2.webp ├── UI1.webp ├── flowchart.webp └── 1721017707759.jpg ├── requirements.txt ├── chainlit.md ├── utils ├── openai_embeddings_llm.py ├── chainlit_agents.py ├── settings.yaml ├── pdf_to_markdown.py └── embedding.py ├── README.md └── appUI.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | input 3 | cache 4 | output 5 | settings.yaml 6 | .env 7 | prompts -------------------------------------------------------------------------------- /images/U2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karthik-codex/Autogen_GraphRAG_Ollama/HEAD/images/U2.webp -------------------------------------------------------------------------------- /images/UI1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karthik-codex/Autogen_GraphRAG_Ollama/HEAD/images/UI1.webp -------------------------------------------------------------------------------- /images/flowchart.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karthik-codex/Autogen_GraphRAG_Ollama/HEAD/images/flowchart.webp -------------------------------------------------------------------------------- /images/1721017707759.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karthik-codex/Autogen_GraphRAG_Ollama/HEAD/images/1721017707759.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | litellm[proxy] 2 | ollama 3 | pyautogen[retrievechat] 4 | tiktoken 5 | chainlit 6 | graphrag 7 | marker-pdf 8 | torch 9 | -------------------------------------------------------------------------------- /chainlit.md: -------------------------------------------------------------------------------- 1 | # Multi-Agent AI Superbot using AutoGen and GraphRAG 2 | 3 | This application integrates GraphRAG with AutoGen agents, powered by local LLMs from Ollama, for free and offline embedding and inference. Key highlights include: 4 | - **Agentic-RAG:** - Integrating GraphRAG's knowledge search method with an AutoGen agent via function calling. 5 | - **Offline LLM Support:** - Configuring GraphRAG (local & global search) to support local models from Ollama for inference 6 | and embedding. 7 | - **Non-OpenAI Function Calling:** - Extending AutoGen to support function calling with non-OpenAI LLMs from Ollama via Lite-LLM proxy 8 | server. 9 | - **Interactive UI:** - Deploying Chainlit UI to handle continuous conversations, multi-threading, and user input settings. 10 | 11 | ## Useful Links 🔗 12 | 13 | - **Medium Article:** Microsoft's GraphRAG + AutoGen + Ollama + Chainlit = Fully Local & Free Multi-Agent RAG Superbot [Medium.com](https://medium.com/@karthik.codex/microsofts-graphrag-autogen-ollama-chainlit-fully-local-free-multi-agent-rag-superbot-61ad3759f06f) 📚 14 | -------------------------------------------------------------------------------- /utils/openai_embeddings_llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Microsoft Corporation. 2 | # Licensed under the MIT License 3 | 4 | """The EmbeddingsLLM class.""" 5 | 6 | from typing_extensions import Unpack 7 | 8 | from graphrag.llm.base import BaseLLM 9 | from graphrag.llm.types import ( 10 | EmbeddingInput, 11 | EmbeddingOutput, 12 | LLMInput, 13 | ) 14 | 15 | from .openai_configuration import OpenAIConfiguration 16 | from .types import OpenAIClientTypes 17 | import ollama 18 | 19 | class OpenAIEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]): 20 | """A text-embedding generator LLM.""" 21 | 22 | _client: OpenAIClientTypes 23 | _configuration: OpenAIConfiguration 24 | 25 | def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration): 26 | self.client = client 27 | self.configuration = configuration 28 | 29 | async def _execute_llm( 30 | self, input: EmbeddingInput, **kwargs: Unpack[LLMInput] 31 | ) -> EmbeddingOutput | None: 32 | args = { 33 | "model": self.configuration.model, 34 | **(kwargs.get("model_parameters") or {}), 35 | } 36 | embedding_list = [] 37 | for inp in input: 38 | embedding = ollama.embeddings(model="nomic-embed-text", prompt=inp) 39 | embedding_list.append(embedding["embedding"]) 40 | return embedding_list 41 | -------------------------------------------------------------------------------- /utils/chainlit_agents.py: -------------------------------------------------------------------------------- 1 | from autogen.agentchat import Agent, AssistantAgent, UserProxyAgent 2 | from typing import Dict, Optional, Union, Callable 3 | import chainlit as cl 4 | 5 | async def ask_helper(func, **kwargs): 6 | res = await func(**kwargs).send() 7 | while not res: 8 | res = await func(**kwargs).send() 9 | return res 10 | 11 | class ChainlitAssistantAgent(AssistantAgent): 12 | """ 13 | Wrapper for AutoGens Assistant Agent 14 | """ 15 | def send( 16 | self, 17 | message: Union[Dict, str], 18 | recipient: Agent, 19 | request_reply: Optional[bool] = None, 20 | silent: Optional[bool] = False, 21 | ) -> bool: 22 | cl.run_sync( 23 | cl.Message( 24 | content=f'*Sending message to "{recipient.name}":*\n\n{message}', 25 | author=self.name, 26 | ).send() 27 | ) 28 | super(ChainlitAssistantAgent, self).send( 29 | message=message, 30 | recipient=recipient, 31 | request_reply=request_reply, 32 | silent=silent, 33 | ) 34 | 35 | class ChainlitUserProxyAgent(UserProxyAgent): 36 | """ 37 | Wrapper for AutoGens UserProxy Agent. Simplifies the UI by adding CL Actions. 38 | """ 39 | def get_human_input(self, prompt: str) -> str: 40 | if prompt.startswith( 41 | "Provide feedback to chat_manager. Press enter to skip and use auto-reply" 42 | ): 43 | res = cl.run_sync( 44 | ask_helper( 45 | cl.AskActionMessage, 46 | content="Continue or provide feedback?", 47 | actions=[ 48 | cl.Action( name="continue", value="continue", label="✅ Continue" ), 49 | cl.Action( name="feedback",value="feedback", label="💬 Provide feedback"), 50 | cl.Action( name="exit",value="exit", label="🔚 Exit Conversation" ) 51 | ], 52 | ) 53 | ) 54 | if res.get("value") == "continue": 55 | return "" 56 | if res.get("value") == "exit": 57 | return "exit" 58 | 59 | reply = cl.run_sync(ask_helper(cl.AskUserMessage, content=prompt, timeout=60)) 60 | 61 | return reply["output"].strip() 62 | 63 | def send( 64 | self, 65 | message: Union[Dict, str], 66 | recipient: Agent, 67 | request_reply: Optional[bool] = None, 68 | silent: Optional[bool] = False, 69 | ): 70 | #cl.run_sync( 71 | #cl.Message( 72 | # content=f'*Sending message to "{recipient.name}"*:\n\n{message}', 73 | # author=self.name, 74 | #).send() 75 | #) 76 | super(ChainlitUserProxyAgent, self).send( 77 | message=message, 78 | recipient=recipient, 79 | request_reply=request_reply, 80 | silent=silent, 81 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphRAG + AutoGen + Ollama + Chainlit UI = Local Multi-Agent RAG Superbot 2 | 3 | ![Graphical Abstract](https://github.com/karthik-codex/autogen_graphRAG/blob/main/images/1721017707759.jpg?raw=true) 4 | 5 | This application integrates GraphRAG with AutoGen agents, powered by local LLMs from Ollama, for free and offline embedding and inference. Key highlights include: 6 | - **Agentic-RAG:** - Integrating GraphRAG's knowledge search method with an AutoGen agent via function calling. 7 | - **Offline LLM Support:** - Configuring GraphRAG (local & global search) to support local models from Ollama for inference 8 | and embedding. 9 | - **Non-OpenAI Function Calling:** - Extending AutoGen to support function calling with non-OpenAI LLMs from Ollama via Lite-LLM proxy 10 | server. 11 | - **Interactive UI:** - Deploying Chainlit UI to handle continuous conversations, multi-threading, and user input settings. 12 | 13 | ![Main Interfacce](https://github.com/karthik-codex/autogen_graphRAG/blob/main/images/UI1.webp?raw=true) 14 | ![Widget Settings](https://github.com/karthik-codex/autogen_graphRAG/blob/main/images/U2.webp?raw=true) 15 | 16 | ## Useful Links 🔗 17 | 18 | - **Full Guide:** Microsoft's GraphRAG + AutoGen + Ollama + Chainlit = Fully Local & Free Multi-Agent RAG Superbot [Medium.com](https://medium.com/@karthik.codex/microsofts-graphrag-autogen-ollama-chainlit-fully-local-free-multi-agent-rag-superbot-61ad3759f06f) 📚 19 | 20 | ## 📦 Installation and Setup Linux 21 | 22 | Follow these steps to set up and run AutoGen GraphRAG Local with Ollama and Chainlit UI: 23 | 24 | 1. **Install LLMs:** 25 | 26 | Visit [Ollama's website](https://ollama.com/) for installation files. 27 | 28 | ```bash 29 | ollama pull mistral 30 | ollama pull nomic-embed-text 31 | ollama pull llama3 32 | ollama serve 33 | ``` 34 | 35 | 2. **Create conda environment and install packages:** 36 | ```bash 37 | conda create -n RAG_agents python=3.12 38 | conda activate RAG_agents 39 | git clone https://github.com/karthik-codex/autogen_graphRAG.git 40 | cd autogen_graphRAG 41 | pip install -r requirements.txt 42 | ``` 43 | 3. **Initiate GraphRAG root folder:** 44 | ```bash 45 | mkdir -p ./input 46 | python -m graphrag.index --init --root . 47 | mv ./utils/settings.yaml ./ 48 | ``` 49 | 4. **Replace 'embedding.py' and 'openai_embeddings_llm.py' in the GraphRAG package folder using files from Utils folder:** 50 | ```bash 51 | sudo find / -name openai_embeddings_llm.py 52 | sudo find / -name embedding.py 53 | ``` 54 | 5. **Create embeddings and knowledge graph:** 55 | ```bash 56 | python -m graphrag.index --root . 57 | ``` 58 | 6. **Start Lite-LLM proxy server:** 59 | ```bash 60 | litellm --model ollama_chat/llama3 61 | ``` 62 | 7. **Run app:** 63 | ```bash 64 | chainlit run appUI.py 65 | ``` 66 | 67 | ## 📦 Installation and Setup Windows 68 | 69 | Follow these steps to set up and run AutoGen GraphRAG Local with Ollama and Chainlit UI on Windows: 70 | 71 | 1. **Install LLMs:** 72 | 73 | Visit [Ollama's website](https://ollama.com/) for installation files. 74 | 75 | ```pwsh 76 | ollama pull mistral 77 | ollama pull nomic-embed-text 78 | ollama pull llama3 79 | ollama serve 80 | ``` 81 | 82 | 2. **Create conda environment and install packages:** 83 | ```pwsh 84 | git clone https://github.com/karthik-codex/autogen_graphRAG.git 85 | cd autogen_graphRAG 86 | python -m venv venv 87 | ./venv/Scripts/activate 88 | pip install -r requirements.txt 89 | ``` 90 | 3. **Initiate GraphRAG root folder:** 91 | ```pwsh 92 | mkdir input 93 | python -m graphrag.index --init --root . 94 | cp ./utils/settings.yaml ./ 95 | ``` 96 | 4. **Replace 'embedding.py' and 'openai_embeddings_llm.py' in the GraphRAG package folder using files from Utils folder:** 97 | ```pwsh 98 | cp ./utils/openai_embeddings_llm.py .\venv\Lib\site-packages\graphrag\llm\openai\openai_embeddings_llm.py 99 | cp ./utils/embedding.py .\venv\Lib\site-packages\graphrag\query\llm\oai\embedding.py 100 | ``` 101 | 5. **Create embeddings and knowledge graph:** 102 | ```pwsh 103 | python -m graphrag.index --root . 104 | ``` 105 | 6. **Start Lite-LLM proxy server:** 106 | ```pwsh 107 | litellm --model ollama_chat/llama3 108 | ``` 109 | 7. **Run app:** 110 | ```pwsh 111 | chainlit run appUI.py 112 | ``` 113 | -------------------------------------------------------------------------------- /utils/settings.yaml: -------------------------------------------------------------------------------- 1 | 2 | encoding_model: cl100k_base 3 | skip_workflows: [] 4 | llm: 5 | api_key: ${GRAPHRAG_API_KEY} 6 | type: openai_chat # or azure_openai_chat 7 | model: mistral 8 | model_supports_json: true # recommended if this is available for your model. 9 | # max_tokens: 4000 10 | # request_timeout: 180.0 11 | api_base: http://localhost:11434/v1 12 | # api_version: 2024-02-15-preview 13 | # organization: 14 | # deployment_name: 15 | # tokens_per_minute: 150_000 # set a leaky bucket throttle 16 | # requests_per_minute: 10_000 # set a leaky bucket throttle 17 | # max_retries: 10 18 | # max_retry_wait: 10.0 19 | # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times 20 | # concurrent_requests: 25 # the number of parallel inflight requests that may be made 21 | 22 | parallelization: 23 | stagger: 0.3 24 | # num_threads: 50 # the number of threads to use for parallel processing 25 | 26 | async_mode: threaded # or asyncio 27 | 28 | embeddings: 29 | ## parallelization: override the global parallelization settings for embeddings 30 | async_mode: threaded # or asyncio 31 | llm: 32 | api_key: ${GRAPHRAG_API_KEY} 33 | type: openai_embedding # or azure_openai_embedding 34 | model: nomic_embed_text #text-embedding-3-large #mxbai-embed-large # 35 | api_base: http://localhost:11434/api 36 | # api_version: 2024-02-15-preview 37 | # organization: 38 | # deployment_name: 39 | # tokens_per_minute: 150_000 # set a leaky bucket throttle 40 | # requests_per_minute: 10_000 # set a leaky bucket throttle 41 | # max_retries: 10 42 | # max_retry_wait: 10.0 43 | # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times 44 | concurrent_requests: 25 # the number of parallel inflight requests that may be made 45 | # batch_size: 16 # the number of documents to send in a single request 46 | # batch_max_tokens: 8191 # the maximum number of tokens to send in a single request 47 | # target: required # or optional 48 | 49 | 50 | 51 | chunks: 52 | size: 300 53 | overlap: 100 54 | group_by_columns: [id] # by default, we don't allow chunks to cross documents 55 | 56 | input: 57 | type: file # or blob 58 | file_type: text # or csv 59 | base_dir: "input/markdown" 60 | file_encoding: utf-8 61 | file_pattern: ".*\\.md$" 62 | 63 | cache: 64 | type: file # or blob 65 | base_dir: "cache" 66 | # connection_string: 67 | # container_name: 68 | 69 | storage: 70 | type: file # or blob 71 | base_dir: "output/${timestamp}/artifacts" 72 | # connection_string: 73 | # container_name: 74 | 75 | reporting: 76 | type: file # or console, blob 77 | base_dir: "output/${timestamp}/reports" 78 | # connection_string: 79 | # container_name: 80 | 81 | entity_extraction: 82 | ## llm: override the global llm settings for this task 83 | ## parallelization: override the global parallelization settings for this task 84 | ## async_mode: override the global async_mode settings for this task 85 | prompt: "prompts/entity_extraction.txt" 86 | entity_types: [organization,person,geo,event] 87 | max_gleanings: 0 88 | 89 | summarize_descriptions: 90 | ## llm: override the global llm settings for this task 91 | ## parallelization: override the global parallelization settings for this task 92 | ## async_mode: override the global async_mode settings for this task 93 | prompt: "prompts/summarize_descriptions.txt" 94 | max_length: 500 95 | 96 | claim_extraction: 97 | ## llm: override the global llm settings for this task 98 | ## parallelization: override the global parallelization settings for this task 99 | ## async_mode: override the global async_mode settings for this task 100 | # enabled: true 101 | prompt: "prompts/claim_extraction.txt" 102 | description: "Any claims or facts that could be relevant to information discovery." 103 | max_gleanings: 0 104 | 105 | community_report: 106 | ## llm: override the global llm settings for this task 107 | ## parallelization: override the global parallelization settings for this task 108 | ## async_mode: override the global async_mode settings for this task 109 | prompt: "prompts/community_report.txt" 110 | max_length: 2000 111 | max_input_length: 8000 112 | 113 | cluster_graph: 114 | max_cluster_size: 10 115 | 116 | embed_graph: 117 | enabled: false # if true, will generate node2vec embeddings for nodes 118 | # num_walks: 10 119 | # walk_length: 40 120 | # window_size: 2 121 | # iterations: 3 122 | # random_seed: 597832 123 | 124 | umap: 125 | enabled: false # if true, will generate UMAP embeddings for nodes 126 | 127 | snapshots: 128 | graphml: True 129 | raw_entities: false 130 | top_level_nodes: True 131 | 132 | local_search: 133 | # text_unit_prop: 0.5 134 | # community_prop: 0.1 135 | # conversation_history_max_turns: 5 136 | # top_k_mapped_entities: 10 137 | # top_k_relationships: 10 138 | # max_tokens: 12000 139 | 140 | global_search: 141 | # max_tokens: 12000 142 | # data_max_tokens: 12000 143 | # map_max_tokens: 1000 144 | # reduce_max_tokens: 2000 145 | # concurrency: 32 146 | -------------------------------------------------------------------------------- /utils/pdf_to_markdown.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS 3 | from marker.convert import convert_single_pdf 4 | from marker.logger import configure_logging 5 | from marker.models import load_all_models 6 | os.environ["IN_STREAMLIT"] = "true" # Avoid multiprocessing inside surya 7 | os.environ["PDFTEXT_CPU_WORKERS"] = "1" # Avoid multiprocessing inside pdftext 8 | import pypdfium2 # Needs to be at the top to avoid warnings 9 | import argparse 10 | import torch.multiprocessing as mp 11 | from tqdm import tqdm 12 | import math 13 | from marker.output import markdown_exists, save_markdown 14 | from marker.pdf.utils import find_filetype 15 | from marker.pdf.extract_text import get_length_of_text 16 | from marker.settings import settings 17 | import traceback 18 | import json 19 | 20 | configure_logging() 21 | 22 | def worker_init(shared_model): 23 | if shared_model is None: 24 | shared_model = load_all_models() 25 | 26 | global model_refs 27 | model_refs = shared_model 28 | 29 | def worker_exit(): 30 | global model_refs 31 | del model_refs 32 | 33 | def process_single_pdf(args): 34 | filepath, out_folder, metadata, min_length = args 35 | 36 | fname = os.path.basename(filepath) 37 | if markdown_exists(out_folder, fname): 38 | return 39 | 40 | try: 41 | # Skip trying to convert files that don't have a lot of embedded text 42 | # This can indicate that they were scanned, and not OCRed properly 43 | # Usually these files are not recent/high-quality 44 | if min_length: 45 | filetype = find_filetype(filepath) 46 | if filetype == "other": 47 | return 0 48 | 49 | length = get_length_of_text(filepath) 50 | if length < min_length: 51 | return 52 | 53 | full_text, images, out_metadata = convert_single_pdf(filepath, model_refs, metadata=metadata, batch_multiplier=2) 54 | if len(full_text.strip()) > 0: 55 | save_markdown(out_folder, fname, full_text, images, out_metadata) 56 | else: 57 | print(f"Empty file: {filepath}. Could not convert.") 58 | except Exception as e: 59 | print(f"Error converting {filepath}: {e}") 60 | print(traceback.format_exc()) 61 | 62 | 63 | def multiple(): 64 | chunk_idx = 0 65 | num_chunks = 1 66 | max = None 67 | workers = 10 68 | meta = None 69 | min_len = None 70 | in_folder = 'input/toray' #os.path.abspath(args.in_folder) 71 | out_folder = 'input/markdown' #os.path.abspath(args.out_folder) 72 | 73 | files = [os.path.join(in_folder, f) for f in os.listdir(in_folder)] 74 | files = [f for f in files if os.path.isfile(f)] 75 | os.makedirs(out_folder, exist_ok=True) 76 | 77 | # Handle chunks if we're processing in parallel 78 | # Ensure we get all files into a chunk 79 | chunk_size = math.ceil(len(files) / num_chunks) 80 | start_idx = chunk_idx * chunk_size 81 | end_idx = start_idx + chunk_size 82 | files_to_convert = files[start_idx:end_idx] 83 | 84 | # Limit files converted if needed 85 | if max: 86 | files_to_convert = files_to_convert[:max] 87 | 88 | metadata = {} 89 | if meta: 90 | metadata_file = os.path.abspath(meta) 91 | with open(metadata_file, "r") as f: 92 | metadata = json.load(f) 93 | 94 | total_processes = min(len(files_to_convert), workers) 95 | 96 | # Dynamically set GPU allocation per task based on GPU ram 97 | if settings.CUDA: 98 | tasks_per_gpu = settings.INFERENCE_RAM // settings.VRAM_PER_TASK if settings.CUDA else 0 99 | total_processes = int(min(tasks_per_gpu, total_processes)) 100 | else: 101 | total_processes = int(total_processes) 102 | 103 | try: 104 | mp.set_start_method('spawn') # Required for CUDA, forkserver doesn't work 105 | except RuntimeError: 106 | raise RuntimeError("Set start method to spawn twice. This may be a temporary issue with the script. Please try running it again.") 107 | 108 | if settings.TORCH_DEVICE == "mps" or settings.TORCH_DEVICE_MODEL == "mps": 109 | print("Cannot use MPS with torch multiprocessing share_memory. This will make things less memory efficient. If you want to share memory, you have to use CUDA or CPU. Set the TORCH_DEVICE environment variable to change the device.") 110 | 111 | model_lst = None 112 | else: 113 | model_lst = load_all_models() 114 | 115 | for model in model_lst: 116 | if model is None: 117 | continue 118 | model.share_memory() 119 | 120 | print(f"Converting {len(files_to_convert)} pdfs in chunk {chunk_idx + 1}/{num_chunks} with {total_processes} processes, and storing in {out_folder}") 121 | task_args = [(f, out_folder, metadata.get(os.path.basename(f)), min_len) for f in files_to_convert] 122 | 123 | with mp.Pool(processes=total_processes, initializer=worker_init, initargs=(model_lst,)) as pool: 124 | list(tqdm(pool.imap(process_single_pdf, task_args), total=len(task_args), desc="Processing PDFs", unit="pdf")) 125 | 126 | pool._worker_handler.terminate = worker_exit 127 | 128 | # Delete all CUDA tensors 129 | del model_lst 130 | 131 | 132 | def single(): 133 | fname = 'input/toray/Toray-Cetex-TC910_PA6_PDS.pdf' #'input/solvay/Composite_Aerospace_Brochure.pdf' 134 | model_lst = load_all_models() 135 | full_text, images, out_meta = convert_single_pdf(fname, model_lst, max_pages=None, langs=None, batch_multiplier=2, start_page=None) 136 | 137 | fname = os.path.basename(fname) 138 | 139 | output = 'input/markdown' 140 | subfolder_path = save_markdown(output, fname, full_text, images, out_meta) 141 | 142 | print(f"Saved markdown to the {subfolder_path} folder") 143 | 144 | 145 | if __name__ == "__main__": 146 | single() 147 | #multiple() -------------------------------------------------------------------------------- /utils/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Microsoft Corporation. 2 | # Licensed under the MIT License 3 | 4 | """OpenAI Embedding model implementation.""" 5 | 6 | import asyncio 7 | from collections.abc import Callable 8 | from typing import Any 9 | import ollama 10 | import numpy as np 11 | import tiktoken 12 | from tenacity import ( 13 | AsyncRetrying, 14 | RetryError, 15 | Retrying, 16 | retry_if_exception_type, 17 | stop_after_attempt, 18 | wait_exponential_jitter, 19 | ) 20 | 21 | from graphrag.query.llm.base import BaseTextEmbedding 22 | from graphrag.query.llm.oai.base import OpenAILLMImpl 23 | from graphrag.query.llm.oai.typing import ( 24 | OPENAI_RETRY_ERROR_TYPES, 25 | OpenaiApiType, 26 | ) 27 | from graphrag.query.llm.text_utils import chunk_text 28 | from graphrag.query.progress import StatusReporter 29 | 30 | 31 | class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl): 32 | """Wrapper for OpenAI Embedding models.""" 33 | 34 | def __init__( 35 | self, 36 | api_key: str | None = None, 37 | azure_ad_token_provider: Callable | None = None, 38 | model: str = "text-embedding-3-small", 39 | deployment_name: str | None = None, 40 | api_base: str | None = None, 41 | api_version: str | None = None, 42 | api_type: OpenaiApiType = OpenaiApiType.OpenAI, 43 | organization: str | None = None, 44 | encoding_name: str = "cl100k_base", 45 | max_tokens: int = 8191, 46 | max_retries: int = 10, 47 | request_timeout: float = 180.0, 48 | retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore 49 | reporter: StatusReporter | None = None, 50 | ): 51 | OpenAILLMImpl.__init__( 52 | self=self, 53 | api_key=api_key, 54 | azure_ad_token_provider=azure_ad_token_provider, 55 | deployment_name=deployment_name, 56 | api_base=api_base, 57 | api_version=api_version, 58 | api_type=api_type, # type: ignore 59 | organization=organization, 60 | max_retries=max_retries, 61 | request_timeout=request_timeout, 62 | reporter=reporter, 63 | ) 64 | 65 | self.model = model 66 | self.encoding_name = encoding_name 67 | self.max_tokens = max_tokens 68 | self.token_encoder = tiktoken.get_encoding(self.encoding_name) 69 | self.retry_error_types = retry_error_types 70 | self.embedding_dim = 384 # Nomic-embed-text model dimension 71 | self.ollama_client = ollama.Client() 72 | 73 | def embed(self, text: str, **kwargs: Any) -> list[float]: 74 | """Embed text using Ollama's nomic-embed-text model.""" 75 | try: 76 | embedding = self.ollama_client.embeddings(model="nomic-embed-text", prompt=text) 77 | return embedding["embedding"] 78 | except Exception as e: 79 | self._reporter.error( 80 | message="Error embedding text", 81 | details={self.__class__.__name__: str(e)}, 82 | ) 83 | return np.zeros(self.embedding_dim).tolist() 84 | 85 | async def aembed(self, text: str, **kwargs: Any) -> list[float]: 86 | """Embed text using Ollama's nomic-embed-text model asynchronously.""" 87 | try: 88 | embedding = await self.ollama_client.embeddings(model="nomic-embed-text", prompt=text) 89 | return embedding["embedding"] 90 | except Exception as e: 91 | self._reporter.error( 92 | message="Error embedding text asynchronously", 93 | details={self.__class__.__name__: str(e)}, 94 | ) 95 | return np.zeros(self.embedding_dim).tolist() 96 | 97 | def _embed_with_retry( 98 | self, text: str | tuple, **kwargs: Any #str | tuple 99 | ) -> tuple[list[float], int]: 100 | try: 101 | retryer = Retrying( 102 | stop=stop_after_attempt(self.max_retries), 103 | wait=wait_exponential_jitter(max=10), 104 | reraise=True, 105 | retry=retry_if_exception_type(self.retry_error_types), 106 | ) 107 | for attempt in retryer: 108 | with attempt: 109 | embedding = ( 110 | self.sync_client.embeddings.create( # type: ignore 111 | input=text, 112 | model=self.model, 113 | **kwargs, # type: ignore 114 | ) 115 | .data[0] 116 | .embedding 117 | or [] 118 | ) 119 | return (embedding["embedding"], len(text)) 120 | except RetryError as e: 121 | self._reporter.error( 122 | message="Error at embed_with_retry()", 123 | details={self.__class__.__name__: str(e)}, 124 | ) 125 | return ([], 0) 126 | else: 127 | # TODO: why not just throw in this case? 128 | return ([], 0) 129 | 130 | async def _aembed_with_retry( 131 | self, text: str | tuple, **kwargs: Any 132 | ) -> tuple[list[float], int]: 133 | try: 134 | retryer = AsyncRetrying( 135 | stop=stop_after_attempt(self.max_retries), 136 | wait=wait_exponential_jitter(max=10), 137 | reraise=True, 138 | retry=retry_if_exception_type(self.retry_error_types), 139 | ) 140 | async for attempt in retryer: 141 | with attempt: 142 | embedding = ( 143 | await self.async_client.embeddings.create( # type: ignore 144 | input=text, 145 | model=self.model, 146 | **kwargs, # type: ignore 147 | ) 148 | ).data[0].embedding or [] 149 | return (embedding, len(text)) 150 | except RetryError as e: 151 | self._reporter.error( 152 | message="Error at embed_with_retry()", 153 | details={self.__class__.__name__: str(e)}, 154 | ) 155 | return ([], 0) 156 | else: 157 | # TODO: why not just throw in this case? 158 | return ([], 0) 159 | -------------------------------------------------------------------------------- /appUI.py: -------------------------------------------------------------------------------- 1 | import autogen 2 | from rich import print 3 | import chainlit as cl 4 | from typing_extensions import Annotated 5 | from chainlit.input_widget import ( 6 | Select, Slider, Switch) 7 | from autogen import AssistantAgent, UserProxyAgent 8 | from utils.chainlit_agents import ChainlitUserProxyAgent, ChainlitAssistantAgent 9 | from graphrag.query.cli import run_global_search, run_local_search 10 | 11 | # LLama3 LLM from Lite-LLM Server for Agents # 12 | llm_config_autogen = { 13 | "seed": 42, # change the seed for different trials 14 | "temperature": 0, 15 | "config_list": [{"model": "litellm", 16 | "base_url": "http://0.0.0.0:4000/", 17 | 'api_key': 'ollama'}, 18 | ], 19 | "timeout": 60000, 20 | } 21 | 22 | @cl.on_chat_start 23 | async def on_chat_start(): 24 | try: 25 | settings = await cl.ChatSettings( 26 | [ 27 | Switch(id="Search_type", label="(GraphRAG) Local Search", initial=True), 28 | Select( 29 | id="Gen_type", 30 | label="(GraphRAG) Content Type", 31 | values=["prioritized list", "single paragraph", "multiple paragraphs", "multiple-page report"], 32 | initial_index=1, 33 | ), 34 | Slider( 35 | id="Community", 36 | label="(GraphRAG) Community Level", 37 | initial=0, 38 | min=0, 39 | max=2, 40 | step=1, 41 | ), 42 | 43 | ] 44 | ).send() 45 | 46 | response_type = settings["Gen_type"] 47 | community = settings["Community"] 48 | local_search = settings["Search_type"] 49 | 50 | cl.user_session.set("Gen_type", response_type) 51 | cl.user_session.set("Community", community) 52 | cl.user_session.set("Search_type", local_search) 53 | 54 | retriever = AssistantAgent( 55 | name="Retriever", 56 | llm_config=llm_config_autogen, 57 | system_message="""Only execute the function query_graphRAG to look for context. 58 | Output 'TERMINATE' when an answer has been provided.""", 59 | max_consecutive_auto_reply=1, 60 | human_input_mode="NEVER", 61 | description="Retriever Agent" 62 | ) 63 | 64 | user_proxy = ChainlitUserProxyAgent( 65 | name="User_Proxy", 66 | human_input_mode="ALWAYS", 67 | llm_config=llm_config_autogen, 68 | is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"), 69 | code_execution_config=False, 70 | system_message='''A human admin. Interact with the retriever to provide any context''', 71 | description="User Proxy Agent" 72 | ) 73 | 74 | print("Set agents.") 75 | 76 | cl.user_session.set("Query Agent", user_proxy) 77 | cl.user_session.set("Retriever", retriever) 78 | 79 | msg = cl.Message(content=f"""Hello! What task would you like to get done today? 80 | """, 81 | author="User_Proxy") 82 | await msg.send() 83 | 84 | print("Message sent.") 85 | 86 | except Exception as e: 87 | print("Error: ", e) 88 | pass 89 | 90 | @cl.on_settings_update 91 | async def setup_agent(settings): 92 | response_type = settings["Gen_type"] 93 | community = settings["Community"] 94 | local_search = settings["Search_type"] 95 | cl.user_session.set("Gen_type", response_type) 96 | cl.user_session.set("Community", community) 97 | cl.user_session.set("Search_type", local_search) 98 | print("on_settings_update", settings) 99 | 100 | @cl.on_message 101 | async def run_conversation(message: cl.Message): 102 | print("Running conversation") 103 | INPUT_DIR = None 104 | ROOT_DIR = '.' 105 | CONTEXT = message.content 106 | MAX_ITER = 10 107 | RESPONSE_TYPE = cl.user_session.get("Gen_type") 108 | COMMUNITY = cl.user_session.get("Community") 109 | LOCAL_SEARCH = cl.user_session.get("Search_type") 110 | 111 | retriever = cl.user_session.get("Retriever") 112 | user_proxy = cl.user_session.get("Query Agent") 113 | print("Setting groupchat") 114 | 115 | def state_transition(last_speaker, groupchat): 116 | messages = groupchat.messages 117 | if last_speaker is user_proxy: 118 | return retriever 119 | if last_speaker is retriever: 120 | if messages[-1]["content"].lower() not in ['math_expert','physics_expert']: 121 | return user_proxy 122 | else: 123 | if messages[-1]["content"].lower() == 'math_expert': 124 | return user_proxy 125 | else: 126 | return user_proxy 127 | else: 128 | pass 129 | return None 130 | 131 | async def query_graphRAG( 132 | question: Annotated[str, 'Query string containing information that you want from RAG search'] 133 | ) -> str: 134 | if LOCAL_SEARCH: 135 | print(LOCAL_SEARCH) 136 | result = run_local_search(INPUT_DIR, ROOT_DIR, COMMUNITY ,RESPONSE_TYPE, question) 137 | else: 138 | result = run_global_search(INPUT_DIR, ROOT_DIR, COMMUNITY ,RESPONSE_TYPE, question) 139 | await cl.Message(content=result).send() 140 | return result 141 | 142 | for caller in [retriever]: 143 | d_retrieve_content = caller.register_for_llm( 144 | description="retrieve content for code generation and question answering.", api_style="function" 145 | )(query_graphRAG) 146 | 147 | for agents in [user_proxy, retriever]: 148 | agents.register_for_execution()(d_retrieve_content) 149 | 150 | groupchat = autogen.GroupChat( 151 | agents=[user_proxy, retriever], 152 | messages=[], 153 | max_round=MAX_ITER, 154 | speaker_selection_method=state_transition, 155 | allow_repeat_speaker=True, 156 | ) 157 | manager = autogen.GroupChatManager(groupchat=groupchat, 158 | llm_config=llm_config_autogen, 159 | is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"), 160 | code_execution_config=False, 161 | ) 162 | 163 | # -------------------- Conversation Logic. Edit to change your first message based on the Task you want to get done. ----------------------------- # 164 | if len(groupchat.messages) == 0: 165 | await cl.make_async(user_proxy.initiate_chat)( manager, message=CONTEXT, ) 166 | elif len(groupchat.messages) < MAX_ITER: 167 | await cl.make_async(user_proxy.send)( manager, message=CONTEXT, ) 168 | elif len(groupchat.messages) == MAX_ITER: 169 | await cl.make_async(user_proxy.send)( manager, message="exit", ) 170 | 171 | --------------------------------------------------------------------------------