├── .env-example ├── .gitignore ├── LICENSE ├── README-CN.md ├── README.md ├── api.py ├── app.py ├── assets ├── image1.png ├── image2.gif └── image3.png ├── index_app.py ├── indexing ├── .env ├── input │ └── test.txt ├── prompts │ ├── claim_extraction.txt │ ├── community_report.txt │ ├── entity_extraction.txt │ └── summarize_descriptions.txt └── settings.yaml ├── lancedb └── empty.md ├── requirements.txt ├── settings-example.yaml └── web.py /.env-example: -------------------------------------------------------------------------------- 1 | LLM_API_BASE=http://localhost:11434/v1 2 | LLM_MODEL=mistral:7b 3 | LLM_API_KEY=ollama 4 | LLM_SERVICE_TYPE=openai_chat 5 | 6 | EMBEDDINGS_API_BASE=http://localhost:11434/v1 7 | EMBEDDINGS_MODEL=nomic-embed-text:latest 8 | EMBEDDINGS_API_KEY=ollama 9 | EMBEDDINGS_SERVICE_TYPE=openai_embedding 10 | 11 | GRAPHRAG_API_KEY=ollama 12 | ROOT_DIR=indexing 13 | INPUT_DIR=${ROOT_DIR}/output/${timestamp}/artifacts 14 | 15 | API_URL=http://localhost:8012 16 | API_PORT=8012 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | indexing/output/* 2 | indexing/prompt_tuning_config.yaml 3 | __pycache__ 4 | indexing/cache 5 | .vscode 6 | setup.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Beckett 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-CN.md: -------------------------------------------------------------------------------- 1 | [English](./README.md) | 简体中文 2 | 3 | # GraphRAG-UI 4 | 5 | GraphRAG-UI 是 [GraphRAG](https://github.com/microsoft/graphrag) 的用户友好界面,GraphRAG 是一个强大的工具,可使用检索增强生成(RAG)方法对大量文本数据进行索引和查询。本项目支持最新版graphrag-0.3.3,旨在为 GraphRAG 提供方便的管理和交互方式,支持配置 ollama 等本地大模型服务,使其更容易为广大用户所使用。 6 | 7 | ## 致谢 8 | 9 | 本项目目前是在 [severian42](https://github.com/severian42) 及其 [GraphRAG-Local-UI](https://github.com/severian42/GraphRAG-Local-UI) 项目的基础上升级而来。在此,我要衷心感谢他,为本项目奠定了坚实的基础。后期可能会加入一些新特性。 10 | 11 | ## 特性 12 | 13 | - **直观的 Web 界面**: GraphRAG-UI 提供了用户友好的 Web 界面,可以轻松配置和使用 GraphRAG。 14 | - **索引管理**: 快速创建、更新和管理您的文本数据索引。 15 | - **查询执行**: 提交自然语言查询,并从索引数据中获取相关内容,之后从大模型获取相应结果。 16 | - **配置选项**: 自定义各种设置和参数,以微调索引和查询过程。 17 | - **日志和监控**: 通过详细的日志和状态更新,监控索引和查询任务的进度。 18 | 19 | 20 | ## 示例截图: 21 | ### 索引 22 | 23 | ![GraphRAG UI](./assets/image1.png) 24 | 25 | ### 图可视化 (GIF 图) 26 | 27 | ![GraphRAG UI](./assets/image2.gif) 28 | 29 | ### 使用 GraphRAG 聊天 30 | 31 | ![GraphRAG UI](./assets/image3.png) 32 | 33 | ## pip 安装使用 34 | 35 | 1. 安装ollama(可选): 36 | 37 | 访问 [Ollama官网](https://ollama.com/) 来安装。如果是 Linux ,可以直接运行下面命令 38 | 39 | ```bash 40 | curl -fsSL https://ollama.com/install.sh | sh 41 | ``` 42 | 43 | 2. pip 安装本软件: 44 | 45 | ```bash 46 | pip install graphrag-ui 47 | 或者 48 | pip install graphrag-ui -i https://pypi.org/simple 49 | ``` 50 | 51 | 3. 启动 API Server 52 | 53 | ```bash 54 | graphrag-ui-server 55 | ``` 56 | 57 | 4. 启动 UI 58 | 59 | 启动综合版 UI 60 | 61 | ```bash 62 | graphrag-ui 63 | ``` 64 | 65 | 或启动纯净版 UI 66 | 67 | ```bash 68 | graphrag-ui-pure 69 | ``` 70 | 71 | 72 | 73 | ## 源码安装使用 74 | 75 | 1. 创建并激活一个新的conda环境: 76 | ```bash 77 | conda create -n graphrag-ui -y 78 | conda activate graphrag-ui 79 | ``` 80 | 81 | 82 | 2. 安装ollama(可选): 83 | 84 | 访问 [Ollama官网](https://ollama.com/) 来安装。如果是 Linux ,可以直接运行下面命令 85 | 86 | ```bash 87 | curl -fsSL https://ollama.com/install.sh | sh 88 | ``` 89 | 90 | 3. 克隆存储库: 91 | 92 | ```bash 93 | git clone https://github.com/wade1010/graphrag-ui.git 94 | ``` 95 | 96 | 97 | 4. 安装所需的软件包: 98 | ```bash 99 | cd graphrag-ui 100 | pip install -r requirements.txt 101 | ``` 102 | 103 | 5. 启动API服务器 104 | ```bash 105 | python api.py --host 0.0.0.0 --port 8012 --reload 106 | ``` 107 | 108 | 6. 启动 109 | 110 | - **纯净版** 111 | 112 | 该版本只做索引、Prompt Tuning 和文件管理,没有查询功能。 113 | ```bash 114 | gradio index_app.py 115 | 或者 116 | python index_app.py 117 | ``` 118 | - **综合版** 119 | 120 | 该版本在纯净版的基础上增加了可视化图表、配置管理和使用 GraphRAG 聊天。 121 | ```bash 122 | python app.py 123 | ``` 124 | 7. 访问 UI 125 | - **纯净版**: `http://localhost:7860` 126 | - **综合版**: `http://localhost:7862` 127 | 128 | ## 安装使用博客 129 | 130 | [https://blog.csdn.net/wade1010/article/details/142374956](https://blog.csdn.net/wade1010/article/details/142374956) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | English | [简体中文](./README-CN.md) 3 | 4 | # GraphRAG-UI 5 | 6 | GraphRAG-UI is a user-friendly interface for [GraphRAG](https://github.com/microsoft/graphrag), a powerful tool that uses the Retrieval-Augmented Generation (RAG) approach to index and query large text data. This project supports the latest version graphrag-0.3.3 and aims to provide a convenient management and interaction method for GraphRAG, supporting the configuration of local large language models like Ollama, making it easier for users to leverage. 7 | 8 | ## Acknowledgments 9 | 10 | This project is currently an upgrade based on the work of [severian42](https://github.com/severian42) and his [GraphRAG-Local-UI](https://github.com/severian42/GraphRAG-Local-UI) project. I would like to express my sincere gratitude to him for laying a solid foundation for this project. New features may be added in the future. 11 | 12 | ## Features 13 | 14 | - **Intuitive Web Interface**: GraphRAG-UI provides a user-friendly web interface for easy configuration and use of GraphRAG. 15 | - **Index Management**: Quickly create, update, and manage your text data indexes. 16 | - **Query Execution**: Submit natural language queries and retrieve relevant content from indexed data, followed by responses from a large language model. 17 | - **Configuration Options**: Customize various settings and parameters to fine-tune the indexing and querying processes. 18 | - **Logging and Monitoring**: Monitor the progress of indexing and querying tasks through detailed logs and status updates. 19 | 20 | ## Sample screenshots: 21 | ### Indexing 22 | 23 | ![GraphRAG UI](./assets/image1.png) 24 | 25 | ### Visualize Graph (GIF image) 26 | 27 | ![GraphRAG UI](./assets/image2.gif) 28 | 29 | ### Chat With GraphRAG 30 | 31 | ![GraphRAG UI](./assets/image3.png) 32 | 33 | ## Usage with pip 34 | 35 | 1. Install Ollama (optional): 36 | 37 | Visit the [Ollama website](https://ollama.com/) to install. If you're on Linux, you can run the following command directly: 38 | 39 | ```bash 40 | curl -fsSL https://ollama.com/install.sh | sh 41 | ``` 42 | 43 | 2. Install this software via pip: 44 | 45 | ```bash 46 | pip install graphrag-ui 47 | or 48 | pip install graphrag-ui -i https://pypi.org/simple 49 | ``` 50 | 51 | 3. Start the API Server 52 | 53 | ```bash 54 | graphrag-ui-server 55 | ``` 56 | 57 | 4. Start the UI 58 | 59 | Start the comprehensive UI 60 | 61 | ```bash 62 | graphrag-ui 63 | ``` 64 | 65 | Or start the pure UI 66 | 67 | ```bash 68 | graphrag-ui-pure 69 | ``` 70 | 71 | ## Source code installation and usage 72 | 73 | 1. Create and activate a new conda environment: 74 | ```bash 75 | conda create -n graphrag-ui -y 76 | conda activate graphrag-ui 77 | ``` 78 | 2. Install Ollama(optional): 79 | 80 | Visit [Ollama's website](https://ollama.com/) for installation instructions. 81 | 82 | Or Linux, run: 83 | 84 | ```bash 85 | curl -fsSL https://ollama.com/install.sh | sh 86 | ``` 87 | 88 | 3. Clone the repository: 89 | ```bash 90 | git clone https://github.com/wade1010/graphrag-ui.git 91 | ``` 92 | 93 | 4. Install the required packages: 94 | ```bash 95 | cd graphrag-ui 96 | pip install -r requirements.txt 97 | ``` 98 | 99 | 5. Start the API server: 100 | ```bash 101 | python api.py --host 0.0.0.0 --port 8012 --reload 102 | ``` 103 | 104 | 6. Start the UI: 105 | - **Clean version** 106 | 107 | This version only supports indexing, Prompt Tuning, and file management, without query functionality. 108 | ```bash 109 | gradio index_app.py 110 | or 111 | python index_app.py 112 | ``` 113 | - **Comprehensive version** 114 | 115 | This version adds visualizations, configuration management, and GraphRAG chat functionality on top of the clean version. 116 | ```bash 117 | python app.py 118 | ``` 119 | 120 | 7. Access the UI: 121 | - **Clean version**: `http://localhost:7860` 122 | - **Comprehensive version**: `http://localhost:7862` 123 | 124 | ## Installation and Usage Blog 125 | 126 | [https://blog.csdn.net/wade1010/article/details/142374956](https://blog.csdn.net/wade1010/article/details/142374956) -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | import sys 2 | try: 3 | import graphrag 4 | except ImportError: 5 | print("The 'graphrag' package is not installed. Please install it using 'pip install graphrag'.Since the dependency package `aiofiles` of `graphrag` conflicts with the requirements of `gradio`, it is necessary to manually install `graphrag` separately.") 6 | sys.exit(1) 7 | from dotenv import load_dotenv 8 | import os 9 | import asyncio 10 | import tempfile 11 | from collections import deque 12 | import time 13 | import uuid 14 | import json 15 | import re 16 | import pandas as pd 17 | import tiktoken 18 | import logging 19 | import yaml 20 | import shutil 21 | from fastapi import Body 22 | from fastapi import FastAPI, HTTPException, Request, BackgroundTasks, Depends 23 | from fastapi.responses import JSONResponse, StreamingResponse 24 | from pydantic import BaseModel, Field 25 | from typing import List, Optional, Dict, Any, Union 26 | from contextlib import asynccontextmanager 27 | try: 28 | from web import DuckDuckGoSearchAPIWrapper 29 | except ImportError: 30 | from .web import DuckDuckGoSearchAPIWrapper 31 | from functools import lru_cache 32 | import requests 33 | import subprocess 34 | import argparse 35 | 36 | # GraphRAG related imports 37 | from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey 38 | from graphrag.query.indexer_adapters import ( 39 | read_indexer_covariates, 40 | read_indexer_entities, 41 | read_indexer_relationships, 42 | read_indexer_reports, 43 | read_indexer_text_units, 44 | ) 45 | from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings 46 | from graphrag.query.llm.oai.chat_openai import ChatOpenAI 47 | from graphrag.query.llm.oai.embedding import OpenAIEmbedding 48 | from graphrag.query.llm.oai.typing import OpenaiApiType 49 | from graphrag.query.question_gen.local_gen import LocalQuestionGen 50 | from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext 51 | from graphrag.query.structured_search.local_search.search import LocalSearch 52 | from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext 53 | from graphrag.query.structured_search.global_search.search import GlobalSearch 54 | from graphrag.vector_stores.lancedb import LanceDBVectorStore 55 | 56 | # Set up logging 57 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 58 | logger = logging.getLogger(__name__) 59 | 60 | # Load environment variables 61 | graphrag_indexing_dir = 'indexing' 62 | project_root = os.path.abspath(os.path.dirname(__file__)) 63 | ROOT_DIR = os.path.join(project_root,graphrag_indexing_dir) 64 | env_file = os.path.join(ROOT_DIR, '.env') 65 | load_dotenv(env_file) 66 | LLM_API_BASE = os.getenv('LLM_API_BASE') 67 | LLM_MODEL = os.getenv('LLM_MODEL') 68 | LLM_PROVIDER = os.getenv('LLM_PROVIDER', 'openai').lower() 69 | EMBEDDINGS_API_BASE = os.getenv('EMBEDDINGS_API_BASE', '') 70 | EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL') 71 | EMBEDDINGS_PROVIDER = os.getenv('EMBEDDINGS_PROVIDER', 'openai').lower() 72 | OUTPUT_DIR = os.getenv('OUTPUT_DIR') 73 | PORT = int(os.getenv('API_PORT', 8012)) 74 | LANCEDB_URI = f"{ROOT_DIR}/lancedb" 75 | COMMUNITY_REPORT_TABLE = "create_final_community_reports" 76 | ENTITY_TABLE = "create_final_nodes" 77 | ENTITY_EMBEDDING_TABLE = "create_final_entities" 78 | RELATIONSHIP_TABLE = "create_final_relationships" 79 | COVARIATE_TABLE = "create_final_covariates" 80 | TEXT_UNIT_TABLE = "create_final_text_units" 81 | COMMUNITY_LEVEL = 2 82 | 83 | # Global variables for storing search engines and question generator 84 | local_search_engine = None 85 | global_search_engine = None 86 | question_generator = None 87 | 88 | # Data models 89 | class Message(BaseModel): 90 | role: str 91 | content: str 92 | 93 | class QueryOptions(BaseModel): 94 | query_type: str 95 | preset: Optional[str] = None 96 | community_level: Optional[int] = None 97 | response_type: Optional[str] = None 98 | custom_cli_args: Optional[str] = None 99 | selected_folder: Optional[str] = None 100 | 101 | class ChatCompletionRequest(BaseModel): 102 | model: str 103 | messages: List[Message] 104 | temperature: Optional[float] = 0.7 105 | max_tokens: Optional[int] = None 106 | stream: Optional[bool] = False 107 | query_options: Optional[QueryOptions] = None 108 | 109 | class ChatCompletionResponseChoice(BaseModel): 110 | index: int 111 | message: Message 112 | finish_reason: Optional[str] = None 113 | 114 | class Usage(BaseModel): 115 | prompt_tokens: int 116 | completion_tokens: int 117 | total_tokens: int 118 | 119 | class ChatCompletionResponse(BaseModel): 120 | id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") 121 | object: str = "chat.completion" 122 | created: int = Field(default_factory=lambda: int(time.time())) 123 | model: str 124 | choices: List[ChatCompletionResponseChoice] 125 | usage: Usage 126 | system_fingerprint: Optional[str] = None 127 | 128 | def list_output_folders(): 129 | return [f for f in os.listdir(OUTPUT_DIR) if os.path.isdir(os.path.join(OUTPUT_DIR, f))] 130 | 131 | def list_folder_contents(folder_name): 132 | folder_path = os.path.join(OUTPUT_DIR, folder_name, "artifacts") 133 | if not os.path.exists(folder_path): 134 | return [] 135 | return [item for item in os.listdir(folder_path) if item.endswith('.parquet')] 136 | 137 | def normalize_api_base(api_base: str) -> str: 138 | """Normalize the API base URL by removing trailing slashes and /v1 or /api suffixes.""" 139 | api_base = api_base.rstrip('/') 140 | if api_base.endswith('/v1') or api_base.endswith('/api'): 141 | api_base = api_base[:-3] 142 | return api_base 143 | 144 | def get_models_endpoint(api_base: str, api_type: str) -> str: 145 | """Get the appropriate models endpoint based on the API type.""" 146 | normalized_base = normalize_api_base(api_base) 147 | if api_type.lower() == 'openai': 148 | return f"{normalized_base}/v1/models" 149 | elif api_type.lower() == 'azure': 150 | return f"{normalized_base}/openai/deployments?api-version=2022-12-01" 151 | else: # For other API types (e.g., local LLMs) 152 | return f"{normalized_base}/models" 153 | 154 | async def fetch_available_models(settings: Dict[str, Any]) -> List[str]: 155 | """Fetch available models from the API.""" 156 | api_base = settings['api_base'] 157 | api_type = settings['api_type'] 158 | api_key = settings['api_key'] 159 | 160 | models_endpoint = get_models_endpoint(api_base, api_type) 161 | headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} 162 | 163 | try: 164 | response = requests.get(models_endpoint, headers=headers, timeout=10) 165 | response.raise_for_status() 166 | data = response.json() 167 | 168 | if api_type.lower() == 'openai': 169 | return [model['id'] for model in data['data']] 170 | elif api_type.lower() == 'azure': 171 | return [model['id'] for model in data['value']] 172 | else: 173 | # Adjust this based on the actual response format of your local LLM API 174 | return [model['name'] for model in data['models']] 175 | except requests.exceptions.RequestException as e: 176 | logger.error(f"Error fetching models: {str(e)}") 177 | return [] 178 | 179 | def load_settings(): 180 | config_path = os.getenv('GRAPHRAG_CONFIG', 'config.yaml') 181 | if os.path.exists(config_path): 182 | with open(config_path, 'r') as config_file: 183 | config = yaml.safe_load(config_file) 184 | else: 185 | config = {} 186 | 187 | settings = { 188 | 'llm_model': os.getenv('LLM_MODEL', config.get('llm_model')), 189 | 'embedding_model': os.getenv('EMBEDDINGS_MODEL', config.get('embedding_model')), 190 | 'community_level': int(os.getenv('COMMUNITY_LEVEL', config.get('community_level', 2))), 191 | 'token_limit': int(os.getenv('TOKEN_LIMIT', config.get('token_limit', 4096))), 192 | 'api_key': os.getenv('GRAPHRAG_API_KEY', config.get('api_key')), 193 | 'api_base': os.getenv('LLM_API_BASE', config.get('api_base')), 194 | 'embeddings_api_base': os.getenv('EMBEDDINGS_API_BASE', config.get('embeddings_api_base')), 195 | 'api_type': os.getenv('API_TYPE', config.get('api_type', 'openai')), 196 | } 197 | 198 | return settings 199 | 200 | return settings 201 | 202 | async def setup_llm_and_embedder(settings): 203 | logger.info("Setting up LLM and embedder") 204 | try: 205 | llm = ChatOpenAI( 206 | api_key=settings['api_key'], 207 | api_base=f"{settings['api_base']}/v1", 208 | model=settings['llm_model'], 209 | api_type=OpenaiApiType[settings['api_type'].capitalize()], 210 | max_retries=20, 211 | ) 212 | 213 | token_encoder = tiktoken.get_encoding("cl100k_base") 214 | 215 | text_embedder = OpenAIEmbedding( 216 | api_key=settings['api_key'], 217 | api_base=f"{settings['embeddings_api_base']}/v1", 218 | api_type=OpenaiApiType[settings['api_type'].capitalize()], 219 | model=settings['embedding_model'], 220 | deployment_name=settings['embedding_model'], 221 | max_retries=20, 222 | ) 223 | 224 | logger.info("LLM and embedder setup complete") 225 | return llm, token_encoder, text_embedder 226 | except Exception as e: 227 | logger.error(f"Error setting up LLM and embedder: {str(e)}") 228 | raise HTTPException(status_code=500, detail=f"Failed to set up LLM and embedder: {str(e)}") 229 | 230 | async def load_context(selected_folder, settings): 231 | """ 232 | Load context data including entities, relationships, reports, text units, and covariates 233 | """ 234 | logger.info("Loading context data") 235 | try: 236 | input_dir = os.path.join(OUTPUT_DIR, selected_folder, "artifacts") 237 | entity_df = pd.read_parquet(f"{input_dir}/{ENTITY_TABLE}.parquet") 238 | entity_embedding_df = pd.read_parquet(f"{input_dir}/{ENTITY_EMBEDDING_TABLE}.parquet") 239 | entities = read_indexer_entities(entity_df, entity_embedding_df, settings['community_level']) 240 | 241 | description_embedding_store = LanceDBVectorStore(collection_name="entity_description_embeddings") 242 | description_embedding_store.connect(db_uri=LANCEDB_URI) 243 | store_entity_semantic_embeddings(entities=entities, vectorstore=description_embedding_store) 244 | 245 | relationship_df = pd.read_parquet(f"{input_dir}/{RELATIONSHIP_TABLE}.parquet") 246 | relationships = read_indexer_relationships(relationship_df) 247 | 248 | report_df = pd.read_parquet(f"{input_dir}/{COMMUNITY_REPORT_TABLE}.parquet") 249 | reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL) 250 | 251 | text_unit_df = pd.read_parquet(f"{input_dir}/{TEXT_UNIT_TABLE}.parquet") 252 | text_units = read_indexer_text_units(text_unit_df) 253 | 254 | covariate_df = pd.read_parquet(f"{input_dir}/{COVARIATE_TABLE}.parquet") 255 | claims = read_indexer_covariates(covariate_df) 256 | logger.info(f"Number of claim records: {len(claims)}") 257 | covariates = {"claims": claims} 258 | 259 | logger.info("Context data loading complete") 260 | return entities, relationships, reports, text_units, description_embedding_store, covariates 261 | except Exception as e: 262 | logger.error(f"Error loading context data: {str(e)}") 263 | raise 264 | 265 | async def setup_search_engines(llm, token_encoder, text_embedder, entities, relationships, reports, text_units, 266 | description_embedding_store, covariates): 267 | """ 268 | Set up local and global search engines 269 | """ 270 | logger.info("Setting up search engines") 271 | 272 | # Set up local search engine 273 | local_context_builder = LocalSearchMixedContext( 274 | community_reports=reports, 275 | text_units=text_units, 276 | entities=entities, 277 | relationships=relationships, 278 | covariates=covariates, 279 | entity_text_embeddings=description_embedding_store, 280 | embedding_vectorstore_key=EntityVectorStoreKey.ID, 281 | text_embedder=text_embedder, 282 | token_encoder=token_encoder, 283 | ) 284 | 285 | local_context_params = { 286 | "text_unit_prop": 0.5, 287 | "community_prop": 0.1, 288 | "conversation_history_max_turns": 5, 289 | "conversation_history_user_turns_only": True, 290 | "top_k_mapped_entities": 10, 291 | "top_k_relationships": 10, 292 | "include_entity_rank": True, 293 | "include_relationship_weight": True, 294 | "include_community_rank": False, 295 | "return_candidate_context": False, 296 | "embedding_vectorstore_key": EntityVectorStoreKey.ID, 297 | "max_tokens": 12_000, 298 | } 299 | 300 | local_llm_params = { 301 | "max_tokens": 2_000, 302 | "temperature": 0.0, 303 | } 304 | 305 | local_search_engine = LocalSearch( 306 | llm=llm, 307 | context_builder=local_context_builder, 308 | token_encoder=token_encoder, 309 | llm_params=local_llm_params, 310 | context_builder_params=local_context_params, 311 | response_type="multiple paragraphs", 312 | ) 313 | 314 | # Set up global search engine 315 | global_context_builder = GlobalCommunityContext( 316 | community_reports=reports, 317 | entities=entities, 318 | token_encoder=token_encoder, 319 | ) 320 | 321 | global_context_builder_params = { 322 | "use_community_summary": False, 323 | "shuffle_data": True, 324 | "include_community_rank": True, 325 | "min_community_rank": 0, 326 | "community_rank_name": "rank", 327 | "include_community_weight": True, 328 | "community_weight_name": "occurrence weight", 329 | "normalize_community_weight": True, 330 | "max_tokens": 12_000, 331 | "context_name": "Reports", 332 | } 333 | 334 | map_llm_params = { 335 | "max_tokens": 1000, 336 | "temperature": 0.0, 337 | "response_format": {"type": "json_object"}, 338 | } 339 | 340 | reduce_llm_params = { 341 | "max_tokens": 2000, 342 | "temperature": 0.0, 343 | } 344 | 345 | global_search_engine = GlobalSearch( 346 | llm=llm, 347 | context_builder=global_context_builder, 348 | token_encoder=token_encoder, 349 | max_data_tokens=12_000, 350 | map_llm_params=map_llm_params, 351 | reduce_llm_params=reduce_llm_params, 352 | allow_general_knowledge=False, 353 | json_mode=True, 354 | context_builder_params=global_context_builder_params, 355 | concurrent_coroutines=32, 356 | response_type="multiple paragraphs", 357 | ) 358 | 359 | logger.info("Search engines setup complete") 360 | return local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params 361 | 362 | def format_response(response): 363 | """ 364 | Format the response by adding appropriate line breaks and paragraph separations. 365 | """ 366 | paragraphs = re.split(r'\n{2,}', response) 367 | 368 | formatted_paragraphs = [] 369 | for para in paragraphs: 370 | if '```' in para: 371 | parts = para.split('```') 372 | for i, part in enumerate(parts): 373 | if i % 2 == 1: # This is a code block 374 | parts[i] = f"\n```\n{part.strip()}\n```\n" 375 | para = ''.join(parts) 376 | else: 377 | para = para.replace('. ', '.\n') 378 | 379 | formatted_paragraphs.append(para.strip()) 380 | 381 | return '\n\n'.join(formatted_paragraphs) 382 | 383 | @asynccontextmanager 384 | async def lifespan(app: FastAPI): 385 | global settings 386 | try: 387 | logger.info("Loading settings...") 388 | settings = load_settings() 389 | logger.info("Settings loaded successfully.") 390 | except Exception as e: 391 | logger.error(f"Error loading settings: {str(e)}") 392 | raise 393 | 394 | yield 395 | 396 | logger.info("Shutting down...") 397 | 398 | app = FastAPI(lifespan=lifespan) 399 | 400 | # Create a cache for loaded contexts 401 | context_cache = {} 402 | 403 | @lru_cache() 404 | def get_settings(): 405 | return load_settings() 406 | 407 | async def get_context(selected_folder: str, settings: dict = Depends(get_settings)): 408 | if selected_folder not in context_cache: 409 | try: 410 | llm, token_encoder, text_embedder = await setup_llm_and_embedder(settings) 411 | entities, relationships, reports, text_units, description_embedding_store, covariates = await load_context(selected_folder, settings) 412 | local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params = await setup_search_engines( 413 | llm, token_encoder, text_embedder, entities, relationships, reports, text_units, 414 | description_embedding_store, covariates 415 | ) 416 | question_generator = LocalQuestionGen( 417 | llm=llm, 418 | context_builder=local_context_builder, 419 | token_encoder=token_encoder, 420 | llm_params=local_llm_params, 421 | context_builder_params=local_context_params, 422 | ) 423 | context_cache[selected_folder] = { 424 | "local_search_engine": local_search_engine, 425 | "global_search_engine": global_search_engine, 426 | "question_generator": question_generator 427 | } 428 | except Exception as e: 429 | logger.error(f"Error loading context for folder {selected_folder}: {str(e)}") 430 | raise HTTPException(status_code=500, detail=f"Failed to load context for folder {selected_folder}") 431 | 432 | return context_cache[selected_folder] 433 | 434 | @app.post("/v1/chat/completions") 435 | async def chat_completions(request: ChatCompletionRequest): 436 | try: 437 | logger.info(f"Received request for model: {request.model}") 438 | if request.model == "direct-chat": 439 | logger.info("Routing to direct chat") 440 | return await run_direct_chat(request) 441 | elif request.model.startswith("graphrag-"): 442 | logger.info("Routing to GraphRAG query") 443 | if not request.query_options or not request.query_options.selected_folder: 444 | raise HTTPException(status_code=400, detail="Selected folder is required for GraphRAG queries") 445 | return await run_graphrag_query(request) 446 | elif request.model == "duckduckgo-search:latest": 447 | logger.info("Routing to DuckDuckGo search") 448 | return await run_duckduckgo_search(request) 449 | elif request.model == "full-model:latest": 450 | logger.info("Routing to full model search") 451 | return await run_full_model_search(request) 452 | else: 453 | raise HTTPException(status_code=400, detail=f"Invalid model specified: {request.model}") 454 | except HTTPException as he: 455 | logger.error(f"HTTP Exception: {str(he)}") 456 | raise he 457 | except Exception as e: 458 | logger.error(f"Error in chat completion: {str(e)}", exc_info=True) 459 | raise HTTPException(status_code=500, detail=str(e)) 460 | 461 | async def run_direct_chat(request: ChatCompletionRequest) -> ChatCompletionResponse: 462 | try: 463 | if not LLM_API_BASE: 464 | raise ValueError("LLM_API_BASE environment variable is not set") 465 | 466 | headers = {"Content-Type": "application/json"} 467 | 468 | payload = { 469 | "model": LLM_MODEL, 470 | "messages": [{"role": msg.role, "content": msg.content} for msg in request.messages], 471 | "stream": False 472 | } 473 | 474 | # Optional parameters 475 | if request.temperature is not None: 476 | payload["temperature"] = request.temperature 477 | if request.max_tokens is not None: 478 | payload["max_tokens"] = request.max_tokens 479 | 480 | full_url = f"{normalize_api_base(LLM_API_BASE)}/v1/chat/completions" 481 | 482 | logger.info(f"Sending request to: {full_url}") 483 | logger.info(f"Payload: {payload}") 484 | 485 | try: 486 | response = requests.post(full_url, json=payload, headers=headers, timeout=10) 487 | response.raise_for_status() 488 | except requests.exceptions.RequestException as req_ex: 489 | logger.error(f"Request to LLM API failed: {str(req_ex)}") 490 | if isinstance(req_ex, requests.exceptions.ConnectionError): 491 | raise HTTPException(status_code=503, detail="Unable to connect to LLM API. Please check your API settings.") 492 | elif isinstance(req_ex, requests.exceptions.Timeout): 493 | raise HTTPException(status_code=504, detail="Request to LLM API timed out") 494 | else: 495 | raise HTTPException(status_code=500, detail=f"Request to LLM API failed: {str(req_ex)}") 496 | 497 | result = response.json() 498 | logger.info(f"Received response: {result}") 499 | 500 | content = result['choices'][0]['message']['content'] 501 | 502 | return ChatCompletionResponse( 503 | model=LLM_MODEL, 504 | choices=[ 505 | ChatCompletionResponseChoice( 506 | index=0, 507 | message=Message( 508 | role="assistant", 509 | content=content 510 | ), 511 | finish_reason=None 512 | ) 513 | ], 514 | usage=None 515 | ) 516 | except HTTPException as he: 517 | logger.error(f"HTTP Exception in direct chat: {str(he)}") 518 | raise he 519 | except Exception as e: 520 | logger.error(f"Unexpected error in direct chat: {str(e)}") 521 | raise HTTPException(status_code=500, detail=f"An unexpected error occurred during the direct chat: {str(e)}") 522 | 523 | def get_embeddings(text: str) -> List[float]: 524 | settings = load_settings() 525 | embeddings_api_base = settings['embeddings_api_base'] 526 | 527 | headers = {"Content-Type": "application/json"} 528 | 529 | if EMBEDDINGS_PROVIDER == 'ollama': 530 | payload = { 531 | "model": EMBEDDINGS_MODEL, 532 | "prompt": text 533 | } 534 | full_url = f"{embeddings_api_base}/api/embeddings" 535 | else: # OpenAI-compatible API 536 | payload = { 537 | "model": EMBEDDINGS_MODEL, 538 | "input": text 539 | } 540 | full_url = f"{embeddings_api_base}/v1/embeddings" 541 | 542 | try: 543 | response = requests.post(full_url, json=payload, headers=headers) 544 | response.raise_for_status() 545 | except requests.exceptions.RequestException as req_ex: 546 | logger.error(f"Request to Embeddings API failed: {str(req_ex)}") 547 | raise HTTPException(status_code=500, detail=f"Failed to get embeddings: {str(req_ex)}") 548 | 549 | result = response.json() 550 | 551 | if EMBEDDINGS_PROVIDER == 'ollama': 552 | return result['embedding'] 553 | else: 554 | return result['data'][0]['embedding'] 555 | 556 | 557 | async def run_graphrag_query(request: ChatCompletionRequest) -> ChatCompletionResponse: 558 | try: 559 | query_options = request.query_options 560 | query = request.messages[-1].content # Get the last user message as the query 561 | 562 | cmd = ["python", "-m", "graphrag.query"] 563 | cmd.extend(["--data", f"{ROOT_DIR}/output/{query_options.selected_folder}/artifacts"]) 564 | cmd.extend(["--method", query_options.query_type.split('-')[1]]) # 'global' or 'local' 565 | 566 | if query_options.community_level: 567 | cmd.extend(["--community_level", str(query_options.community_level)]) 568 | if query_options.response_type: 569 | cmd.extend(["--response_type", query_options.response_type]) 570 | 571 | # Handle preset CLI args 572 | if query_options.preset and query_options.preset != "Custom Query": 573 | preset_args = get_preset_args(query_options.preset) 574 | cmd.extend(preset_args) 575 | 576 | # Handle custom CLI args 577 | if query_options.custom_cli_args: 578 | cmd.extend(query_options.custom_cli_args.split()) 579 | 580 | cmd.append(query) 581 | 582 | logger.info(f"Executing GraphRAG query: {' '.join(cmd)}") 583 | 584 | result = subprocess.run(cmd, capture_output=True, text=True) 585 | if result.returncode != 0: 586 | raise Exception(f"GraphRAG query failed: {result.stderr}") 587 | 588 | return ChatCompletionResponse( 589 | model=request.model, 590 | choices=[ 591 | ChatCompletionResponseChoice( 592 | index=0, 593 | message=Message( 594 | role="assistant", 595 | content=result.stdout 596 | ), 597 | finish_reason="stop" 598 | ) 599 | ], 600 | usage=Usage( 601 | prompt_tokens=0, 602 | completion_tokens=0, 603 | total_tokens=0 604 | ) 605 | ) 606 | except Exception as e: 607 | logger.error(f"Error in GraphRAG query: {str(e)}") 608 | raise HTTPException(status_code=500, detail=f"An error occurred during the GraphRAG query: {str(e)}") 609 | 610 | 611 | def get_preset_args(preset: str) -> List[str]: 612 | preset_args = { 613 | "Default Global Search": ["--community_level", "2", "--response_type", "Multiple Paragraphs"], 614 | "Default Local Search": ["--community_level", "2", "--response_type", "Multiple Paragraphs"], 615 | "Detailed Global Analysis": ["--community_level", "3", "--response_type", "Multi-Page Report"], 616 | "Detailed Local Analysis": ["--community_level", "3", "--response_type", "Multi-Page Report"], 617 | "Quick Global Summary": ["--community_level", "1", "--response_type", "Single Paragraph"], 618 | "Quick Local Summary": ["--community_level", "1", "--response_type", "Single Paragraph"], 619 | "Global Bullet Points": ["--community_level", "2", "--response_type", "List of 3-7 Points"], 620 | "Local Bullet Points": ["--community_level", "2", "--response_type", "List of 3-7 Points"], 621 | "Comprehensive Global Report": ["--community_level", "4", "--response_type", "Multi-Page Report"], 622 | "Comprehensive Local Report": ["--community_level", "4", "--response_type", "Multi-Page Report"], 623 | "High-Level Global Overview": ["--community_level", "1", "--response_type", "Single Page"], 624 | "High-Level Local Overview": ["--community_level", "1", "--response_type", "Single Page"], 625 | "Focused Global Insight": ["--community_level", "3", "--response_type", "Single Paragraph"], 626 | "Focused Local Insight": ["--community_level", "3", "--response_type", "Single Paragraph"], 627 | } 628 | return preset_args.get(preset, []) 629 | 630 | ddg_search = DuckDuckGoSearchAPIWrapper(max_results=5) 631 | 632 | async def run_duckduckgo_search(request: ChatCompletionRequest) -> ChatCompletionResponse: 633 | query = request.messages[-1].content 634 | results = ddg_search.results(query, max_results=5) 635 | 636 | if not results: 637 | content = "No results found for the given query." 638 | else: 639 | content = "DuckDuckGo Search Results:\n\n" 640 | for result in results: 641 | content += f"Title: {result['title']}\n" 642 | content += f"Snippet: {result['snippet']}\n" 643 | content += f"Link: {result['link']}\n" 644 | if 'date' in result: 645 | content += f"Date: {result['date']}\n" 646 | if 'source' in result: 647 | content += f"Source: {result['source']}\n" 648 | content += "\n" 649 | 650 | return ChatCompletionResponse( 651 | model=request.model, 652 | choices=[ 653 | ChatCompletionResponseChoice( 654 | index=0, 655 | message=Message( 656 | role="assistant", 657 | content=content 658 | ), 659 | finish_reason="stop" 660 | ) 661 | ], 662 | usage=Usage( 663 | prompt_tokens=0, 664 | completion_tokens=0, 665 | total_tokens=0 666 | ) 667 | ) 668 | 669 | async def run_full_model_search(request: ChatCompletionRequest) -> ChatCompletionResponse: 670 | query = request.messages[-1].content 671 | 672 | # Run all search types 673 | graphrag_global = await run_graphrag_query(ChatCompletionRequest(model="graphrag-global-search:latest", messages=request.messages, query_options=request.query_options)) 674 | graphrag_local = await run_graphrag_query(ChatCompletionRequest(model="graphrag-local-search:latest", messages=request.messages, query_options=request.query_options)) 675 | duckduckgo = await run_duckduckgo_search(request) 676 | 677 | # Combine results 678 | combined_content = f"""Full Model Search Results: 679 | 680 | Global Search: 681 | {graphrag_global.choices[0].message.content} 682 | 683 | Local Search: 684 | {graphrag_local.choices[0].message.content} 685 | 686 | DuckDuckGo Search: 687 | {duckduckgo.choices[0].message.content} 688 | """ 689 | 690 | return ChatCompletionResponse( 691 | model=request.model, 692 | choices=[ 693 | ChatCompletionResponseChoice( 694 | index=0, 695 | message=Message( 696 | role="assistant", 697 | content=combined_content 698 | ), 699 | finish_reason="stop" 700 | ) 701 | ], 702 | usage=Usage( 703 | prompt_tokens=0, 704 | completion_tokens=0, 705 | total_tokens=0 706 | ) 707 | ) 708 | 709 | @app.get("/health") 710 | async def health_check(): 711 | return {"status": "ok"} 712 | 713 | @app.get("/v1/models") 714 | async def list_models(): 715 | settings = load_settings() 716 | try: 717 | api_models = await fetch_available_models(settings) 718 | except Exception as e: 719 | logger.error(f"Error fetching API models: {str(e)}") 720 | api_models = [] 721 | 722 | # Include the hardcoded models 723 | hardcoded_models = [ 724 | {"id": "graphrag-local-search:latest", "object": "model", "owned_by": "graphrag"}, 725 | {"id": "graphrag-global-search:latest", "object": "model", "owned_by": "graphrag"}, 726 | {"id": "duckduckgo-search:latest", "object": "model", "owned_by": "duckduckgo"}, 727 | {"id": "full-model:latest", "object": "model", "owned_by": "combined"}, 728 | ] 729 | 730 | # Combine API models with hardcoded models 731 | all_models = [{"id": model, "object": "model", "owned_by": "api"} for model in api_models] + hardcoded_models 732 | 733 | return JSONResponse(content={"data": all_models}) 734 | 735 | class PromptTuneRequest(BaseModel): 736 | root: str = "./{ROOT_DIR}" 737 | config: str = "./{ROOT_DIR}/settings.yaml" 738 | domain: Optional[str] = None 739 | method: str = "random" 740 | limit: int = 15 741 | language: Optional[str] = None 742 | max_tokens: int = 2000 743 | chunk_size: int = 200 744 | no_entity_types: bool = False 745 | output: str = "./{ROOT_DIR}/prompts" 746 | 747 | class PromptTuneResponse(BaseModel): 748 | status: str 749 | message: str 750 | 751 | # Global variable to store the latest logs 752 | prompt_tune_logs = deque(maxlen=100) 753 | 754 | async def run_prompt_tuning(request: PromptTuneRequest): 755 | cmd = ["python", "-m", "graphrag.prompt_tune"] 756 | 757 | # Create a temporary directory for output 758 | with tempfile.TemporaryDirectory() as temp_output: 759 | # Expand environment variables in the root path 760 | root_path = os.path.expandvars(request.root) 761 | 762 | cmd.extend(["--root", root_path]) 763 | cmd.extend(["--config", request.config]) 764 | cmd.extend(["--selection-method", request.method]) 765 | cmd.extend(["--limit", str(request.limit)]) 766 | 767 | if request.domain: 768 | cmd.extend(["--domain", request.domain]) 769 | 770 | if request.language: 771 | cmd.extend(["--language", request.language]) 772 | 773 | cmd.extend(["--max-tokens", str(request.max_tokens)]) 774 | cmd.extend(["--chunk-size", str(request.chunk_size)]) 775 | 776 | if request.no_entity_types: 777 | cmd.append("--no-entity-types") 778 | 779 | # Use the temporary directory for output 780 | cmd.extend(["--output", temp_output]) 781 | 782 | logger.info(f"Executing prompt tuning command: {' '.join(cmd)}") 783 | 784 | try: 785 | process = await asyncio.create_subprocess_exec( 786 | *cmd, 787 | stdout=asyncio.subprocess.PIPE, 788 | stderr=asyncio.subprocess.PIPE 789 | ) 790 | 791 | async def read_stream(stream): 792 | while True: 793 | line = await stream.readline() 794 | if not line: 795 | break 796 | line = line.decode().strip() 797 | prompt_tune_logs.append(line) 798 | logger.info(line) 799 | 800 | await asyncio.gather( 801 | read_stream(process.stdout), 802 | read_stream(process.stderr) 803 | ) 804 | 805 | await process.wait() 806 | 807 | if process.returncode == 0: 808 | logger.info("Prompt tuning completed successfully") 809 | 810 | # Replace the existing template files with the newly generated prompts 811 | dest_dir = os.path.join(ROOT_DIR, "prompts") 812 | 813 | for filename in os.listdir(temp_output): 814 | if filename.endswith(".txt"): 815 | source_file = os.path.join(temp_output, filename) 816 | dest_file = os.path.join(dest_dir, filename) 817 | shutil.move(source_file, dest_file) 818 | logger.info(f"Replaced {filename} in {dest_file}") 819 | 820 | return PromptTuneResponse(status="success", message="Prompt tuning completed successfully. Existing prompts have been replaced.") 821 | else: 822 | logger.error("Prompt tuning failed") 823 | return PromptTuneResponse(status="error", message="Prompt tuning failed. Check logs for details.") 824 | except Exception as e: 825 | logger.error(f"Prompt tuning failed: {str(e)}") 826 | return PromptTuneResponse(status="error", message=f"Prompt tuning failed: {str(e)}") 827 | 828 | @app.post("/v1/prompt_tune") 829 | async def prompt_tune(request: PromptTuneRequest, background_tasks: BackgroundTasks): 830 | background_tasks.add_task(run_prompt_tuning, request) 831 | return {"status": "started", "message": "Prompt tuning process has been started in the background"} 832 | 833 | @app.get("/v1/prompt_tune_status") 834 | async def prompt_tune_status(): 835 | return { 836 | "status": "running" if prompt_tune_logs else "idle", 837 | "logs": list(prompt_tune_logs) 838 | } 839 | 840 | class IndexingRequest(BaseModel): 841 | llm_model: str 842 | embed_model: str 843 | llm_api_base: str 844 | embed_api_base: str 845 | root: str 846 | verbose: bool = False 847 | nocache: bool = False 848 | resume: Optional[str] = None 849 | reporter: str = "rich" 850 | emit: List[str] = ["parquet"] 851 | custom_args: Optional[str] = None 852 | llm_params: Dict[str, Any] = Field(default_factory=dict) 853 | embed_params: Dict[str, Any] = Field(default_factory=dict) 854 | 855 | # Global variable to store the latest indexing logs 856 | indexing_logs = deque(maxlen=100) 857 | 858 | async def run_indexing(request: IndexingRequest): 859 | cmd = ["python", "-m", "graphrag.index"] 860 | 861 | cmd.extend(["--root", request.root]) 862 | 863 | if request.verbose: 864 | cmd.append("--verbose") 865 | 866 | if request.nocache: 867 | cmd.append("--nocache") 868 | 869 | if request.resume: 870 | cmd.extend(["--resume", request.resume]) 871 | 872 | cmd.extend(["--reporter", request.reporter]) 873 | cmd.extend(["--emit", ",".join(request.emit)]) 874 | 875 | # Set environment variables for LLM and embedding models 876 | env: Dict[str, Any] = os.environ.copy() 877 | env["GRAPHRAG_LLM_MODEL"] = request.llm_model 878 | env["GRAPHRAG_EMBED_MODEL"] = request.embed_model 879 | env["GRAPHRAG_LLM_API_BASE"] = LLM_API_BASE 880 | env["GRAPHRAG_EMBED_API_BASE"] = EMBEDDINGS_API_BASE 881 | 882 | # Set environment variables for LLM parameters 883 | for key, value in request.llm_params.items(): 884 | env[f"GRAPHRAG_LLM_{key.upper()}"] = str(value) 885 | 886 | # Set environment variables for embedding parameters 887 | for key, value in request.embed_params.items(): 888 | env[f"GRAPHRAG_EMBED_{key.upper()}"] = str(value) 889 | 890 | # Add custom CLI arguments 891 | if request.custom_args: 892 | cmd.extend(request.custom_args.split()) 893 | 894 | logger.info(f"Executing indexing command: {' '.join(cmd)}") 895 | logger.info(f"Environment variables: {env}") 896 | 897 | try: 898 | process = await asyncio.create_subprocess_exec( 899 | *cmd, 900 | stdout=asyncio.subprocess.PIPE, 901 | stderr=asyncio.subprocess.PIPE, 902 | env=env 903 | ) 904 | 905 | async def read_stream(stream): 906 | while True: 907 | line = await stream.readline() 908 | if not line: 909 | break 910 | line = line.decode().strip() 911 | indexing_logs.append(line) 912 | logger.info(line) 913 | 914 | await asyncio.gather( 915 | read_stream(process.stdout), 916 | read_stream(process.stderr) 917 | ) 918 | 919 | await process.wait() 920 | 921 | if process.returncode == 0: 922 | logger.info("Indexing completed successfully") 923 | return {"status": "success", "message": "Indexing completed successfully"} 924 | else: 925 | logger.error("Indexing failed") 926 | return {"status": "error", "message": "Indexing failed. Check logs for details."} 927 | except Exception as e: 928 | logger.error(f"Indexing failed: {str(e)}") 929 | return {"status": "error", "message": f"Indexing failed: {str(e)}"} 930 | 931 | 932 | @app.post("/v1/index") 933 | async def start_indexing(request: IndexingRequest, background_tasks: BackgroundTasks): 934 | background_tasks.add_task(run_indexing, request) 935 | return {"status": "started", "message": "Indexing process has been started in the background"} 936 | 937 | @app.get("/v1/index_status") 938 | async def indexing_status(): 939 | return { 940 | "status": "running" if indexing_logs else "idle", 941 | "logs": list(indexing_logs) 942 | } 943 | 944 | def main(): 945 | parser = argparse.ArgumentParser(description="Launch the GraphRAG API server") 946 | parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to") 947 | parser.add_argument("--port", type=int, default=PORT, help="Port to bind the server to") 948 | parser.add_argument("--reload", action="store_true", help="Enable auto-reload mode") 949 | args = parser.parse_args() 950 | 951 | import uvicorn 952 | 953 | if __name__ == "__main__": 954 | app_name = 'api:app' 955 | else: 956 | app_name = 'graphrag_ui.api:app' 957 | 958 | uvicorn.run( 959 | app_name, 960 | host=args.host, 961 | port=args.port, 962 | reload=args.reload 963 | ) 964 | 965 | if __name__ == "__main__": 966 | main() 967 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | try: 3 | import graphrag 4 | except ImportError: 5 | print("The 'graphrag' package is not installed. Please install it using 'pip install graphrag'.Since the dependency package `aiofiles` of `graphrag` conflicts with the requirements of `gradio`, it is necessary to manually install `graphrag` separately.") 6 | sys.exit(1) 7 | import gradio as gr 8 | from gradio.helpers import Progress 9 | import asyncio 10 | import subprocess 11 | import yaml 12 | import os 13 | import networkx as nx 14 | import plotly.graph_objects as go 15 | import numpy as np 16 | import plotly.io as pio 17 | import lancedb 18 | import random 19 | import io 20 | import shutil 21 | import logging 22 | import queue 23 | import threading 24 | import time 25 | from collections import deque 26 | import re 27 | import glob 28 | from datetime import datetime 29 | import json 30 | import requests 31 | import aiohttp 32 | from openai import OpenAI 33 | from openai import AsyncOpenAI 34 | import pyarrow.parquet as pq 35 | import pandas as pd 36 | import colorsys 37 | from dotenv import load_dotenv, set_key 38 | import argparse 39 | import socket 40 | import tiktoken 41 | from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey 42 | from graphrag.query.indexer_adapters import ( 43 | read_indexer_covariates, 44 | read_indexer_entities, 45 | read_indexer_relationships, 46 | read_indexer_reports, 47 | read_indexer_text_units, 48 | ) 49 | from graphrag.llm.openai import create_openai_chat_llm 50 | from graphrag.llm.openai.factories import create_openai_embedding_llm 51 | from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings 52 | from graphrag.query.llm.oai.chat_openai import ChatOpenAI 53 | from graphrag.llm.openai.openai_configuration import OpenAIConfiguration 54 | from graphrag.llm.openai.openai_embeddings_llm import OpenAIEmbeddingsLLM 55 | from graphrag.query.llm.oai.typing import OpenaiApiType 56 | from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext 57 | from graphrag.query.structured_search.local_search.search import LocalSearch 58 | from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext 59 | from graphrag.query.structured_search.global_search.search import GlobalSearch 60 | from graphrag.vector_stores.lancedb import LanceDBVectorStore 61 | import textwrap 62 | 63 | 64 | 65 | # Suppress warnings 66 | import warnings 67 | warnings.filterwarnings("ignore", category=UserWarning, module="gradio_client.documentation") 68 | 69 | graphrag_indexing_dir = 'indexing' 70 | project_root = os.path.abspath(os.path.dirname(__file__)) 71 | ROOT_DIR = os.path.join(project_root,graphrag_indexing_dir) 72 | 73 | env_file = os.path.join(ROOT_DIR, '.env') 74 | load_dotenv(env_file) 75 | 76 | # LLM 相关配置 77 | LLM_API_BASE = os.getenv('LLM_API_BASE') 78 | LLM_MODEL = os.getenv('LLM_MODEL') 79 | LLM_API_KEY = os.getenv('LLM_API_KEY') 80 | LLM_SERVICE_TYPE = os.getenv('LLM_SERVICE_TYPE') 81 | 82 | # EMBEDDINGS 相关配置 83 | EMBEDDINGS_API_BASE = os.getenv('EMBEDDINGS_API_BASE') 84 | EMBEDDINGS_MODEL = os.getenv('EMBEDDINGS_MODEL') 85 | EMBEDDINGS_API_KEY = os.getenv('EMBEDDINGS_API_KEY') 86 | EMBEDDINGS_SERVICE_TYPE = os.getenv('EMBEDDINGS_SERVICE_TYPE') 87 | 88 | # 其他配置 89 | ROOT_DIR = os.path.join(project_root,graphrag_indexing_dir) 90 | INPUT_DIR = os.getenv('INPUT_DIR') 91 | 92 | 93 | # Set up logging 94 | log_queue = queue.Queue() 95 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 96 | 97 | 98 | llm = None 99 | text_embedder = None 100 | 101 | class QueueHandler(logging.Handler): 102 | def __init__(self, log_queue): 103 | super().__init__() 104 | self.log_queue = log_queue 105 | 106 | def emit(self, record): 107 | self.log_queue.put(self.format(record)) 108 | queue_handler = QueueHandler(log_queue) 109 | logging.getLogger().addHandler(queue_handler) 110 | 111 | 112 | 113 | def initialize_models(): 114 | global llm, text_embedder 115 | logging.info("Fetching models...") 116 | models = fetch_models(LLM_API_BASE, LLM_API_KEY, LLM_SERVICE_TYPE) 117 | 118 | # Use the same models list for both LLM and embeddings 119 | llm_models = models 120 | embeddings_models = models 121 | 122 | # Initialize LLM 123 | if LLM_SERVICE_TYPE == "openai_chat": 124 | llm = ChatOpenAI( 125 | api_key=LLM_API_KEY, 126 | api_base=f"{LLM_API_BASE}/v1", 127 | model=LLM_MODEL, 128 | api_type=OpenaiApiType.OpenAI, 129 | max_retries=20, 130 | ) 131 | # Initialize OpenAI client for embeddings 132 | openai_client = OpenAI( 133 | api_key=EMBEDDINGS_API_KEY or "dummy_key", 134 | base_url=f"{EMBEDDINGS_API_BASE}/v1" 135 | ) 136 | 137 | # Initialize text embedder using OpenAIEmbeddingsLLM 138 | text_embedder = OpenAIEmbeddingsLLM( 139 | client=openai_client, 140 | configuration={ 141 | "model": EMBEDDINGS_MODEL, 142 | "api_type": "open_ai", 143 | "api_base": EMBEDDINGS_API_BASE, 144 | "api_key": EMBEDDINGS_API_KEY or None, 145 | "provider": EMBEDDINGS_SERVICE_TYPE 146 | } 147 | ) 148 | 149 | return llm_models, embeddings_models, 1, 1, 1, 1, text_embedder 150 | # return llm_models, embeddings_models, llm_service_type, embeddings_service_type, llm_api_base, embeddings_api_base, text_embedder 151 | 152 | def find_latest_output_folder(): 153 | root_dir = os.path.join(ROOT_DIR,'output') 154 | folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))] 155 | 156 | if folders is None: 157 | raise ValueError("No output folders found") 158 | elif len(folders) == 0: 159 | return None,None 160 | 161 | # Sort folders by creation time, most recent first 162 | sorted_folders = sorted(folders, key=lambda x: os.path.getctime(os.path.join(root_dir, x)), reverse=True) 163 | 164 | latest_folder = None 165 | timestamp = None 166 | 167 | for folder in sorted_folders: 168 | try: 169 | # Try to parse the folder name as a timestamp 170 | timestamp = datetime.strptime(folder, "%Y%m%d-%H%M%S") 171 | latest_folder = folder 172 | break 173 | except ValueError: 174 | # If the folder name is not a valid timestamp, skip it 175 | continue 176 | 177 | if latest_folder is None: 178 | raise ValueError("No valid timestamp folders found") 179 | 180 | latest_path = os.path.join(root_dir, latest_folder) 181 | artifacts_path = os.path.join(latest_path, "artifacts") 182 | 183 | if not os.path.exists(artifacts_path): 184 | raise ValueError(f"Artifacts folder not found in {latest_path}") 185 | 186 | return latest_path, latest_folder 187 | 188 | def initialize_data(): 189 | global entity_df, relationship_df, text_unit_df, report_df, covariate_df 190 | 191 | tables = { 192 | "entity_df": "create_final_nodes", 193 | "relationship_df": "create_final_edges", 194 | "text_unit_df": "create_final_text_units", 195 | "report_df": "create_final_reports", 196 | "covariate_df": "create_final_covariates" 197 | } 198 | 199 | timestamp = None # Initialize timestamp to None 200 | 201 | try: 202 | latest_output_folder, timestamp = find_latest_output_folder() 203 | if latest_output_folder is None: 204 | return None 205 | artifacts_folder = os.path.join(latest_output_folder, "artifacts") 206 | 207 | for df_name, file_prefix in tables.items(): 208 | file_pattern = os.path.join(artifacts_folder, f"{file_prefix}*.parquet") 209 | matching_files = glob.glob(file_pattern) 210 | 211 | if matching_files: 212 | latest_file = max(matching_files, key=os.path.getctime) 213 | df = pd.read_parquet(latest_file) 214 | globals()[df_name] = df 215 | logging.info(f"Successfully loaded {df_name} from {latest_file}") 216 | else: 217 | logging.warning(f"No matching file found for {df_name} in {artifacts_folder}. Initializing as an empty DataFrame.") 218 | globals()[df_name] = pd.DataFrame() 219 | 220 | except Exception as e: 221 | logging.error(f"Error initializing data: {str(e)}") 222 | for df_name in tables.keys(): 223 | globals()[df_name] = pd.DataFrame() 224 | 225 | return timestamp 226 | 227 | # Call initialize_data and store the timestamp 228 | current_timestamp = initialize_data() 229 | 230 | 231 | def find_available_port(start_port, max_attempts=100): 232 | for port in range(start_port, start_port + max_attempts): 233 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 234 | try: 235 | s.bind(('', port)) 236 | return port 237 | except OSError: 238 | continue 239 | raise IOError("No free ports found") 240 | 241 | def start_api_server(port): 242 | subprocess.Popen([sys.executable, "api_server.py", "--port", str(port)]) 243 | 244 | def wait_for_api_server(port): 245 | max_retries = 30 246 | for _ in range(max_retries): 247 | try: 248 | response = requests.get(f"http://localhost:{port}") 249 | if response.status_code == 200: 250 | print(f"API server is up and running on port {port}") 251 | return 252 | else: 253 | print(f"Unexpected response from API server: {response.status_code}") 254 | except requests.ConnectionError: 255 | time.sleep(1) 256 | print("Failed to connect to API server") 257 | 258 | def load_settings(): 259 | try: 260 | with open(f"{ROOT_DIR}/settings.yaml", "r") as f: 261 | return yaml.safe_load(f) or {} 262 | except FileNotFoundError: 263 | return {} 264 | 265 | def update_setting(key, value): 266 | settings = load_settings() 267 | try: 268 | settings[key] = json.loads(value) 269 | except json.JSONDecodeError: 270 | settings[key] = value 271 | 272 | try: 273 | with open(os.path.join(ROOT_DIR,'"settings.yaml"'), "w") as f: 274 | yaml.dump(settings, f, default_flow_style=False) 275 | return f"Setting '{key}' updated successfully" 276 | except Exception as e: 277 | return f"Error updating setting '{key}': {str(e)}" 278 | 279 | def create_setting_component(key, value): 280 | with gr.Accordion(key, open=False): 281 | if isinstance(value, (dict, list)): 282 | value_str = json.dumps(value, indent=2) 283 | lines = value_str.count('\n') + 1 284 | else: 285 | value_str = str(value) 286 | lines = 1 287 | 288 | text_area = gr.TextArea(value=value_str, label="Value", lines=lines, max_lines=20) 289 | update_btn = gr.Button("Update", variant="primary") 290 | status = gr.Textbox(label="Status", visible=False) 291 | 292 | update_btn.click( 293 | fn=update_setting, 294 | inputs=[gr.Textbox(value=key, visible=False), text_area], 295 | outputs=[status] 296 | ).then( 297 | fn=lambda: gr.update(visible=True), 298 | outputs=[status] 299 | ) 300 | 301 | 302 | 303 | def get_openai_client(): 304 | return OpenAI( 305 | base_url=os.getenv("LLM_API_BASE"), 306 | api_key=os.getenv("LLM_API_KEY"), 307 | llm_model = os.getenv("LLM_MODEL") 308 | ) 309 | 310 | async def chat_with_openai(messages, model, temperature, max_tokens, api_base): 311 | client = AsyncOpenAI( 312 | base_url=api_base, 313 | api_key=os.getenv("LLM_API_KEY") 314 | ) 315 | 316 | try: 317 | response = await client.chat.completions.create( 318 | model=model, 319 | messages=messages, 320 | temperature=temperature, 321 | max_tokens=max_tokens 322 | ) 323 | return response.choices[0].message.content 324 | except Exception as e: 325 | logging.error(f"Error in chat_with_openai: {str(e)}") 326 | return f"An error occurred: {str(e)}" 327 | return f"Error: {str(e)}" 328 | 329 | def chat_with_llm(query, history, system_message, temperature, max_tokens, model, api_base): 330 | try: 331 | messages = [{"role": "system", "content": system_message}] 332 | for item in history: 333 | if isinstance(item, tuple) and len(item) == 2: 334 | human, ai = item 335 | messages.append({"role": "user", "content": human}) 336 | messages.append({"role": "assistant", "content": ai}) 337 | messages.append({"role": "user", "content": query}) 338 | 339 | logging.info(f"Sending chat request to {api_base} with model {model}") 340 | client = OpenAI(base_url=api_base, api_key=os.getenv("LLM_API_KEY", "dummy-key")) 341 | response = client.chat.completions.create( 342 | model=model, 343 | messages=messages, 344 | temperature=temperature, 345 | max_tokens=max_tokens 346 | ) 347 | return response.choices[0].message.content 348 | except Exception as e: 349 | logging.error(f"Error in chat_with_llm: {str(e)}") 350 | logging.error(f"Attempted with model: {model}, api_base: {api_base}") 351 | raise RuntimeError(f"Chat request failed: {str(e)}") 352 | 353 | def run_graphrag_query(cli_args): 354 | try: 355 | command = ' '.join(cli_args) 356 | logging.info(f"Executing command: {command}") 357 | result = subprocess.run(cli_args, capture_output=True, text=True, check=True) 358 | return result.stdout.strip() 359 | except subprocess.CalledProcessError as e: 360 | logging.error(f"Error running GraphRAG query: {e}") 361 | logging.error(f"Command output (stdout): {e.stdout}") 362 | logging.error(f"Command output (stderr): {e.stderr}") 363 | raise RuntimeError(f"GraphRAG query failed: {e.stderr}") 364 | 365 | def parse_query_response(response: str): 366 | try: 367 | # Split the response into metadata and content 368 | parts = response.split("\n\n", 1) 369 | if len(parts) < 2: 370 | return response # Return original response if it doesn't contain metadata 371 | 372 | metadata_str, content = parts 373 | metadata = json.loads(metadata_str) 374 | 375 | # Extract relevant information from metadata 376 | query_type = metadata.get("query_type", "Unknown") 377 | execution_time = metadata.get("execution_time", "N/A") 378 | tokens_used = metadata.get("tokens_used", "N/A") 379 | 380 | # Remove unwanted lines from the content 381 | content_lines = content.split('\n') 382 | filtered_content = '\n'.join([line for line in content_lines if not line.startswith("INFO:") and not line.startswith("creating llm client")]) 383 | 384 | # Format the parsed response 385 | parsed_response = f""" 386 | Query Type: {query_type} 387 | Execution Time: {execution_time} seconds 388 | Tokens Used: {tokens_used} 389 | 390 | {filtered_content.strip()} 391 | """ 392 | return parsed_response 393 | except Exception as e: 394 | print(f"Error parsing query response: {str(e)}") 395 | return response 396 | 397 | def send_message(query_type, query, history, system_message, temperature, max_tokens, preset, community_level, response_type, custom_cli_args, selected_folder): 398 | try: 399 | if query_type in ["global", "local"]: 400 | cli_args = construct_cli_args(query_type, preset, community_level, response_type, custom_cli_args, query, selected_folder) 401 | logging.info(f"Executing {query_type} search with command: {' '.join(cli_args)}") 402 | result = run_graphrag_query(cli_args) 403 | parsed_result = parse_query_response(result) 404 | logging.info(f"Parsed query result: {parsed_result}") 405 | else: # Direct chat 406 | llm_model = os.getenv("LLM_MODEL") 407 | api_base = os.getenv("LLM_API_BASE") 408 | logging.info(f"Executing direct chat with model: {llm_model}") 409 | 410 | try: 411 | result = chat_with_llm(query, history, system_message, temperature, max_tokens, llm_model, api_base) 412 | parsed_result = result # No parsing needed for direct chat 413 | logging.info(f"Direct chat result: {parsed_result[:100]}...") # Log first 100 chars of result 414 | except Exception as chat_error: 415 | logging.error(f"Error in chat_with_llm: {str(chat_error)}") 416 | raise RuntimeError(f"Direct chat failed: {str(chat_error)}") 417 | 418 | history.append((query, parsed_result)) 419 | except Exception as e: 420 | error_message = f"An error occurred: {str(e)}" 421 | logging.error(error_message) 422 | logging.exception("Exception details:") 423 | history.append((query, error_message)) 424 | 425 | return history, gr.update(value=""), update_logs() 426 | 427 | def construct_cli_args(query_type, preset, community_level, response_type, custom_cli_args, query, selected_folder): 428 | if not selected_folder: 429 | raise ValueError("No folder selected. Please select an output folder before querying.") 430 | 431 | artifacts_folder = os.path.join(f"{ROOT_DIR}/output", selected_folder, "artifacts") 432 | if not os.path.exists(artifacts_folder): 433 | raise ValueError(f"Artifacts folder not found in {artifacts_folder}") 434 | 435 | base_args = [ 436 | "python", "-m", "graphrag.query", 437 | "--root", ROOT_DIR, 438 | "--method", query_type, 439 | ] 440 | 441 | # Apply preset configurations 442 | if preset.startswith("Default"): 443 | base_args.extend(["--community_level", "2", "--response_type", "Multiple Paragraphs"]) 444 | elif preset.startswith("Detailed"): 445 | base_args.extend(["--community_level", "4", "--response_type", "Multi-Page Report"]) 446 | elif preset.startswith("Quick"): 447 | base_args.extend(["--community_level", "1", "--response_type", "Single Paragraph"]) 448 | elif preset.startswith("Bullet"): 449 | base_args.extend(["--community_level", "2", "--response_type", "List of 3-7 Points"]) 450 | elif preset.startswith("Comprehensive"): 451 | base_args.extend(["--community_level", "5", "--response_type", "Multi-Page Report"]) 452 | elif preset.startswith("High-Level"): 453 | base_args.extend(["--community_level", "1", "--response_type", "Single Page"]) 454 | elif preset.startswith("Focused"): 455 | base_args.extend(["--community_level", "3", "--response_type", "Multiple Paragraphs"]) 456 | elif preset == "Custom Query": 457 | base_args.extend([ 458 | "--community_level", str(community_level), 459 | "--response_type", f'"{response_type}"', 460 | ]) 461 | if custom_cli_args: 462 | base_args.extend(custom_cli_args.split()) 463 | 464 | # Add the query at the end 465 | base_args.append(query) 466 | 467 | return base_args 468 | 469 | 470 | 471 | 472 | 473 | 474 | def upload_file(file): 475 | if file is not None: 476 | input_dir = os.path.join(ROOT_DIR, "input") 477 | os.makedirs(input_dir, exist_ok=True) 478 | 479 | # Get the original filename from the uploaded file 480 | original_filename = file.name 481 | 482 | # Create the destination path 483 | destination_path = os.path.join(input_dir, os.path.basename(original_filename)) 484 | 485 | # Move the uploaded file to the destination path 486 | shutil.move(file.name, destination_path) 487 | 488 | logging.info(f"File uploaded and moved to: {destination_path}") 489 | status = f"File uploaded: {os.path.basename(original_filename)}" 490 | else: 491 | status = "No file uploaded" 492 | 493 | # Get the updated file list 494 | updated_file_list = [f["path"] for f in list_input_files()] 495 | 496 | return status, gr.update(choices=updated_file_list), update_logs() 497 | 498 | def list_input_files(): 499 | input_dir = os.path.join(ROOT_DIR, "input") 500 | files = [] 501 | if os.path.exists(input_dir): 502 | files = os.listdir(input_dir) 503 | return [{"name": f, "path": os.path.join(input_dir, f)} for f in files] 504 | 505 | def delete_file(file_path): 506 | try: 507 | os.remove(file_path) 508 | logging.info(f"File deleted: {file_path}") 509 | status = f"File deleted: {os.path.basename(file_path)}" 510 | except Exception as e: 511 | logging.error(f"Error deleting file: {str(e)}") 512 | status = f"Error deleting file: {str(e)}" 513 | 514 | # Get the updated file list 515 | updated_file_list = [f["path"] for f in list_input_files()] 516 | 517 | return status, gr.update(choices=updated_file_list), update_logs() 518 | 519 | def read_file_content(file_path): 520 | try: 521 | if file_path.endswith('.parquet'): 522 | df = pd.read_parquet(file_path) 523 | 524 | # Get basic information about the DataFrame 525 | info = f"Parquet File: {os.path.basename(file_path)}\n" 526 | info += f"Rows: {len(df)}, Columns: {len(df.columns)}\n\n" 527 | info += "Column Names:\n" + "\n".join(df.columns) + "\n\n" 528 | 529 | # Display first few rows 530 | info += "First 5 rows:\n" 531 | info += df.head().to_string() + "\n\n" 532 | 533 | # Display basic statistics 534 | info += "Basic Statistics:\n" 535 | info += df.describe().to_string() 536 | 537 | return info 538 | else: 539 | with open(file_path, 'r', encoding='utf-8', errors='replace') as file: 540 | content = file.read() 541 | return content 542 | except Exception as e: 543 | logging.error(f"Error reading file: {str(e)}") 544 | return f"Error reading file: {str(e)}" 545 | 546 | def save_file_content(file_path, content): 547 | try: 548 | with open(file_path, 'w') as file: 549 | file.write(content) 550 | logging.info(f"File saved: {file_path}") 551 | status = f"File saved: {os.path.basename(file_path)}" 552 | except Exception as e: 553 | logging.error(f"Error saving file: {str(e)}") 554 | status = f"Error saving file: {str(e)}" 555 | return status, update_logs() 556 | 557 | def manage_data(): 558 | db = lancedb.connect(os.path.join(ROOT_DIR,"lancedb")) 559 | tables = db.table_names() 560 | table_info = "" 561 | if tables: 562 | table = db[tables[0]] 563 | table_info = f"Table: {tables[0]}\nSchema: {table.schema}" 564 | 565 | input_files = list_input_files() 566 | 567 | return { 568 | "database_info": f"Tables: {', '.join(tables)}\n\n{table_info}", 569 | "input_files": input_files 570 | } 571 | 572 | 573 | def find_latest_graph_file(root_dir): 574 | pattern = os.path.join(root_dir, "output", "*", "artifacts", "*.graphml") 575 | graph_files = glob.glob(pattern) 576 | if not graph_files: 577 | # If no files found, try excluding .DS_Store 578 | output_dir = os.path.join(root_dir, "output") 579 | run_dirs = [d for d in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, d)) and d != ".DS_Store"] 580 | if run_dirs: 581 | latest_run = max(run_dirs) 582 | pattern = os.path.join(root_dir, "output", latest_run, "artifacts", "*.graphml") 583 | graph_files = glob.glob(pattern) 584 | 585 | if not graph_files: 586 | return None 587 | 588 | # Sort files by modification time, most recent first 589 | latest_file = max(graph_files, key=os.path.getmtime) 590 | return latest_file 591 | 592 | def update_visualization(folder_name, file_name, layout_type, node_size, edge_width, node_color_attribute, color_scheme, show_labels, label_size): 593 | root_dir = ROOT_DIR 594 | if not folder_name or not file_name: 595 | return None, "Please select a folder and a GraphML file." 596 | file_name = file_name.split("] ")[1] if "]" in file_name else file_name # Remove file type prefix 597 | graph_path = os.path.join(root_dir, "output", folder_name, "artifacts", file_name) 598 | if not graph_path.endswith('.graphml'): 599 | return None, "Please select a GraphML file for visualization." 600 | try: 601 | # Load the GraphML file 602 | graph = nx.read_graphml(graph_path) 603 | 604 | # Create layout based on user selection 605 | if layout_type == "3D Spring": 606 | pos = nx.spring_layout(graph, dim=3, seed=42, k=0.5) 607 | elif layout_type == "2D Spring": 608 | pos = nx.spring_layout(graph, dim=2, seed=42, k=0.5) 609 | else: # Circular 610 | pos = nx.circular_layout(graph) 611 | 612 | # Extract node positions 613 | if layout_type == "3D Spring": 614 | x_nodes = [pos[node][0] for node in graph.nodes()] 615 | y_nodes = [pos[node][1] for node in graph.nodes()] 616 | z_nodes = [pos[node][2] for node in graph.nodes()] 617 | else: 618 | x_nodes = [pos[node][0] for node in graph.nodes()] 619 | y_nodes = [pos[node][1] for node in graph.nodes()] 620 | z_nodes = [0] * len(graph.nodes()) # Set all z-coordinates to 0 for 2D layouts 621 | 622 | # Extract edge positions 623 | x_edges, y_edges, z_edges = [], [], [] 624 | for edge in graph.edges(): 625 | x_edges.extend([pos[edge[0]][0], pos[edge[1]][0], None]) 626 | y_edges.extend([pos[edge[0]][1], pos[edge[1]][1], None]) 627 | if layout_type == "3D Spring": 628 | z_edges.extend([pos[edge[0]][2], pos[edge[1]][2], None]) 629 | else: 630 | z_edges.extend([0, 0, None]) 631 | 632 | # Generate node colors based on user selection 633 | if node_color_attribute == "Degree": 634 | node_colors = [graph.degree(node) for node in graph.nodes()] 635 | else: # Random 636 | node_colors = [random.random() for _ in graph.nodes()] 637 | node_colors = np.array(node_colors) 638 | node_colors = (node_colors - node_colors.min()) / (node_colors.max() - node_colors.min()) 639 | 640 | # Create the trace for edges 641 | edge_trace = go.Scatter3d( 642 | x=x_edges, y=y_edges, z=z_edges, 643 | mode='lines', 644 | line=dict(color='lightgray', width=edge_width), 645 | hoverinfo='none' 646 | ) 647 | 648 | # Create the trace for nodes 649 | node_trace = go.Scatter3d( 650 | x=x_nodes, y=y_nodes, z=z_nodes, 651 | mode='markers+text' if show_labels else 'markers', 652 | marker=dict( 653 | size=node_size, 654 | color=node_colors, 655 | colorscale=color_scheme, 656 | colorbar=dict( 657 | title='Node Degree' if node_color_attribute == "Degree" else "Random Value", 658 | thickness=10, 659 | x=1.1, 660 | tickvals=[0, 1], 661 | ticktext=['Low', 'High'] 662 | ), 663 | line=dict(width=1) 664 | ), 665 | text=[node for node in graph.nodes()], 666 | textposition="top center", 667 | textfont=dict(size=label_size, color='black'), 668 | hoverinfo='text' 669 | ) 670 | 671 | # Create the plot 672 | fig = go.Figure(data=[edge_trace, node_trace]) 673 | 674 | # Update layout for better visualization 675 | fig.update_layout( 676 | title=f'{layout_type} Graph Visualization: {os.path.basename(graph_path)}', 677 | showlegend=False, 678 | scene=dict( 679 | xaxis=dict(showbackground=False, showticklabels=False, title=''), 680 | yaxis=dict(showbackground=False, showticklabels=False, title=''), 681 | zaxis=dict(showbackground=False, showticklabels=False, title='') 682 | ), 683 | margin=dict(l=0, r=0, b=0, t=40), 684 | annotations=[ 685 | dict( 686 | showarrow=False, 687 | text=f"Interactive {layout_type} visualization of GraphML data", 688 | xref="paper", 689 | yref="paper", 690 | x=0, 691 | y=0 692 | ) 693 | ], 694 | autosize=True 695 | ) 696 | 697 | fig.update_layout(autosize=True) 698 | fig.update_layout(height=600) # Set a fixed height 699 | return fig, f"Graph visualization generated successfully. Using file: {graph_path}" 700 | except Exception as e: 701 | return go.Figure(), f"Error visualizing graph: {str(e)}" 702 | 703 | 704 | 705 | 706 | 707 | def update_logs(): 708 | logs = [] 709 | while not log_queue.empty(): 710 | logs.append(log_queue.get()) 711 | return "\n".join(logs) 712 | 713 | 714 | 715 | def fetch_models(base_url, api_key, service_type): 716 | try: 717 | if service_type.lower() == "ollama": 718 | response = requests.get(f"{base_url}/tags", timeout=10) 719 | else: # OpenAI Compatible 720 | headers = { 721 | "Authorization": f"Bearer {api_key}", 722 | "Content-Type": "application/json" 723 | } 724 | response = requests.get(f"{base_url}/models", headers=headers, timeout=10) 725 | 726 | logging.info(f"Raw API response: {response.text}") 727 | 728 | if response.status_code == 200: 729 | data = response.json() 730 | if service_type.lower() == "ollama": 731 | models = [model.get('name', '') for model in data.get('models', data) if isinstance(model, dict)] 732 | else: # OpenAI Compatible 733 | models = [model.get('id', '') for model in data.get('data', []) if isinstance(model, dict)] 734 | 735 | models = [model for model in models if model] # Remove empty strings 736 | 737 | if not models: 738 | logging.warning(f"No models found in {service_type} API response") 739 | return ["No models available"] 740 | 741 | logging.info(f"Successfully fetched {service_type} models: {models}") 742 | return models 743 | else: 744 | logging.error(f"Error fetching {service_type} models. Status code: {response.status_code}, Response: {response.text}") 745 | return ["Error fetching models"] 746 | except requests.RequestException as e: 747 | logging.error(f"Exception while fetching {service_type} models: {str(e)}") 748 | return ["Error: Connection failed"] 749 | except Exception as e: 750 | logging.error(f"Unexpected error in fetch_models: {str(e)}") 751 | return ["Error: Unexpected issue"] 752 | 753 | def update_model_choices(base_url, api_key, service_type, settings_key): 754 | models = fetch_models(base_url, api_key, service_type) 755 | 756 | if not models: 757 | logging.warning(f"No models fetched for {service_type}.") 758 | 759 | # Get the current model from settings 760 | if settings_key=='llm': 761 | current_model = settings.get(settings_key, {}).get('model') 762 | else: 763 | current_model = settings.get(settings_key, {}).get('llm').get('model') 764 | 765 | match = re.match(r'\$\{(.+)\}', current_model) 766 | if match: 767 | # Extract the variable name 768 | variable_name = match.group(1) 769 | # Get the value from the variable 770 | current_model = globals().get(variable_name) or locals().get(variable_name) 771 | 772 | # If the current model is not in the list, add it 773 | if current_model and current_model not in models: 774 | models.append(current_model) 775 | 776 | return gr.update(choices=models, value=current_model if current_model in models else (models[0] if models else None)) 777 | 778 | def update_llm_model_choices(base_url, api_key, service_type): 779 | return update_model_choices(base_url, api_key, service_type, 'llm') 780 | 781 | def update_embeddings_model_choices(base_url, api_key, service_type): 782 | return update_model_choices(base_url, api_key, service_type, 'embeddings') 783 | 784 | 785 | 786 | 787 | def update_llm_settings(llm_model, embeddings_model, context_window, system_message, temperature, max_tokens, 788 | llm_api_base, llm_api_key, 789 | embeddings_api_base, embeddings_api_key, embeddings_service_type): 790 | try: 791 | # 2024-9-11 10:55:57 There is no need to modify it here 792 | # # Update settings.yaml 793 | # settings = load_settings() 794 | # settings['llm'].update({ 795 | # "model": llm_model, 796 | # "api_base": llm_api_base, 797 | # "api_key": llm_api_key, 798 | # "temperature": temperature, 799 | # "max_tokens": max_tokens, 800 | # }) 801 | # settings['embeddings']['llm'].update({ 802 | # "model": embeddings_model, 803 | # "api_base": embeddings_api_base, 804 | # "api_key": embeddings_api_key, 805 | # "provider": embeddings_service_type 806 | # }) 807 | 808 | # with open(f"{ROOT_DIR}/settings.yaml", 'w') as f: 809 | # yaml.dump(settings, f, default_flow_style=False) 810 | 811 | # Update .env file 812 | update_env_file("LLM_API_BASE", llm_api_base) 813 | update_env_file("LLM_API_KEY", llm_api_key) 814 | if llm_model != "${LLM_MODEL}": 815 | update_env_file("LLM_MODEL", llm_model) 816 | update_env_file("EMBEDDINGS_API_BASE", embeddings_api_base) 817 | update_env_file("EMBEDDINGS_API_KEY", embeddings_api_key) 818 | if embeddings_model != "${EMBEDDINGS_MODEL}": 819 | update_env_file("EMBEDDINGS_MODEL", embeddings_model) 820 | update_env_file("CONTEXT_WINDOW", str(context_window)) 821 | update_env_file("SYSTEM_MESSAGE", system_message) 822 | update_env_file("TEMPERATURE", str(temperature)) 823 | update_env_file("MAX_TOKENS", str(max_tokens)) 824 | 825 | # Reload environment variables 826 | load_dotenv(dotenv_path=env_file, override=True) 827 | 828 | return "LLM and embeddings settings updated successfully in both settings.yaml and .env files." 829 | except Exception as e: 830 | return f"Error updating LLM and embeddings settings: {str(e)}" 831 | 832 | def update_env_file(key, value): 833 | with open(env_file, 'r') as file: 834 | lines = file.readlines() 835 | 836 | updated = False 837 | for i, line in enumerate(lines): 838 | if line.startswith(f"{key}="): 839 | lines[i] = f"{key}={value}\n" 840 | updated = True 841 | break 842 | 843 | if not updated: 844 | lines.append(f"{key}={value}\n") 845 | 846 | with open(env_file, 'w') as file: 847 | file.writelines(lines) 848 | 849 | custom_css = """ 850 | html, body { 851 | margin: 0; 852 | padding: 0; 853 | height: 100vh; 854 | overflow: hidden; 855 | } 856 | 857 | .gradio-container { 858 | margin: 0 !important; 859 | padding: 0 !important; 860 | width: 100vw !important; 861 | max-width: 100vw !important; 862 | height: 100vh !important; 863 | max-height: 100vh !important; 864 | overflow: auto; 865 | display: flex; 866 | flex-direction: column; 867 | } 868 | 869 | #main-container { 870 | flex: 1; 871 | display: flex; 872 | overflow: hidden; 873 | } 874 | 875 | #left-column, #right-column { 876 | height: 100%; 877 | overflow-y: auto; 878 | padding: 10px; 879 | } 880 | 881 | #left-column { 882 | flex: 1; 883 | } 884 | 885 | #right-column { 886 | flex: 2; 887 | display: flex; 888 | flex-direction: column; 889 | } 890 | 891 | #chat-container { 892 | flex: 0 0 auto; /* Don't allow this to grow */ 893 | height: 100%; 894 | display: flex; 895 | flex-direction: column; 896 | overflow: hidden; 897 | border: 1px solid var(--color-accent); 898 | border-radius: 8px; 899 | padding: 10px; 900 | overflow-y: auto; 901 | } 902 | 903 | #chatbot { 904 | overflow-y: hidden; 905 | height: 100%; 906 | } 907 | 908 | #chat-input-row { 909 | margin-top: 10px; 910 | } 911 | 912 | #visualization-plot { 913 | width: 100%; 914 | aspect-ratio: 1 / 1; 915 | max-height: 600px; /* Adjust this value as needed */ 916 | } 917 | 918 | #vis-controls-row { 919 | display: flex; 920 | justify-content: space-between; 921 | align-items: center; 922 | margin-top: 10px; 923 | } 924 | 925 | #vis-controls-row > * { 926 | flex: 1; 927 | margin: 0 5px; 928 | } 929 | 930 | #vis-status { 931 | margin-top: 10px; 932 | } 933 | 934 | /* Chat input styling */ 935 | #chat-input-row { 936 | display: flex; 937 | flex-direction: column; 938 | } 939 | 940 | #chat-input-row > div { 941 | width: 100% !important; 942 | } 943 | 944 | #chat-input-row input[type="text"] { 945 | width: 100% !important; 946 | } 947 | 948 | /* Adjust padding for all containers */ 949 | .gr-box, .gr-form, .gr-panel { 950 | padding: 10px !important; 951 | } 952 | 953 | /* Ensure all textboxes and textareas have full height */ 954 | .gr-textbox, .gr-textarea { 955 | height: auto !important; 956 | min-height: 100px !important; 957 | } 958 | 959 | /* Ensure all dropdowns have full width */ 960 | .gr-dropdown { 961 | width: 100% !important; 962 | } 963 | 964 | :root { 965 | --color-background: #ffffff; 966 | --color-foreground: #3F4E4F; 967 | --color-accent: #A27B5C; 968 | --color-text: #DCD7C9; 969 | } 970 | 971 | body, .gradio-container { 972 | background-color: var(--color-background); 973 | color: var(--color-text); 974 | } 975 | 976 | .gr-button { 977 | background-color: var(--color-accent); 978 | color: var(--color-text); 979 | } 980 | 981 | .gr-input, .gr-textarea, .gr-dropdown { 982 | background-color: var(--color-foreground); 983 | color: var(--color-text); 984 | border: 1px solid var(--color-accent); 985 | } 986 | 987 | .gr-panel { 988 | background-color: var(--color-foreground); 989 | border: 1px solid var(--color-accent); 990 | } 991 | 992 | .gr-box { 993 | border-radius: 8px; 994 | margin-bottom: 10px; 995 | background-color: var(--color-foreground); 996 | } 997 | 998 | .gr-padded { 999 | padding: 10px; 1000 | } 1001 | 1002 | .gr-form { 1003 | background-color: var(--color-foreground); 1004 | } 1005 | 1006 | .gr-input-label, .gr-radio-label { 1007 | color: var(--color-text); 1008 | } 1009 | 1010 | .gr-checkbox-label { 1011 | color: var(--color-text); 1012 | } 1013 | 1014 | .gr-markdown { 1015 | color: var(--color-text); 1016 | } 1017 | 1018 | .gr-accordion { 1019 | background-color: var(--color-foreground); 1020 | border: 1px solid var(--color-accent); 1021 | } 1022 | 1023 | .gr-accordion-header { 1024 | background-color: var(--color-accent); 1025 | color: var(--color-text); 1026 | } 1027 | 1028 | #visualization-container { 1029 | display: flex; 1030 | flex-direction: column; 1031 | border: 2px solid var(--color-accent); 1032 | border-radius: 8px; 1033 | margin-top: 20px; 1034 | padding: 10px; 1035 | background-color: var(--color-foreground); 1036 | height: calc(100vh - 300px); /* Adjust this value as needed */ 1037 | } 1038 | 1039 | #visualization-plot { 1040 | width: 100%; 1041 | height: 100%; 1042 | } 1043 | 1044 | #vis-controls-row { 1045 | display: flex; 1046 | justify-content: space-between; 1047 | align-items: center; 1048 | margin-top: 10px; 1049 | } 1050 | 1051 | #vis-controls-row > * { 1052 | flex: 1; 1053 | margin: 0 5px; 1054 | } 1055 | 1056 | #vis-status { 1057 | margin-top: 10px; 1058 | } 1059 | 1060 | #log-container { 1061 | background-color: var(--color-foreground); 1062 | border: 1px solid var(--color-accent); 1063 | border-radius: 8px; 1064 | padding: 10px; 1065 | margin-top: 20px; 1066 | max-height: auto; 1067 | overflow-y: auto; 1068 | } 1069 | 1070 | .setting-accordion .label-wrap { 1071 | cursor: pointer; 1072 | } 1073 | 1074 | .setting-accordion .icon { 1075 | transition: transform 0.3s ease; 1076 | } 1077 | 1078 | .setting-accordion[open] .icon { 1079 | transform: rotate(90deg); 1080 | } 1081 | 1082 | .gr-form.gr-box { 1083 | border: none !important; 1084 | background: none !important; 1085 | } 1086 | 1087 | .model-params { 1088 | border-top: 1px solid var(--color-accent); 1089 | margin-top: 10px; 1090 | padding-top: 10px; 1091 | } 1092 | """ 1093 | 1094 | def list_output_files(root_dir): 1095 | output_dir = os.path.join(root_dir, "output") 1096 | files = [] 1097 | for root, _, filenames in os.walk(output_dir): 1098 | for filename in filenames: 1099 | files.append(os.path.join(root, filename)) 1100 | return files 1101 | 1102 | def update_file_list(): 1103 | files = list_input_files() 1104 | return gr.update(choices=[f["path"] for f in files]) 1105 | 1106 | def update_file_content(file_path): 1107 | if not file_path: 1108 | return "" 1109 | try: 1110 | with open(file_path, 'r', encoding='utf-8') as file: 1111 | content = file.read() 1112 | return content 1113 | except Exception as e: 1114 | logging.error(f"Error reading file: {str(e)}") 1115 | return f"Error reading file: {str(e)}" 1116 | 1117 | def list_output_folders(root_dir): 1118 | output_dir = os.path.join(root_dir, "output") 1119 | folders = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f))] 1120 | return sorted(folders, reverse=True) 1121 | 1122 | def list_folder_contents(folder_path): 1123 | contents = [] 1124 | for item in os.listdir(folder_path): 1125 | item_path = os.path.join(folder_path, item) 1126 | if os.path.isdir(item_path): 1127 | contents.append(f"[DIR] {item}") 1128 | else: 1129 | _, ext = os.path.splitext(item) 1130 | contents.append(f"[{ext[1:].upper()}] {item}") 1131 | return contents 1132 | 1133 | def update_output_folder_list(): 1134 | root_dir = ROOT_DIR 1135 | folders = list_output_folders(root_dir) 1136 | return gr.update(choices=folders, value=folders[0] if folders else None) 1137 | 1138 | def update_folder_content_list(folder_name): 1139 | root_dir = ROOT_DIR 1140 | if not folder_name: 1141 | return gr.update(choices=[]) 1142 | contents = list_folder_contents(os.path.join(root_dir, "output", folder_name, "artifacts")) 1143 | return gr.update(choices=contents) 1144 | 1145 | def handle_content_selection(folder_name, selected_item): 1146 | root_dir = ROOT_DIR 1147 | if isinstance(selected_item, list) and selected_item: 1148 | selected_item = selected_item[0] # Take the first item if it's a list 1149 | 1150 | if isinstance(selected_item, str) and selected_item.startswith("[DIR]"): 1151 | dir_name = selected_item[6:] # Remove "[DIR] " prefix 1152 | sub_contents = list_folder_contents(os.path.join(root_dir, "output", folder_name, dir_name)) 1153 | return gr.update(choices=sub_contents), "", "" 1154 | elif isinstance(selected_item, str): 1155 | file_name = selected_item.split("] ")[1] if "]" in selected_item else selected_item # Remove file type prefix if present 1156 | file_path = os.path.join(root_dir, "output", folder_name, "artifacts", file_name) 1157 | file_size = os.path.getsize(file_path) 1158 | file_type = os.path.splitext(file_name)[1] 1159 | file_info = f"File: {file_name}\nSize: {file_size} bytes\nType: {file_type}" 1160 | content = read_file_content(file_path) 1161 | return gr.update(), file_info, content 1162 | else: 1163 | return gr.update(), "", "" 1164 | 1165 | def initialize_selected_folder(folder_name): 1166 | root_dir = ROOT_DIR 1167 | if not folder_name: 1168 | return "Please select a folder first.", gr.update(choices=[]) 1169 | folder_path = os.path.join(root_dir, "output", folder_name, "artifacts") 1170 | if not os.path.exists(folder_path): 1171 | return f"Artifacts folder not found in '{folder_name}'.", gr.update(choices=[]) 1172 | contents = list_folder_contents(folder_path) 1173 | return f"Folder '{folder_name}/artifacts' initialized with {len(contents)} items.", gr.update(choices=contents) 1174 | 1175 | 1176 | settings = load_settings() 1177 | default_model = settings['llm']['model'] 1178 | cli_args = gr.State({}) 1179 | stop_indexing = threading.Event() 1180 | indexing_thread = None 1181 | 1182 | def start_indexing(*args): 1183 | global indexing_thread, stop_indexing 1184 | stop_indexing = threading.Event() # Reset the stop_indexing event 1185 | indexing_thread = threading.Thread(target=run_indexing, args=args) 1186 | indexing_thread.start() 1187 | return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False) 1188 | 1189 | def stop_indexing_process(): 1190 | global indexing_thread 1191 | logging.info("Stop indexing requested") 1192 | stop_indexing.set() 1193 | if indexing_thread and indexing_thread.is_alive(): 1194 | logging.info("Waiting for indexing thread to finish") 1195 | indexing_thread.join(timeout=10) 1196 | logging.info("Indexing thread finished" if not indexing_thread.is_alive() else "Indexing thread did not finish within timeout") 1197 | indexing_thread = None # Reset the thread 1198 | return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True) 1199 | 1200 | def refresh_indexing(): 1201 | global indexing_thread, stop_indexing 1202 | if indexing_thread and indexing_thread.is_alive(): 1203 | logging.info("Cannot refresh: Indexing is still running") 1204 | return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), "Cannot refresh: Indexing is still running" 1205 | else: 1206 | stop_indexing = threading.Event() # Reset the stop_indexing event 1207 | indexing_thread = None # Reset the thread 1208 | return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True), "Indexing process refreshed. You can start indexing again." 1209 | 1210 | 1211 | 1212 | def run_indexing(root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_args): 1213 | if not root_dir or root_dir == '.' or root_dir == './': 1214 | root_dir = ROOT_DIR 1215 | elif not os.path.exists(root_dir): 1216 | logging.error(f"Root directory '{root_dir}' does not exist.") 1217 | return ("\n".join(["Root directory does not exist."]), 1218 | "Root directory does not exist.", 1219 | 100, 1220 | gr.update(interactive=True), 1221 | gr.update(interactive=False), 1222 | gr.update(interactive=True), 1223 | "0") 1224 | 1225 | # Set default config_file if None 1226 | if not config_file: 1227 | config_file = "settings.yaml" 1228 | elif not os.path.exists(config_file): 1229 | logging.error(f"Config file '{config_file}' does not exist.") 1230 | return ("\n".join([f"Config file '{config_file}' does not exist."]), 1231 | f"Config file '{config_file}' does not exist.", 1232 | 100, 1233 | gr.update(interactive=True), 1234 | gr.update(interactive=False), 1235 | gr.update(interactive=True), 1236 | "0") 1237 | 1238 | cmd = ["python", "-m", "graphrag.index", "--root", root_dir, "--config", config_file] 1239 | 1240 | # Add other CLI arguments 1241 | if verbose: 1242 | cmd.append("--verbose") 1243 | if nocache: 1244 | cmd.append("--nocache") 1245 | if resume: 1246 | cmd.extend(["--resume", resume]) 1247 | if reporter: 1248 | cmd.extend(["--reporter", reporter]) 1249 | if emit_formats: 1250 | cmd.extend(["--emit", ','.join(emit_formats)]) 1251 | 1252 | # Add custom CLI arguments 1253 | if custom_args: 1254 | cmd.extend(custom_args.split()) 1255 | 1256 | logging.info(f"Executing command: {' '.join(cmd)}") 1257 | 1258 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, encoding='utf-8', universal_newlines=True) 1259 | 1260 | 1261 | output = [] 1262 | progress_value = 0 1263 | iterations_completed = 0 1264 | 1265 | while True: 1266 | if stop_indexing.is_set(): 1267 | process.terminate() 1268 | process.wait(timeout=5) 1269 | if process.poll() is None: 1270 | process.kill() 1271 | return ("\n".join(output + ["Indexing stopped by user."]), 1272 | "Indexing stopped.", 1273 | 100, 1274 | gr.update(interactive=True), 1275 | gr.update(interactive=False), 1276 | gr.update(interactive=True), 1277 | str(iterations_completed)) 1278 | 1279 | try: 1280 | line = process.stdout.readline() 1281 | if not line and process.poll() is not None: 1282 | break 1283 | 1284 | if line: 1285 | line = line.strip() 1286 | output.append(line) 1287 | 1288 | if "Processing file" in line: 1289 | progress_value += 1 1290 | iterations_completed += 1 1291 | elif "Indexing completed" in line: 1292 | progress_value = 100 1293 | elif "ERROR" in line: 1294 | line = f"🚨 ERROR: {line}" 1295 | 1296 | yield ("\n".join(output), 1297 | line, 1298 | progress_value, 1299 | gr.update(interactive=False), 1300 | gr.update(interactive=True), 1301 | gr.update(interactive=False), 1302 | str(iterations_completed)) 1303 | except Exception as e: 1304 | logging.error(f"Error during indexing: {str(e)}") 1305 | return ("\n".join(output + [f"Error: {str(e)}"]), 1306 | "Error occurred during indexing.", 1307 | 100, 1308 | gr.update(interactive=True), 1309 | gr.update(interactive=False), 1310 | gr.update(interactive=True), 1311 | str(iterations_completed)) 1312 | 1313 | if process.returncode != 0 and not stop_indexing.is_set(): 1314 | final_output = "\n".join(output + [f"Error: Process exited with return code {process.returncode}"]) 1315 | final_progress = "Indexing failed. Check output for details." 1316 | else: 1317 | final_output = "\n".join(output) 1318 | final_progress = "Indexing completed successfully!" 1319 | 1320 | return (final_output, 1321 | final_progress, 1322 | 100, 1323 | gr.update(interactive=True), 1324 | gr.update(interactive=False), 1325 | gr.update(interactive=True), 1326 | str(iterations_completed)) 1327 | 1328 | global_vector_store_wrapper = None 1329 | 1330 | def create_gradio_interface(): 1331 | global global_vector_store_wrapper 1332 | llm_models, embeddings_models, llm_service_type, embeddings_service_type, llm_api_base, embeddings_api_base, text_embedder = initialize_models() 1333 | settings = load_settings() 1334 | 1335 | 1336 | log_output = gr.TextArea(label="Logs", elem_id="log-output", interactive=False, visible=False) 1337 | 1338 | with gr.Blocks(css=custom_css, theme=gr.themes.Base()) as demo: 1339 | gr.Markdown("# GraphRAG Local UI", elem_id="title") 1340 | 1341 | with gr.Row(elem_id="main-container"): 1342 | with gr.Column(scale=1, elem_id="left-column"): 1343 | with gr.Tabs(): 1344 | with gr.TabItem("Data Management"): 1345 | with gr.Accordion("File Upload (.txt)", open=True): 1346 | file_upload = gr.File(label="Upload .txt File", file_types=[".txt"]) 1347 | upload_btn = gr.Button("Upload File", variant="primary") 1348 | upload_output = gr.Textbox(label="Upload Status", visible=False) 1349 | 1350 | with gr.Accordion("File Management", open=True): 1351 | file_list = gr.Dropdown(label="Select File", choices=[], interactive=True) 1352 | refresh_btn = gr.Button("Refresh File List", variant="secondary") 1353 | 1354 | file_content = gr.TextArea(label="File Content", lines=10) 1355 | 1356 | with gr.Row(): 1357 | delete_btn = gr.Button("Delete Selected File", variant="stop") 1358 | save_btn = gr.Button("Save Changes", variant="primary") 1359 | 1360 | operation_status = gr.Textbox(label="Operation Status", visible=False) 1361 | 1362 | 1363 | 1364 | with gr.TabItem("Indexing"): 1365 | root_dir = gr.Textbox(label="Root Directory", value=f"{ROOT_DIR}") 1366 | config_file = gr.File(label="Config File (optional)") 1367 | with gr.Row(): 1368 | verbose = gr.Checkbox(label="Verbose", value=True) 1369 | nocache = gr.Checkbox(label="No Cache", value=True) 1370 | with gr.Row(): 1371 | resume = gr.Textbox(label="Resume Timestamp (optional)") 1372 | reporter = gr.Dropdown(label="Reporter", choices=["rich", "print", "none"], value=None) 1373 | with gr.Row(): 1374 | emit_formats = gr.CheckboxGroup(label="Emit Formats", choices=["json", "csv", "parquet"], value=None) 1375 | with gr.Row(): 1376 | run_index_button = gr.Button("Run Indexing") 1377 | stop_index_button = gr.Button("Stop Indexing", variant="stop") 1378 | refresh_index_button = gr.Button("Refresh Indexing", variant="secondary") 1379 | 1380 | with gr.Accordion("Custom CLI Arguments", open=True): 1381 | custom_cli_args = gr.Textbox( 1382 | label="Custom CLI Arguments", 1383 | placeholder="--arg1 value1 --arg2 value2", 1384 | lines=3 1385 | ) 1386 | cli_guide = gr.Markdown( 1387 | textwrap.dedent(""" 1388 | ### CLI Argument Key Guide: 1389 | - `--root `: Set the root directory for the project 1390 | - `--config `: Specify a custom configuration file 1391 | - `--verbose`: Enable verbose output 1392 | - `--nocache`: Disable caching 1393 | - `--resume `: Resume from a specific timestamp 1394 | - `--reporter `: Set the reporter type (rich, print, none) 1395 | - `--emit `: Specify output formats (json, csv, parquet) 1396 | 1397 | Example: `--verbose --nocache --emit json,csv` 1398 | """) 1399 | ) 1400 | 1401 | index_output = gr.Textbox(label="Indexing Output", lines=20, max_lines=30) 1402 | index_progress = gr.Textbox(label="Indexing Progress", lines=3) 1403 | iterations_completed = gr.Textbox(label="Iterations Completed", value="0") 1404 | refresh_status = gr.Textbox(label="Refresh Status", visible=True) 1405 | 1406 | run_index_button.click( 1407 | fn=start_indexing, 1408 | inputs=[root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_cli_args], 1409 | outputs=[run_index_button, stop_index_button, refresh_index_button] 1410 | ).then( 1411 | fn=run_indexing, 1412 | inputs=[root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_cli_args], 1413 | outputs=[index_output, index_progress, run_index_button, stop_index_button, refresh_index_button, iterations_completed] 1414 | ) 1415 | 1416 | stop_index_button.click( 1417 | fn=stop_indexing_process, 1418 | outputs=[run_index_button, stop_index_button, refresh_index_button] 1419 | ) 1420 | 1421 | refresh_index_button.click( 1422 | fn=refresh_indexing, 1423 | outputs=[run_index_button, stop_index_button, refresh_index_button, refresh_status] 1424 | ) 1425 | 1426 | with gr.TabItem("Indexing Outputs/Visuals"): 1427 | output_folder_list = gr.Dropdown(label="Select Output Folder (Select GraphML File to Visualize)", choices=list_output_folders(ROOT_DIR), interactive=True) 1428 | refresh_folder_btn = gr.Button("Refresh Folder List", variant="secondary") 1429 | initialize_folder_btn = gr.Button("Initialize Selected Folder", variant="primary") 1430 | folder_content_list = gr.Dropdown(label="Select File or Directory", choices=[], interactive=True) 1431 | file_info = gr.Textbox(label="File Information", interactive=False) 1432 | output_content = gr.TextArea(label="File Content", lines=20, interactive=False) 1433 | initialization_status = gr.Textbox(label="Initialization Status") 1434 | 1435 | with gr.TabItem("LLM Settings"): 1436 | llm_base_url = gr.Textbox(label="LLM API Base URL", value=os.getenv("LLM_API_BASE")) 1437 | llm_api_key = gr.Textbox(label="LLM API Key", value=os.getenv("LLM_API_KEY"), type="password") 1438 | llm_service_type = gr.Radio( 1439 | label="LLM Service Type", 1440 | choices=["openai", "ollama"], 1441 | value="openai", 1442 | visible=False # Hide this if you want to always use OpenAI 1443 | ) 1444 | 1445 | llm_model_dropdown = gr.Dropdown( 1446 | label="LLM Model", 1447 | choices=[], # Start with an empty list 1448 | value=settings['llm'].get('model'), 1449 | allow_custom_value=True 1450 | ) 1451 | refresh_llm_models_btn = gr.Button("Refresh LLM Models", variant="secondary") 1452 | 1453 | embeddings_base_url = gr.Textbox(label="Embeddings API Base URL", value=os.getenv("EMBEDDINGS_API_BASE")) 1454 | embeddings_api_key = gr.Textbox(label="Embeddings API Key", value=os.getenv("EMBEDDINGS_API_KEY"), type="password") 1455 | embeddings_service_type = gr.Radio( 1456 | label="Embeddings Service Type", 1457 | choices=["openai", "ollama"], 1458 | value=settings.get('embeddings', {}).get('llm', {}).get('type', 'openai'), 1459 | visible=False, 1460 | ) 1461 | 1462 | embeddings_model_dropdown = gr.Dropdown( 1463 | label="Embeddings Model", 1464 | choices=[], 1465 | value=settings.get('embeddings', {}).get('llm', {}).get('model'), 1466 | allow_custom_value=True 1467 | ) 1468 | refresh_embeddings_models_btn = gr.Button("Refresh Embedding Models", variant="secondary") 1469 | system_message = gr.Textbox( 1470 | lines=5, 1471 | label="System Message", 1472 | value=os.getenv("SYSTEM_MESSAGE", "You are a helpful AI assistant.") 1473 | ) 1474 | context_window = gr.Slider( 1475 | label="Context Window", 1476 | minimum=512, 1477 | maximum=32768, 1478 | step=512, 1479 | value=int(os.getenv("CONTEXT_WINDOW", 4096)) 1480 | ) 1481 | temperature = gr.Slider( 1482 | label="Temperature", 1483 | minimum=0.0, 1484 | maximum=2.0, 1485 | step=0.1, 1486 | value=float(settings['llm'].get('TEMPERATURE', 0.5)) 1487 | ) 1488 | max_tokens = gr.Slider( 1489 | label="Max Tokens", 1490 | minimum=1, 1491 | maximum=8192, 1492 | step=1, 1493 | value=int(settings['llm'].get('MAX_TOKENS', 1024)) 1494 | ) 1495 | update_settings_btn = gr.Button("Update LLM Settings", variant="primary") 1496 | llm_settings_status = gr.Textbox(label="Status", interactive=False) 1497 | 1498 | # llm_base_url.change( 1499 | # fn=update_model_choices, 1500 | # inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)], 1501 | # outputs=llm_model_dropdown 1502 | # ) 1503 | # Update Embeddings model choices when service type or base URL changes 1504 | embeddings_service_type.change( 1505 | fn=update_embeddings_model_choices, 1506 | inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type], 1507 | outputs=embeddings_model_dropdown 1508 | ) 1509 | 1510 | # embeddings_base_url.change( 1511 | # fn=update_model_choices, 1512 | # inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)], 1513 | # outputs=embeddings_model_dropdown 1514 | # ) 1515 | 1516 | update_settings_btn.click( 1517 | fn=update_llm_settings, 1518 | inputs=[ 1519 | llm_model_dropdown, 1520 | embeddings_model_dropdown, 1521 | context_window, 1522 | system_message, 1523 | temperature, 1524 | max_tokens, 1525 | llm_base_url, 1526 | llm_api_key, 1527 | embeddings_base_url, 1528 | embeddings_api_key, 1529 | embeddings_service_type 1530 | ], 1531 | outputs=[llm_settings_status] 1532 | ) 1533 | 1534 | 1535 | refresh_llm_models_btn.click( 1536 | fn=update_model_choices, 1537 | inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)], 1538 | outputs=[llm_model_dropdown] 1539 | ) 1540 | 1541 | refresh_embeddings_models_btn.click( 1542 | fn=update_model_choices, 1543 | inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)], 1544 | outputs=[embeddings_model_dropdown] 1545 | ) 1546 | 1547 | with gr.TabItem("YAML Settings"): 1548 | settings = load_settings() 1549 | with gr.Group(): 1550 | for key, value in settings.items(): 1551 | if key != 'llm': 1552 | create_setting_component(key, value) 1553 | 1554 | with gr.Group(elem_id="log-container"): 1555 | gr.Markdown("### Logs") 1556 | log_output = gr.TextArea(label="Logs", elem_id="log-output", interactive=False) 1557 | 1558 | with gr.Column(scale=2, elem_id="right-column"): 1559 | with gr.Group(elem_id="chat-container"): 1560 | chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot") 1561 | with gr.Row(elem_id="chat-input-row"): 1562 | with gr.Column(scale=1): 1563 | query_input = gr.Textbox( 1564 | label="Input", 1565 | placeholder="Enter your query here...", 1566 | elem_id="query-input" 1567 | ) 1568 | query_btn = gr.Button("Send Query", variant="primary") 1569 | 1570 | with gr.Accordion("Query Parameters", open=True): 1571 | query_type = gr.Radio( 1572 | ["global", "local", "direct"], 1573 | label="Query Type", 1574 | value="global", 1575 | info="Global: community-based search, Local: entity-based search, Direct: LLM chat" 1576 | ) 1577 | preset_dropdown = gr.Dropdown( 1578 | label="Preset Query Options", 1579 | choices=[ 1580 | "Default Global Search", 1581 | "Default Local Search", 1582 | "Detailed Global Analysis", 1583 | "Detailed Local Analysis", 1584 | "Quick Global Summary", 1585 | "Quick Local Summary", 1586 | "Global Bullet Points", 1587 | "Local Bullet Points", 1588 | "Comprehensive Global Report", 1589 | "Comprehensive Local Report", 1590 | "High-Level Global Overview", 1591 | "High-Level Local Overview", 1592 | "Focused Global Insight", 1593 | "Focused Local Insight", 1594 | "Custom Query" 1595 | ], 1596 | value="Default Global Search", 1597 | info="Select a preset or choose 'Custom Query' for manual configuration" 1598 | ) 1599 | selected_folder = gr.Dropdown( 1600 | label="Select Index Folder to Chat With", 1601 | choices=list_output_folders(ROOT_DIR), 1602 | value=None, 1603 | interactive=True 1604 | ) 1605 | refresh_folder_btn = gr.Button("Refresh Folders", variant="secondary") 1606 | clear_chat_btn = gr.Button("Clear Chat", variant="secondary") 1607 | 1608 | with gr.Group(visible=False) as custom_options: 1609 | community_level = gr.Slider( 1610 | label="Community Level", 1611 | minimum=1, 1612 | maximum=10, 1613 | value=2, 1614 | step=1, 1615 | info="Higher values use reports on smaller communities" 1616 | ) 1617 | response_type = gr.Dropdown( 1618 | label="Response Type", 1619 | choices=[ 1620 | "Multiple Paragraphs", 1621 | "Single Paragraph", 1622 | "Single Sentence", 1623 | "List of 3-7 Points", 1624 | "Single Page", 1625 | "Multi-Page Report" 1626 | ], 1627 | value="Multiple Paragraphs", 1628 | info="Specify the desired format of the response" 1629 | ) 1630 | custom_cli_args = gr.Textbox( 1631 | label="Custom CLI Arguments", 1632 | placeholder="--arg1 value1 --arg2 value2", 1633 | info="Additional CLI arguments for advanced users" 1634 | ) 1635 | 1636 | def update_custom_options(preset): 1637 | if preset == "Custom Query": 1638 | return gr.update(visible=True) 1639 | else: 1640 | return gr.update(visible=False) 1641 | 1642 | preset_dropdown.change(fn=update_custom_options, inputs=[preset_dropdown], outputs=[custom_options]) 1643 | 1644 | 1645 | 1646 | 1647 | with gr.Group(elem_id="visualization-container"): 1648 | vis_output = gr.Plot(label="Graph Visualization", elem_id="visualization-plot") 1649 | with gr.Row(elem_id="vis-controls-row"): 1650 | vis_btn = gr.Button("Visualize Graph", variant="secondary") 1651 | 1652 | # Add new controls for customization 1653 | with gr.Accordion("Visualization Settings", open=False): 1654 | layout_type = gr.Dropdown(["3D Spring", "2D Spring", "Circular"], label="Layout Type", value="3D Spring") 1655 | node_size = gr.Slider(1, 20, 7, label="Node Size", step=1) 1656 | edge_width = gr.Slider(0.1, 5, 0.5, label="Edge Width", step=0.1) 1657 | node_color_attribute = gr.Dropdown(["Degree", "Random"], label="Node Color Attribute", value="Degree") 1658 | color_scheme = gr.Dropdown(["Viridis", "Plasma", "Inferno", "Magma", "Cividis"], label="Color Scheme", value="Viridis") 1659 | show_labels = gr.Checkbox(label="Show Node Labels", value=True) 1660 | label_size = gr.Slider(5, 20, 10, label="Label Size", step=1) 1661 | 1662 | 1663 | # Event handlers 1664 | upload_btn.click(fn=upload_file, inputs=[file_upload], outputs=[upload_output, file_list, log_output]) 1665 | refresh_btn.click(fn=update_file_list, outputs=[file_list]).then( 1666 | fn=update_logs, 1667 | outputs=[log_output] 1668 | ) 1669 | file_list.change(fn=update_file_content, inputs=[file_list], outputs=[file_content]).then( 1670 | fn=update_logs, 1671 | outputs=[log_output] 1672 | ) 1673 | delete_btn.click(fn=delete_file, inputs=[file_list], outputs=[operation_status, file_list, log_output]) 1674 | save_btn.click(fn=save_file_content, inputs=[file_list, file_content], outputs=[operation_status, log_output]) 1675 | 1676 | refresh_folder_btn.click( 1677 | fn=lambda: gr.update(choices=list_output_folders(ROOT_DIR)), 1678 | outputs=[selected_folder] 1679 | ) 1680 | 1681 | clear_chat_btn.click( 1682 | fn=lambda: ([], ""), 1683 | outputs=[chatbot, query_input] 1684 | ) 1685 | 1686 | refresh_folder_btn.click( 1687 | fn=update_output_folder_list, 1688 | outputs=[output_folder_list] 1689 | ).then( 1690 | fn=update_logs, 1691 | outputs=[log_output] 1692 | ) 1693 | 1694 | output_folder_list.change( 1695 | fn=update_folder_content_list, 1696 | inputs=[output_folder_list], 1697 | outputs=[folder_content_list] 1698 | ).then( 1699 | fn=update_logs, 1700 | outputs=[log_output] 1701 | ) 1702 | 1703 | folder_content_list.change( 1704 | fn=handle_content_selection, 1705 | inputs=[output_folder_list, folder_content_list], 1706 | outputs=[folder_content_list, file_info, output_content] 1707 | ).then( 1708 | fn=update_logs, 1709 | outputs=[log_output] 1710 | ) 1711 | 1712 | initialize_folder_btn.click( 1713 | fn=initialize_selected_folder, 1714 | inputs=[output_folder_list], 1715 | outputs=[initialization_status, folder_content_list] 1716 | ).then( 1717 | fn=update_logs, 1718 | outputs=[log_output] 1719 | ) 1720 | 1721 | vis_btn.click( 1722 | fn=update_visualization, 1723 | inputs=[ 1724 | output_folder_list, 1725 | folder_content_list, 1726 | layout_type, 1727 | node_size, 1728 | edge_width, 1729 | node_color_attribute, 1730 | color_scheme, 1731 | show_labels, 1732 | label_size 1733 | ], 1734 | outputs=[vis_output, gr.Textbox(label="Visualization Status")] 1735 | ) 1736 | 1737 | query_btn.click( 1738 | fn=send_message, 1739 | inputs=[ 1740 | query_type, 1741 | query_input, 1742 | chatbot, 1743 | system_message, 1744 | temperature, 1745 | max_tokens, 1746 | preset_dropdown, 1747 | community_level, 1748 | response_type, 1749 | custom_cli_args, 1750 | selected_folder 1751 | ], 1752 | outputs=[chatbot, query_input, log_output] 1753 | ) 1754 | 1755 | query_input.submit( 1756 | fn=send_message, 1757 | inputs=[ 1758 | query_type, 1759 | query_input, 1760 | chatbot, 1761 | system_message, 1762 | temperature, 1763 | max_tokens, 1764 | preset_dropdown, 1765 | community_level, 1766 | response_type, 1767 | custom_cli_args, 1768 | selected_folder 1769 | ], 1770 | outputs=[chatbot, query_input, log_output] 1771 | ) 1772 | 1773 | # Add this JavaScript to enable Shift+Enter functionality 1774 | demo.load(js=""" 1775 | function addShiftEnterListener() { 1776 | const queryInput = document.getElementById('query-input'); 1777 | if (queryInput) { 1778 | queryInput.addEventListener('keydown', function(event) { 1779 | if (event.key === 'Enter' && event.shiftKey) { 1780 | event.preventDefault(); 1781 | const submitButton = queryInput.closest('.gradio-container').querySelector('button.primary'); 1782 | if (submitButton) { 1783 | submitButton.click(); 1784 | } 1785 | } 1786 | }); 1787 | } 1788 | } 1789 | document.addEventListener('DOMContentLoaded', addShiftEnterListener); 1790 | """) 1791 | 1792 | return demo.queue() 1793 | 1794 | async def main(): 1795 | api_port = 8088 1796 | gradio_port = 7860 1797 | 1798 | 1799 | print(f"Starting API server on port {api_port}") 1800 | start_api_server(api_port) 1801 | 1802 | # Wait for the API server to start in a separate thread 1803 | threading.Thread(target=wait_for_api_server, args=(api_port,)).start() 1804 | 1805 | # Create the Gradio app 1806 | demo = create_gradio_interface() 1807 | 1808 | print(f"Starting Gradio app on port {gradio_port}") 1809 | # Launch the Gradio app 1810 | demo.launch(server_port=gradio_port, share=True) 1811 | 1812 | 1813 | demo = create_gradio_interface() 1814 | app = demo.app 1815 | 1816 | def main(): 1817 | initialize_data() 1818 | demo.launch(server_name='0.0.0.0', server_port=7862, share=False) 1819 | 1820 | if __name__ == "__main__": 1821 | main() -------------------------------------------------------------------------------- /assets/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wade1010/graphrag-ui/493e051a05890803e6a566c811a96467bc20d68e/assets/image1.png -------------------------------------------------------------------------------- /assets/image2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wade1010/graphrag-ui/493e051a05890803e6a566c811a96467bc20d68e/assets/image2.gif -------------------------------------------------------------------------------- /assets/image3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wade1010/graphrag-ui/493e051a05890803e6a566c811a96467bc20d68e/assets/image3.png -------------------------------------------------------------------------------- /index_app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | try: 3 | import graphrag 4 | except ImportError: 5 | print("The 'graphrag' package is not installed. Please install it using 'pip install graphrag'.Since the dependency package `aiofiles` of `graphrag` conflicts with the requirements of `gradio`, it is necessary to manually install `graphrag` separately.") 6 | sys.exit(1) 7 | import gradio as gr 8 | import requests 9 | import logging 10 | import os 11 | import json 12 | import shutil 13 | import glob 14 | import queue 15 | import lancedb 16 | from datetime import datetime 17 | from dotenv import load_dotenv, set_key 18 | import yaml 19 | import pandas as pd 20 | from typing import List, Optional 21 | from pydantic import BaseModel 22 | 23 | # Set up logging 24 | log_queue = queue.Queue() 25 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 26 | logger = logging.getLogger(__name__) 27 | 28 | graphrag_indexing_dir = 'indexing' 29 | project_root = os.path.abspath(os.path.dirname(__file__)) 30 | ROOT_DIR = os.path.join(project_root,graphrag_indexing_dir) 31 | env_file = os.path.join(ROOT_DIR, '.env') 32 | 33 | load_dotenv(env_file) 34 | 35 | API_BASE_URL = os.getenv('API_BASE_URL', 'http://localhost:8012') 36 | LLM_API_BASE = os.getenv('LLM_API_BASE', 'http://localhost:11434') 37 | EMBEDDINGS_API_BASE = os.getenv('EMBEDDINGS_API_BASE', 'http://localhost:11434') 38 | 39 | # Data models 40 | class IndexingRequest(BaseModel): 41 | llm_model: str 42 | embed_model: str 43 | llm_api_base: str 44 | embed_api_base: str 45 | root: str 46 | verbose: bool = False 47 | nocache: bool = False 48 | resume: Optional[str] = None 49 | reporter: str = "rich" 50 | emit: List[str] = ["parquet"] 51 | custom_args: Optional[str] = None 52 | 53 | class PromptTuneRequest(BaseModel): 54 | root: str = "{ROOT_DIR}" 55 | config: str = "{ROOT_DIR}/settings.yaml" 56 | domain: Optional[str] = None 57 | method: str = "random" 58 | limit: int = 15 59 | language: Optional[str] = None 60 | max_tokens: int = 2000 61 | chunk_size: int = 200 62 | no_entity_types: bool = False 63 | output: str = "{ROOT_DIR}/prompts" 64 | 65 | class QueueHandler(logging.Handler): 66 | def __init__(self, log_queue): 67 | super().__init__() 68 | self.log_queue = log_queue 69 | 70 | def emit(self, record): 71 | self.log_queue.put(self.format(record)) 72 | queue_handler = QueueHandler(log_queue) 73 | logging.getLogger().addHandler(queue_handler) 74 | 75 | 76 | def update_logs(): 77 | logs = [] 78 | while not log_queue.empty(): 79 | logs.append(log_queue.get()) 80 | return "\n".join(logs) 81 | 82 | ##########SETTINGS################ 83 | def load_settings(): 84 | config_path = os.getenv('GRAPHRAG_CONFIG', 'config.yaml') 85 | if os.path.exists(config_path): 86 | with open(config_path, 'r') as config_file: 87 | config = yaml.safe_load(config_file) 88 | else: 89 | config = {} 90 | 91 | settings = { 92 | 'llm_model': os.getenv('LLM_MODEL', config.get('llm_model')), 93 | 'embedding_model': os.getenv('EMBEDDINGS_MODEL', config.get('embedding_model')), 94 | 'community_level': int(os.getenv('COMMUNITY_LEVEL', config.get('community_level', 2))), 95 | 'token_limit': int(os.getenv('TOKEN_LIMIT', config.get('token_limit', 4096))), 96 | 'api_key': os.getenv('GRAPHRAG_API_KEY', config.get('api_key')), 97 | 'api_base': os.getenv('LLM_API_BASE', config.get('api_base')), 98 | 'embeddings_api_base': os.getenv('EMBEDDINGS_API_BASE', config.get('embeddings_api_base')), 99 | 'api_type': os.getenv('API_TYPE', config.get('api_type', 'openai')), 100 | } 101 | 102 | return settings 103 | 104 | 105 | #######FILE_MANAGEMENT############## 106 | def list_output_files(root_dir): 107 | output_dir = os.path.join(root_dir, "output") 108 | files = [] 109 | for root, _, filenames in os.walk(output_dir): 110 | for filename in filenames: 111 | files.append(os.path.join(root, filename)) 112 | return files 113 | 114 | def update_file_list(): 115 | files = list_input_files() 116 | return gr.update(choices=[f["path"] for f in files]) 117 | 118 | def update_file_content(file_path): 119 | if not file_path: 120 | return "" 121 | try: 122 | with open(file_path, 'r', encoding='utf-8') as file: 123 | content = file.read() 124 | return content 125 | except Exception as e: 126 | logging.error(f"Error reading file: {str(e)}") 127 | return f"Error reading file: {str(e)}" 128 | 129 | def list_output_folders(): 130 | output_dir = os.path.join(ROOT_DIR, "output") 131 | folders = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f))] 132 | return sorted(folders, reverse=True) 133 | 134 | def update_output_folder_list(): 135 | folders = list_output_folders() 136 | return gr.update(choices=folders, value=folders[0] if folders else None) 137 | 138 | def list_folder_contents(folder_name): 139 | folder_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts") 140 | contents = [] 141 | if os.path.exists(folder_path): 142 | for item in os.listdir(folder_path): 143 | item_path = os.path.join(folder_path, item) 144 | if os.path.isdir(item_path): 145 | contents.append(f"[DIR] {item}") 146 | else: 147 | _, ext = os.path.splitext(item) 148 | contents.append(f"[{ext[1:].upper()}] {item}") 149 | return contents 150 | 151 | def update_folder_content_list(folder_name): 152 | if isinstance(folder_name, list) and folder_name: 153 | folder_name = folder_name[0] 154 | elif not folder_name: 155 | return gr.update(choices=[]) 156 | 157 | contents = list_folder_contents(folder_name) 158 | return gr.update(choices=contents) 159 | 160 | def handle_content_selection(folder_name, selected_item): 161 | if isinstance(selected_item, list) and selected_item: 162 | selected_item = selected_item[0] # Take the first item if it's a list 163 | 164 | if isinstance(selected_item, str) and selected_item.startswith("[DIR]"): 165 | dir_name = selected_item[6:] # Remove "[DIR] " prefix 166 | sub_contents = list_folder_contents(os.path.join(ROOT_DIR, "output", folder_name, dir_name)) 167 | return gr.update(choices=sub_contents), "", "" 168 | elif isinstance(selected_item, str): 169 | file_name = selected_item.split("] ")[1] if "]" in selected_item else selected_item # Remove file type prefix if present 170 | file_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts", file_name) 171 | file_size = os.path.getsize(file_path) 172 | file_type = os.path.splitext(file_name)[1] 173 | file_info = f"File: {file_name}\nSize: {file_size} bytes\nType: {file_type}" 174 | content = read_file_content(file_path) 175 | return gr.update(), file_info, content 176 | else: 177 | return gr.update(), "", "" 178 | 179 | def initialize_selected_folder(folder_name): 180 | if not folder_name: 181 | return "Please select a folder first.", gr.update(choices=[]) 182 | folder_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts") 183 | if not os.path.exists(folder_path): 184 | return f"Artifacts folder not found in '{folder_name}'.", gr.update(choices=[]) 185 | contents = list_folder_contents(folder_path) 186 | return f"Folder '{folder_name}/artifacts' initialized with {len(contents)} items.", gr.update(choices=contents) 187 | 188 | def upload_file(file): 189 | if file is not None: 190 | input_dir = os.path.join(ROOT_DIR, 'input') 191 | os.makedirs(input_dir, exist_ok=True) 192 | 193 | # Get the original filename from the uploaded file 194 | original_filename = file.name 195 | 196 | # Create the destination path 197 | destination_path = os.path.join(input_dir, os.path.basename(original_filename)) 198 | 199 | # Move the uploaded file to the destination path 200 | shutil.move(file.name, destination_path) 201 | 202 | logging.info(f"File uploaded and moved to: {destination_path}") 203 | status = f"File uploaded: {os.path.basename(original_filename)}" 204 | else: 205 | status = "No file uploaded" 206 | 207 | # Get the updated file list 208 | updated_file_list = [f["path"] for f in list_input_files()] 209 | 210 | return status, gr.update(choices=updated_file_list), update_logs() 211 | 212 | def list_input_files(): 213 | input_dir = os.path.join(ROOT_DIR, 'input') 214 | files = [] 215 | if os.path.exists(input_dir): 216 | files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] 217 | return [{"name": f, "path": os.path.join(input_dir, f)} for f in files] 218 | 219 | def delete_file(file_path): 220 | try: 221 | os.remove(file_path) 222 | logging.info(f"File deleted: {file_path}") 223 | status = f"File deleted: {os.path.basename(file_path)}" 224 | except Exception as e: 225 | logging.error(f"Error deleting file: {str(e)}") 226 | status = f"Error deleting file: {str(e)}" 227 | 228 | # Get the updated file list 229 | updated_file_list = [f["path"] for f in list_input_files()] 230 | 231 | return status, gr.update(choices=updated_file_list), update_logs() 232 | 233 | def read_file_content(file_path): 234 | try: 235 | if file_path.endswith('.parquet'): 236 | df = pd.read_parquet(file_path) 237 | 238 | # Get basic information about the DataFrame 239 | info = f"Parquet File: {os.path.basename(file_path)}\n" 240 | info += f"Rows: {len(df)}, Columns: {len(df.columns)}\n\n" 241 | info += "Column Names:\n" + "\n".join(df.columns) + "\n\n" 242 | 243 | # Display first few rows 244 | info += "First 5 rows:\n" 245 | info += df.head().to_string() + "\n\n" 246 | 247 | # Display basic statistics 248 | info += "Basic Statistics:\n" 249 | info += df.describe().to_string() 250 | 251 | return info 252 | else: 253 | with open(file_path, 'r', encoding='utf-8', errors='replace') as file: 254 | content = file.read() 255 | return content 256 | except Exception as e: 257 | logging.error(f"Error reading file: {str(e)}") 258 | return f"Error reading file: {str(e)}" 259 | 260 | def save_file_content(file_path, content): 261 | try: 262 | with open(file_path, 'w') as file: 263 | file.write(content) 264 | logging.info(f"File saved: {file_path}") 265 | status = f"File saved: {os.path.basename(file_path)}" 266 | except Exception as e: 267 | logging.error(f"Error saving file: {str(e)}") 268 | status = f"Error saving file: {str(e)}" 269 | return status, update_logs() 270 | 271 | def manage_data(): 272 | db = lancedb.connect(f"{ROOT_DIR}/lancedb") 273 | tables = db.table_names() 274 | table_info = "" 275 | if tables: 276 | table = db[tables[0]] 277 | table_info = f"Table: {tables[0]}\nSchema: {table.schema}" 278 | 279 | input_files = list_input_files() 280 | 281 | return { 282 | "database_info": f"Tables: {', '.join(tables)}\n\n{table_info}", 283 | "input_files": input_files 284 | } 285 | 286 | 287 | def find_latest_graph_file(root_dir): 288 | pattern = os.path.join(root_dir, "output", "*", "artifacts", "*.graphml") 289 | graph_files = glob.glob(pattern) 290 | if not graph_files: 291 | # If no files found, try excluding .DS_Store 292 | output_dir = os.path.join(root_dir, "output") 293 | run_dirs = [d for d in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, d)) and d != ".DS_Store"] 294 | if run_dirs: 295 | latest_run = max(run_dirs) 296 | pattern = os.path.join(root_dir, "output", latest_run, "artifacts", "*.graphml") 297 | graph_files = glob.glob(pattern) 298 | 299 | if not graph_files: 300 | return None 301 | 302 | # Sort files by modification time, most recent first 303 | latest_file = max(graph_files, key=os.path.getmtime) 304 | return latest_file 305 | 306 | def find_latest_output_folder(): 307 | root_dir =f"{ROOT_DIR}/output" 308 | folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))] 309 | 310 | if folders is None: 311 | raise ValueError("No output folders found") 312 | elif len(folders) == 0: 313 | return None,None 314 | 315 | # Sort folders by creation time, most recent first 316 | sorted_folders = sorted(folders, key=lambda x: os.path.getctime(os.path.join(root_dir, x)), reverse=True) 317 | 318 | latest_folder = None 319 | timestamp = None 320 | 321 | for folder in sorted_folders: 322 | try: 323 | # Try to parse the folder name as a timestamp 324 | timestamp = datetime.strptime(folder, "%Y%m%d-%H%M%S") 325 | latest_folder = folder 326 | break 327 | except ValueError: 328 | # If the folder name is not a valid timestamp, skip it 329 | continue 330 | 331 | if latest_folder is None: 332 | raise ValueError("No valid timestamp folders found") 333 | 334 | latest_path = os.path.join(root_dir, latest_folder) 335 | artifacts_path = os.path.join(latest_path, "artifacts") 336 | 337 | if not os.path.exists(artifacts_path): 338 | raise ValueError(f"Artifacts folder not found in {latest_path}") 339 | 340 | return latest_path, latest_folder 341 | 342 | def initialize_data(): 343 | global entity_df, relationship_df, text_unit_df, report_df, covariate_df 344 | 345 | tables = { 346 | "entity_df": "create_final_nodes", 347 | "relationship_df": "create_final_edges", 348 | "text_unit_df": "create_final_text_units", 349 | "report_df": "create_final_reports", 350 | "covariate_df": "create_final_covariates" 351 | } 352 | 353 | timestamp = None # Initialize timestamp to None 354 | 355 | try: 356 | latest_output_folder, timestamp = find_latest_output_folder() 357 | if latest_output_folder is None: 358 | return None 359 | artifacts_folder = os.path.join(latest_output_folder, "artifacts") 360 | 361 | for df_name, file_prefix in tables.items(): 362 | file_pattern = os.path.join(artifacts_folder, f"{file_prefix}*.parquet") 363 | matching_files = glob.glob(file_pattern) 364 | 365 | if matching_files: 366 | latest_file = max(matching_files, key=os.path.getctime) 367 | df = pd.read_parquet(latest_file) 368 | globals()[df_name] = df 369 | logging.info(f"Successfully loaded {df_name} from {latest_file}") 370 | else: 371 | logging.warning(f"No matching file found for {df_name} in {artifacts_folder}. Initializing as an empty DataFrame.") 372 | globals()[df_name] = pd.DataFrame() 373 | 374 | except Exception as e: 375 | logging.error(f"Error initializing data: {str(e)}") 376 | for df_name in tables.keys(): 377 | globals()[df_name] = pd.DataFrame() 378 | 379 | return timestamp 380 | 381 | # Call initialize_data and store the timestamp 382 | current_timestamp = initialize_data() 383 | 384 | 385 | ###########MODELS################## 386 | def normalize_api_base(api_base: str) -> str: 387 | """Normalize the API base URL by removing trailing slashes and /v1 or /api suffixes.""" 388 | api_base = api_base.rstrip('/') 389 | if api_base.endswith('/v1') or api_base.endswith('/api'): 390 | api_base = api_base[:-3] 391 | return api_base 392 | 393 | def is_ollama_api(base_url: str) -> bool: 394 | """Check if the given base URL is for Ollama API.""" 395 | try: 396 | response = requests.get(f"{normalize_api_base(base_url)}/api/tags") 397 | return response.status_code == 200 398 | except requests.RequestException: 399 | return False 400 | 401 | def get_ollama_models(base_url: str) -> List[str]: 402 | """Fetch available models from Ollama API.""" 403 | try: 404 | response = requests.get(f"{normalize_api_base(base_url)}/api/tags") 405 | response.raise_for_status() 406 | models = response.json().get('models', []) 407 | return [model['name'] for model in models] 408 | except requests.RequestException as e: 409 | logger.error(f"Error fetching Ollama models: {str(e)}") 410 | return [] 411 | 412 | def get_openai_compatible_models(base_url: str) -> List[str]: 413 | """Fetch available models from OpenAI-compatible API.""" 414 | try: 415 | response = requests.get(f"{normalize_api_base(base_url)}/v1/models") 416 | response.raise_for_status() 417 | models = response.json().get('data', []) 418 | return [model['id'] for model in models] 419 | except requests.RequestException as e: 420 | logger.error(f"Error fetching OpenAI-compatible models: {str(e)}") 421 | return [] 422 | 423 | def get_local_models(base_url: str) -> List[str]: 424 | """Get available models based on the API type.""" 425 | if is_ollama_api(base_url): 426 | return get_ollama_models(base_url) 427 | else: 428 | return get_openai_compatible_models(base_url) 429 | 430 | def get_model_params(base_url: str, model_name: str) -> dict: 431 | """Get model parameters for Ollama models.""" 432 | if is_ollama_api(base_url): 433 | try: 434 | response = requests.post(f"{normalize_api_base(base_url)}/api/show", json={"name": model_name}) 435 | response.raise_for_status() 436 | model_info = response.json() 437 | return model_info.get('parameters', {}) 438 | except requests.RequestException as e: 439 | logger.error(f"Error fetching Ollama model parameters: {str(e)}") 440 | return {} 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | #########API########### 450 | def start_indexing(request: IndexingRequest): 451 | url = f"{API_BASE_URL}/v1/index" 452 | 453 | try: 454 | response = requests.post(url, json=request.dict()) 455 | response.raise_for_status() 456 | result = response.json() 457 | return result['message'], gr.update(interactive=False), gr.update(interactive=True) 458 | except requests.RequestException as e: 459 | logger.error(f"Error starting indexing: {str(e)}") 460 | return f"Error: {str(e)}", gr.update(interactive=True), gr.update(interactive=False) 461 | 462 | def check_indexing_status(): 463 | url = f"{API_BASE_URL}/v1/index_status" 464 | try: 465 | response = requests.get(url) 466 | response.raise_for_status() 467 | result = response.json() 468 | return result['status'], "\n".join(result['logs']) 469 | except requests.RequestException as e: 470 | logger.error(f"Error checking indexing status: {str(e)}") 471 | return "Error", f"Failed to check indexing status: {str(e)}" 472 | 473 | def start_prompt_tuning(request: PromptTuneRequest): 474 | url = f"{API_BASE_URL}/v1/prompt_tune" 475 | 476 | try: 477 | response = requests.post(url, json=request.model_dump()) 478 | response.raise_for_status() 479 | result = response.json() 480 | return result['message'], gr.update(interactive=False) 481 | except requests.RequestException as e: 482 | logger.error(f"Error starting prompt tuning: {str(e)}") 483 | return f"Error: {str(e)}", gr.update(interactive=True) 484 | 485 | def check_prompt_tuning_status(): 486 | url = f"{API_BASE_URL}/v1/prompt_tune_status" 487 | try: 488 | response = requests.get(url) 489 | response.raise_for_status() 490 | result = response.json() 491 | return result['status'], "\n".join(result['logs']) 492 | except requests.RequestException as e: 493 | logger.error(f"Error checking prompt tuning status: {str(e)}") 494 | return "Error", f"Failed to check prompt tuning status: {str(e)}" 495 | 496 | def update_model_params(model_name): 497 | params = get_model_params(model_name) 498 | return gr.update(value=json.dumps(params, indent=2)) 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | ########################### 509 | css = """ 510 | html, body { 511 | margin: 0; 512 | padding: 0; 513 | height: 100vh; 514 | overflow: hidden; 515 | } 516 | 517 | .gradio-container { 518 | margin: 0 !important; 519 | padding: 0 !important; 520 | width: 100vw !important; 521 | max-width: 100vw !important; 522 | height: 100vh !important; 523 | max-height: 100vh !important; 524 | overflow: auto; 525 | display: flex; 526 | flex-direction: column; 527 | } 528 | 529 | #main-container { 530 | flex: 1; 531 | display: flex; 532 | overflow: hidden; 533 | } 534 | 535 | #left-column, #right-column { 536 | height: 100%; 537 | overflow-y: auto; 538 | padding: 10px; 539 | } 540 | 541 | #left-column { 542 | flex: 1; 543 | } 544 | 545 | #right-column { 546 | flex: 2; 547 | display: flex; 548 | flex-direction: column; 549 | } 550 | 551 | #chat-container { 552 | flex: 0 0 auto; /* Don't allow this to grow */ 553 | height: 100%; 554 | display: flex; 555 | flex-direction: column; 556 | overflow: hidden; 557 | border: 1px solid var(--color-accent); 558 | border-radius: 8px; 559 | padding: 10px; 560 | overflow-y: auto; 561 | } 562 | 563 | #chatbot { 564 | overflow-y: hidden; 565 | height: 100%; 566 | } 567 | 568 | #chat-input-row { 569 | margin-top: 10px; 570 | } 571 | 572 | #visualization-plot { 573 | width: 100%; 574 | aspect-ratio: 1 / 1; 575 | max-height: 600px; /* Adjust this value as needed */ 576 | } 577 | 578 | #vis-controls-row { 579 | display: flex; 580 | justify-content: space-between; 581 | align-items: center; 582 | margin-top: 10px; 583 | } 584 | 585 | #vis-controls-row > * { 586 | flex: 1; 587 | margin: 0 5px; 588 | } 589 | 590 | #vis-status { 591 | margin-top: 10px; 592 | } 593 | 594 | /* Chat input styling */ 595 | #chat-input-row { 596 | display: flex; 597 | flex-direction: column; 598 | } 599 | 600 | #chat-input-row > div { 601 | width: 100% !important; 602 | } 603 | 604 | #chat-input-row input[type="text"] { 605 | width: 100% !important; 606 | } 607 | 608 | /* Adjust padding for all containers */ 609 | .gr-box, .gr-form, .gr-panel { 610 | padding: 10px !important; 611 | } 612 | 613 | /* Ensure all textboxes and textareas have full height */ 614 | .gr-textbox, .gr-textarea { 615 | height: auto !important; 616 | min-height: 100px !important; 617 | } 618 | 619 | /* Ensure all dropdowns have full width */ 620 | .gr-dropdown { 621 | width: 100% !important; 622 | } 623 | 624 | :root { 625 | --color-background: #ffffff; 626 | --color-foreground: #3F4E4F; 627 | --color-accent: #A27B5C; 628 | --color-text: #DCD7C9; 629 | } 630 | 631 | body, .gradio-container { 632 | background-color: var(--color-background); 633 | color: var(--color-text); 634 | } 635 | 636 | .gr-button { 637 | background-color: var(--color-accent); 638 | color: var(--color-text); 639 | } 640 | 641 | .gr-input, .gr-textarea, .gr-dropdown { 642 | background-color: var(--color-foreground); 643 | color: var(--color-text); 644 | border: 1px solid var(--color-accent); 645 | } 646 | 647 | .gr-panel { 648 | background-color: var(--color-foreground); 649 | border: 1px solid var(--color-accent); 650 | } 651 | 652 | .gr-box { 653 | border-radius: 8px; 654 | margin-bottom: 10px; 655 | background-color: var(--color-foreground); 656 | } 657 | 658 | .gr-padded { 659 | padding: 10px; 660 | } 661 | 662 | .gr-form { 663 | background-color: var(--color-foreground); 664 | } 665 | 666 | .gr-input-label, .gr-radio-label { 667 | color: var(--color-text); 668 | } 669 | 670 | .gr-checkbox-label { 671 | color: var(--color-text); 672 | } 673 | 674 | .gr-markdown { 675 | color: var(--color-text); 676 | } 677 | 678 | .gr-accordion { 679 | background-color: var(--color-foreground); 680 | border: 1px solid var(--color-accent); 681 | } 682 | 683 | .gr-accordion-header { 684 | background-color: var(--color-accent); 685 | color: var(--color-text); 686 | } 687 | 688 | #visualization-container { 689 | display: flex; 690 | flex-direction: column; 691 | border: 2px solid var(--color-accent); 692 | border-radius: 8px; 693 | margin-top: 20px; 694 | padding: 10px; 695 | background-color: var(--color-foreground); 696 | height: calc(100vh - 300px); /* Adjust this value as needed */ 697 | } 698 | 699 | #visualization-plot { 700 | width: 100%; 701 | height: 100%; 702 | } 703 | 704 | #vis-controls-row { 705 | display: flex; 706 | justify-content: space-between; 707 | align-items: center; 708 | margin-top: 10px; 709 | } 710 | 711 | #vis-controls-row > * { 712 | flex: 1; 713 | margin: 0 5px; 714 | } 715 | 716 | #vis-status { 717 | margin-top: 10px; 718 | } 719 | 720 | #log-container { 721 | background-color: var(--color-foreground); 722 | border: 1px solid var(--color-accent); 723 | border-radius: 8px; 724 | padding: 10px; 725 | margin-top: 20px; 726 | max-height: auto; 727 | overflow-y: auto; 728 | } 729 | 730 | .setting-accordion .label-wrap { 731 | cursor: pointer; 732 | } 733 | 734 | .pointer-cursor { 735 | cursor: pointer; 736 | } 737 | 738 | .setting-accordion .icon { 739 | transition: transform 0.3s ease; 740 | } 741 | 742 | .setting-accordion[open] .icon { 743 | transform: rotate(90deg); 744 | } 745 | 746 | .gr-form.gr-box { 747 | border: none !important; 748 | background: none !important; 749 | } 750 | 751 | .model-params { 752 | border-top: 1px solid var(--color-accent); 753 | margin-top: 10px; 754 | padding-top: 10px; 755 | } 756 | """ 757 | 758 | 759 | def create_interface(): 760 | settings = load_settings() 761 | llm_api_base = normalize_api_base(settings['api_base']) 762 | embeddings_api_base = normalize_api_base(settings['embeddings_api_base']) 763 | 764 | with gr.Blocks(theme=gr.themes.Base(), css=css) as demo: 765 | gr.Markdown("# GraphRAG Indexer") 766 | 767 | with gr.Tabs(): 768 | with gr.TabItem("Indexing"): 769 | with gr.Row(): 770 | with gr.Column(scale=1): 771 | gr.Markdown("## Indexing Configuration") 772 | 773 | with gr.Row(): 774 | llm_name = gr.Dropdown(label="LLM Model", choices=[], value=settings['llm_model'], allow_custom_value=True) 775 | refresh_llm_btn = gr.Button("🔄", size='sm', scale=0) 776 | 777 | with gr.Row(): 778 | embed_name = gr.Dropdown(label="Embedding Model", choices=[], value=settings['embedding_model'], allow_custom_value=True) 779 | refresh_embed_btn = gr.Button("🔄", size='sm', scale=0) 780 | 781 | save_config_button = gr.Button("Save Configuration", variant="primary") 782 | config_status = gr.Textbox(label="Configuration Status", lines=2) 783 | 784 | with gr.Row(): 785 | with gr.Column(scale=1): 786 | root_dir = gr.Textbox(label="Root Directory (Edit in .env file)", value=f"{ROOT_DIR}") 787 | with gr.Group(): 788 | verbose = gr.Checkbox(label="Verbose", interactive=True, value=True) 789 | nocache = gr.Checkbox(label="No Cache", interactive=True, value=True) 790 | 791 | with gr.Accordion("Advanced Options", open=True): 792 | resume = gr.Textbox(label="Resume Timestamp (optional)") 793 | reporter = gr.Dropdown( 794 | label="Reporter", 795 | choices=["rich", "print", "none"], 796 | value="rich", 797 | interactive=True 798 | ) 799 | emit_formats = gr.CheckboxGroup( 800 | label="Emit Formats", 801 | choices=["json", "csv", "parquet"], 802 | value=["parquet"], 803 | interactive=True 804 | ) 805 | custom_args = gr.Textbox(label="Custom CLI Arguments", placeholder="--arg1 value1 --arg2 value2") 806 | import textwrap 807 | cli_guide = gr.Markdown( 808 | textwrap.dedent(""" 809 | ### CLI Argument Key Guide: 810 | - `--root `: Set the root directory for the project 811 | - `--config `: Specify a custom configuration file 812 | - `--verbose`: Enable verbose output 813 | - `--nocache`: Disable caching 814 | - `--resume `: Resume from a specific timestamp 815 | - `--reporter `: Set the reporter type (rich, print, none) 816 | - `--emit `: Specify output formats (json, csv, parquet) 817 | 818 | Example: `--verbose --nocache --emit json,csv` 819 | """) 820 | ) 821 | 822 | with gr.Column(scale=1): 823 | gr.Markdown("## Indexing Output") 824 | index_output = gr.Textbox(label="Output", lines=10) 825 | index_status = gr.Textbox(label="Status", lines=2) 826 | 827 | run_index_button = gr.Button("Run Indexing", variant="primary") 828 | check_status_button = gr.Button("Check Indexing Status") 829 | 830 | 831 | with gr.TabItem("Prompt Tuning"): 832 | with gr.Row(): 833 | with gr.Column(scale=1): 834 | gr.Markdown("## Prompt Tuning Configuration") 835 | 836 | pt_root = gr.Textbox(label="Root Directory", value=f"{ROOT_DIR}", interactive=True) 837 | pt_config = gr.Textbox(label="Config File", value=f"{ROOT_DIR}/settings.yaml", interactive=True) 838 | pt_domain = gr.Textbox(label="Domain", placeholder="optional") 839 | pt_method = gr.Dropdown( 840 | label="Method", 841 | choices=["random", "top", "all"], 842 | value="random", 843 | interactive=True, 844 | elem_classes="pointer-cursor" 845 | ) 846 | pt_limit = gr.Number(label="Limit", value=15, precision=0, interactive=True) 847 | pt_language = gr.Textbox(label="Language", placeholder="optional. eg: Chinese、Vietnamese", interactive=True) 848 | pt_max_tokens = gr.Number(label="Max Tokens", value=2000, precision=0, interactive=True) 849 | pt_chunk_size = gr.Number(label="Chunk Size", value=200, precision=0, interactive=True) 850 | pt_no_entity_types = gr.Checkbox(label="No Entity Types", value=False) 851 | pt_output_dir = gr.Textbox(label="Output Directory", value=f"{ROOT_DIR}/prompts", interactive=True) 852 | save_pt_config_button = gr.Button("Save Prompt Tuning Configuration", variant="primary") 853 | 854 | with gr.Column(scale=1): 855 | gr.Markdown("## Prompt Tuning Output") 856 | pt_output = gr.Textbox(label="Output", lines=10) 857 | pt_status = gr.Textbox(label="Status", lines=10) 858 | 859 | run_pt_button = gr.Button("Run Prompt Tuning", variant="primary") 860 | check_pt_status_button = gr.Button("Check Prompt Tuning Status") 861 | 862 | with gr.TabItem("Data Management"): 863 | with gr.Row(): 864 | with gr.Column(scale=1): 865 | with gr.Accordion("File Upload", open=True): 866 | file_upload = gr.File(label="Upload File", file_types=[".txt", ".csv", ".parquet"]) 867 | upload_btn = gr.Button("Upload File", variant="primary") 868 | upload_output = gr.Textbox(label="Upload Status", visible=True) 869 | 870 | with gr.Accordion("File Management", open=True): 871 | file_list = gr.Dropdown(label="Select File", choices=[], interactive=True) 872 | refresh_btn = gr.Button("Refresh File List", variant="secondary") 873 | 874 | file_content = gr.TextArea(label="File Content", lines=10) 875 | 876 | with gr.Row(): 877 | delete_btn = gr.Button("Delete Selected File", variant="stop") 878 | save_btn = gr.Button("Save Changes", variant="primary") 879 | 880 | operation_status = gr.Textbox(label="Operation Status", visible=True) 881 | 882 | with gr.Column(scale=1): 883 | with gr.Accordion("Output Folders", open=True): 884 | output_folder_list = gr.Dropdown(label="Select Output Folder", choices=[], interactive=True) 885 | refresh_output_btn = gr.Button("Refresh Output Folders", variant="secondary") 886 | folder_content_list = gr.Dropdown(label="Folder Contents", choices=[], interactive=True, multiselect=False) 887 | 888 | file_info = gr.Textbox(label="File Info", lines=3) 889 | output_content = gr.TextArea(label="File Content", lines=10) 890 | 891 | 892 | 893 | # Event handlers 894 | def refresh_llm_models(): 895 | models = get_local_models(llm_api_base) 896 | return gr.update(choices=models) 897 | 898 | def refresh_embed_models(): 899 | models = get_local_models(embeddings_api_base) 900 | return gr.update(choices=models) 901 | 902 | refresh_llm_btn.click( 903 | refresh_llm_models, 904 | outputs=[llm_name] 905 | ) 906 | 907 | refresh_embed_btn.click( 908 | refresh_embed_models, 909 | outputs=[embed_name] 910 | ) 911 | 912 | # Initialize model lists on page load 913 | demo.load(refresh_llm_models, outputs=[llm_name]) 914 | demo.load(refresh_embed_models, outputs=[embed_name]) 915 | 916 | def create_indexing_request(): 917 | return IndexingRequest( 918 | llm_model=llm_name.value, 919 | embed_model=embed_name.value, 920 | llm_api_base=llm_api_base, 921 | embed_api_base=embeddings_api_base, 922 | root=root_dir.value, 923 | verbose=verbose.value, 924 | nocache=nocache.value, 925 | resume=resume.value if resume.value else None, 926 | reporter=reporter.value, 927 | emit=[fmt for fmt in emit_formats.value], 928 | custom_args=custom_args.value if custom_args.value else None 929 | ) 930 | 931 | run_index_button.click( 932 | lambda: start_indexing(create_indexing_request()), 933 | outputs=[index_output, run_index_button, check_status_button] 934 | ) 935 | 936 | check_status_button.click( 937 | check_indexing_status, 938 | outputs=[index_status, index_output] 939 | ) 940 | 941 | # def create_prompt_tune_request(): 942 | # return PromptTuneRequest( 943 | # root=pt_root.value, 944 | # config=pt_config.value, 945 | # domain=pt_domain.value if pt_domain.value else None, 946 | # method=pt_method.value, 947 | # limit=int(pt_limit.value), 948 | # language=pt_language.value if pt_language.value else None, 949 | # max_tokens=int(pt_max_tokens.value), 950 | # chunk_size=int(pt_chunk_size.value), 951 | # no_entity_types=pt_no_entity_types.value, 952 | # output=pt_output_dir.value 953 | # ) 954 | def create_prompt_tune_request(root, config, domain, method, limit, language, max_tokens, chunk_size, no_entity_types, output): 955 | return PromptTuneRequest( 956 | root=root, 957 | config=config, 958 | domain=domain if domain else None, 959 | method=method, 960 | limit=int(limit), 961 | language=language if language else None, 962 | max_tokens=int(max_tokens), 963 | chunk_size=int(chunk_size), 964 | no_entity_types=no_entity_types, 965 | output=output 966 | ) 967 | 968 | # def update_pt_output(request): 969 | # result, button_update = start_prompt_tuning(request) 970 | # return result, button_update, gr.update(value=f"Request: {request.dict()}") 971 | 972 | def update_pt_output(pt_root, pt_config, pt_domain, pt_method, pt_limit, pt_language, pt_max_tokens, pt_chunk_size, pt_no_entity_types, pt_output_dir): 973 | request = create_prompt_tune_request( 974 | pt_root, pt_config, pt_domain, pt_method, pt_limit, pt_language, pt_max_tokens, pt_chunk_size, pt_no_entity_types, pt_output_dir 975 | ) 976 | result, button_update = start_prompt_tuning(request) 977 | return result, button_update, gr.update(value=f"Request: {request.dict()}") 978 | 979 | 980 | # run_pt_button.click( 981 | # lambda: update_pt_output(create_prompt_tune_request()), 982 | # outputs=[pt_output, run_pt_button, pt_status] 983 | # ) 984 | run_pt_button.click(update_pt_output, inputs=[pt_root, pt_config, pt_domain, pt_method, pt_limit, pt_language, pt_max_tokens, pt_chunk_size, pt_no_entity_types, pt_output_dir], outputs=[pt_output, run_pt_button, pt_status]) 985 | 986 | 987 | check_pt_status_button.click( 988 | check_prompt_tuning_status, 989 | outputs=[pt_status, pt_output] 990 | ) 991 | 992 | # Add event handlers for real-time updates 993 | pt_root.change(lambda x: gr.update(value=f"Root Directory changed to: {x}"), inputs=[pt_root], outputs=[pt_status]) 994 | pt_limit.change(lambda x: gr.update(value=f"Limit changed to: {x}"), inputs=[pt_limit], outputs=[pt_status]) 995 | pt_max_tokens.change(lambda x: gr.update(value=f"Max Tokens changed to: {x}"), inputs=[pt_max_tokens], outputs=[pt_status]) 996 | pt_chunk_size.change(lambda x: gr.update(value=f"Chunk Size changed to: {x}"), inputs=[pt_chunk_size], outputs=[pt_status]) 997 | pt_output_dir.change(lambda x: gr.update(value=f"Output Directory changed to: {x}"), inputs=[pt_output_dir], outputs=[pt_status]) 998 | 999 | # Event handlers for Data Management 1000 | upload_btn.click( 1001 | upload_file, 1002 | inputs=[file_upload], 1003 | outputs=[upload_output, file_list, operation_status] 1004 | ) 1005 | 1006 | refresh_btn.click( 1007 | update_file_list, 1008 | outputs=[file_list] 1009 | ) 1010 | 1011 | refresh_output_btn.click( 1012 | update_output_folder_list, 1013 | outputs=[output_folder_list] 1014 | ) 1015 | 1016 | file_list.change( 1017 | update_file_content, 1018 | inputs=[file_list], 1019 | outputs=[file_content] 1020 | ) 1021 | 1022 | delete_btn.click( 1023 | delete_file, 1024 | inputs=[file_list], 1025 | outputs=[operation_status, file_list, operation_status] 1026 | ) 1027 | 1028 | save_btn.click( 1029 | save_file_content, 1030 | inputs=[file_list, file_content], 1031 | outputs=[operation_status, operation_status] 1032 | ) 1033 | 1034 | output_folder_list.change( 1035 | update_folder_content_list, 1036 | inputs=[output_folder_list], 1037 | outputs=[folder_content_list] 1038 | ) 1039 | 1040 | folder_content_list.change( 1041 | handle_content_selection, 1042 | inputs=[output_folder_list, folder_content_list], 1043 | outputs=[folder_content_list, file_info, output_content] 1044 | ) 1045 | 1046 | # Event handler for saving configuration 1047 | save_config_button.click( 1048 | update_env_file, 1049 | inputs=[llm_name, embed_name], 1050 | outputs=[config_status] 1051 | ) 1052 | 1053 | # Event handler for saving prompt tuning configuration 1054 | save_pt_config_button.click( 1055 | save_prompt_tuning_config, 1056 | inputs=[pt_root, pt_domain, pt_method, pt_limit, pt_language, pt_max_tokens, pt_chunk_size, pt_no_entity_types, pt_output_dir], 1057 | outputs=[pt_status] 1058 | ) 1059 | 1060 | # Initialize file list and output folder list 1061 | demo.load(update_file_list, outputs=[file_list]) 1062 | demo.load(update_output_folder_list, outputs=[output_folder_list]) 1063 | 1064 | return demo 1065 | 1066 | def update_env_file(llm_model, embed_model): 1067 | env_path = os.path.join(ROOT_DIR, '.env') 1068 | 1069 | set_key(env_path, 'LLM_MODEL', llm_model) 1070 | set_key(env_path, 'EMBEDDINGS_MODEL', embed_model) 1071 | 1072 | # Reload the environment variables 1073 | load_dotenv(env_path, override=True) 1074 | 1075 | return f"Environment updated: LLM_MODEL={llm_model}, EMBEDDINGS_MODEL={embed_model}" 1076 | 1077 | def save_prompt_tuning_config(root, domain, method, limit, language, max_tokens, chunk_size, no_entity_types, output_dir): 1078 | config = { 1079 | 'prompt_tuning': { 1080 | 'root': root, 1081 | 'domain': domain, 1082 | 'method': method, 1083 | 'limit': limit, 1084 | 'language': language, 1085 | 'max_tokens': max_tokens, 1086 | 'chunk_size': chunk_size, 1087 | 'no_entity_types': no_entity_types, 1088 | 'output': output_dir 1089 | } 1090 | } 1091 | 1092 | config_path = os.path.join(ROOT_DIR, 'prompt_tuning_config.yaml') 1093 | with open(config_path, 'w') as f: 1094 | yaml.dump(config, f) 1095 | 1096 | return f"Prompt Tuning configuration saved to {config_path}" 1097 | 1098 | demo = create_interface() 1099 | 1100 | def main(): 1101 | demo.launch(server_name='0.0.0.0', server_port=7860, share=False) 1102 | 1103 | if __name__ == "__main__": 1104 | main() -------------------------------------------------------------------------------- /indexing/.env: -------------------------------------------------------------------------------- 1 | LLM_API_BASE=http://localhost:11434/v1 2 | LLM_MODEL=qwen2:latest 3 | LLM_API_KEY=ollama 4 | LLM_SERVICE_TYPE=openai_chat 5 | 6 | EMBEDDINGS_API_BASE=http://localhost:11434/v1 7 | EMBEDDINGS_MODEL=nomic-embed-text:latest 8 | EMBEDDINGS_API_KEY=ollama 9 | EMBEDDINGS_SERVICE_TYPE=openai_embedding 10 | 11 | GRAPHRAG_API_KEY=ollama 12 | OUTPUT_DIR=${ROOT_DIR}/output/${timestamp}/artifacts 13 | 14 | API_URL=http://localhost:8012 15 | API_PORT=8012 16 | 17 | CONTEXT_WINDOW=4096 18 | SYSTEM_MESSAGE=You are a helpful AI assistant. 19 | TEMPERATURE=0.5 20 | MAX_TOKENS=1024 21 | -------------------------------------------------------------------------------- /indexing/input/test.txt: -------------------------------------------------------------------------------- 1 | 第1回 惊天地美猴王出世 2 | 3 |   这是一个神话故事,传说在很久很久以前,天下分为东胜神洲、西牛贺洲、南赡部洲、北俱芦洲。在东胜神洲傲来国,有一座花果山,山上有一块仙石,一天仙石崩裂,从石头中滚出一个卵,这个卵一见风就变成一个石猴,猴眼射出一道道金光,向四方朝拜。 4 |   那猴能走、能跑,渴了就喝些山涧中的泉水,饿了就吃些山上的果子。 5 |   整天和山中的动物一起玩乐,过得十分快活。一天,天气特别热,猴子们为了躲避炎热的天气,跑到山涧里洗澡。它们看见这泉水哗哗地流,就顺着涧往前走,去寻找它的源头。 6 |   猴子们爬呀、爬呀,走到了尽头,却看见一股瀑布,像是从天而降一样。猴子们觉得惊奇,商量说∶“哪个敢钻进瀑布,把泉水的源头找出来,又不伤身体,就拜他为王。”连喊了三遍,那石猴呼地跳了出来,高声喊道∶“我进去,我进去!” 7 |   那石猴闭眼纵身跳入瀑布,觉得不像是在水中,这才睁开眼,四处打量,发现自己站在一座铁板桥上,桥下的水冲贯于石窍之间,倒挂着流出来,将桥门遮住,使外面的人看不到里面。石猴走过桥,发现这真是个好地方,石椅、石床、石盆、石碗,样样都有。 8 |   这里就像不久以前有人住过一样,天然的房子,安静整洁,锅、碗、瓢、盆,整齐地放在炉灶上。正当中有一块石碑,上面刻着∶花果山福地,水帘洞洞天。石猴高兴得不得了,忙转身向外走去,嗖的一下跳出了洞。 9 |   猴子们见石猴出来了,身上又一点伤也没有,又惊又喜,把他团团围住,争著问他里面的情况。石猴抓抓腮,挠挠痒,笑嘻嘻地对大家说∶“里面没有水,是一个安身的好地方,刮大风我们有地方躲,下大雨我们也不怕淋。”猴子们一听,一个个高兴得又蹦又跳。 10 |   猴子们随着石猴穿过了瀑布,进入水帘洞中,看见了这么多的好东西,一个个你争我夺,拿盆的拿盆,拿碗的拿碗,占灶的占灶,争床的争床,搬过来,移过去,直到精疲力尽为止。猴子们都遵照诺言,拜石猴为王,石猴从此登上王位,将石字省去,自称“美猴王”。 11 |   美猴王每天带着猴子们游山玩水,很快三、五百年过去了。一天正在玩乐时,美猴王想到自己将来难免一死,不由悲伤得掉下眼泪来,这时猴群中跳出个通背猿猴来,说∶“大王想要长生不老,只有去学佛、学仙、学神之术。” 12 |   美猴王决定走遍天涯海角,也要找到神仙,学那长生不老的本领。第二天,猴子们为他做了一个木筏,又准备了一些野果,于是美猴王告别了群猴们,一个人撑着木筏,奔向汪洋大海。 13 |   大概是美猴王的运气好,连日的东南风,将他送到西北岸边。他下了木筏,登上了岸,看见岸边有许多人都在干活,有的捉鱼,有的打天上的大雁,有的挖蛤蜊,有的淘盐,他悄悄地走过去,没想到,吓得那些人将东西一扔,四处逃命。 14 |   这一天,他来到一座高山前,突然从半山腰的树林里传出一阵美妙的歌声,唱的是一些关于成仙的话。猴王想∶这个唱歌的人一定是神仙,就顺着歌声找去。 15 |   唱歌的是一个正在树林里砍柴的青年人,猴王从这青年人的口中了解到,这座山叫灵台方寸山,离这儿七八里路,有个斜月三星洞,洞中住着一个称为菩提祖师的神仙。 16 |   美猴王告别打柴的青年人,出了树林,走过山坡,果然远远地看见一座洞府,只见洞门紧紧地闭着,洞门对面的山岗上立着一块石碑,大约有三丈多高,八尺多宽,上面写着十个大字∶“灵台方寸山斜月三星洞”。正在看时,门却忽然打开了,走出来一个仙童。 17 |   美猴王赶快走上前,深深地鞠了一个躬,说明来意,那仙童说∶“我师父刚才正要讲道,忽然叫我出来开门,说外面来了个拜师学艺的,原来就是你呀!跟我来吧!”美猴王赶紧整整衣服,恭恭敬敬地跟着仙童进到洞内,来到祖师讲道的法台跟前。 18 |   猴王看见菩提祖师端端正正地坐在台上,台下两边站着三十多个仙童,就赶紧跪下叩头。祖师问清楚他的来意,很高兴,见他没有姓名,便说∶“你就叫悟空吧!” 19 |   祖师叫孙悟空又拜见了各位师兄,并给悟空找了间空房住下。从此悟空跟着师兄学习生活常识,讲究经典,写字烧香,空时做些扫地挑水的活。 20 |   很快七年过去了,一天,祖师讲道结束后,问悟空想学什么本领。孙悟空不管祖师讲什么求神拜佛、打坐修行,只要一听不能长生不老,就不愿意学,菩提祖师对此非常生气。 21 |   祖师从高台上跳了下来,手里拿着戒尺指着孙悟空说∶“你这猴子,这也不学,那也不学,你要学些什么?”说完走过去在悟空头上打了三下,倒背着手走到里间,关上了门。师兄们看到师父生气了,感到很害怕,纷纷责怪孙悟空。 22 |   孙悟空既不怕,又不生气,心里反而十分高兴。当天晚上,悟空假装睡着了,可是一到半夜,就悄悄起来,从前门出去,等到三更,绕到后门口,看见门半开半闭,高兴地不得了,心想∶“哈哈,我没有猜错师父的意思。” 23 |   孙悟空走了进去,看见祖师面朝里睡着,就跪在床前说∶“师父,我跪在这里等着您呢!”祖师听见声音就起来了,盘着腿坐好后,严厉地问孙悟空来做什么,悟空说∶“师父白天当着大家的面不是答应我,让我三更时从后门进来,教我长生不老的法术吗?” 24 |   菩提祖师听到这话心里很高兴。心想∶“这个猴子果然是天地生成的,不然,怎么能猜透我的暗谜。”于是,让孙悟空跪在床前,教给他长生不老的法术。孙悟空洗耳恭听,用心理解,牢牢记住口诀,并叩头拜谢了祖师的恩情。 25 |   很快三年又过去了,祖师又教了孙悟空七十二般变化的法术和驾筋斗云的本领,学会了这个本领,一个筋斗便能翻出十万八千里路程。孙悟空是个猴子,本来就喜欢蹦蹦跳跳的,所以学起筋斗云来很容易。 26 |   有一个夏天,孙悟空和师兄们在洞门前玩耍,大家要孙悟空变个东西看看,孙悟空心里感到很高兴,得意地念起咒语,摇身一变变成了一棵大树。 27 |   师兄们见了,鼓着掌称赞他。 28 |   大家的吵闹声,让菩提祖师听到了,他拄着拐杖出来,问∶“是谁在吵闹?你们这样大吵大叫的,哪里像个出家修行的人呢?”大家都赶紧停住了笑,孙悟空也恢复了原样,给师父解释,请求原谅。 29 |   菩提祖师看见孙悟空刚刚学会了一些本领就卖弄起来,十分生气。祖师叫其他人离开,把悟空狠狠地教训了一顿,并且要把孙悟空赶走。孙悟空着急了,哀求祖师不要赶他走,祖师却不肯留下他,并要他立下誓言∶任何时候都不能说孙悟空是菩提祖师的徒弟。 30 | 31 | 32 | -------------------------------------------------------------------------------- /indexing/prompts/claim_extraction.txt: -------------------------------------------------------------------------------- 1 | 2 | -Target activity- 3 | You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document. 4 | 5 | -Goal- 6 | Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities. 7 | 8 | -Steps- 9 | 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types. 10 | 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim. 11 | For each claim, extract the following information: 12 | - Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1. 13 | - Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**. 14 | - Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type 15 | - Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified. 16 | - Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references. 17 | - Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**. 18 | - Claim Source Text: List of **all** quotes from the original text that are relevant to the claim. 19 | 20 | Format each claim as ({tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) 21 | 22 | 3. Return output in English as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. 23 | 24 | 4. When finished, output {completion_delimiter} 25 | 26 | -Examples- 27 | Example 1: 28 | Entity specification: organization 29 | Claim description: red flags associated with an entity 30 | Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. 31 | Output: 32 | 33 | (COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) 34 | {completion_delimiter} 35 | 36 | Example 2: 37 | Entity specification: Company A, Person C 38 | Claim description: red flags associated with an entity 39 | Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. 40 | Output: 41 | 42 | (COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) 43 | {record_delimiter} 44 | (PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015) 45 | {completion_delimiter} 46 | 47 | -Real Data- 48 | Use the following input for your answer. 49 | Entity specification: {entity_specs} 50 | Claim description: {claim_description} 51 | Text: {input_text} 52 | Output: -------------------------------------------------------------------------------- /indexing/prompts/community_report.txt: -------------------------------------------------------------------------------- 1 | 2 | You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network. 3 | 4 | # Goal 5 | Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims. 6 | 7 | # Report Structure 8 | 9 | The report should include the following sections: 10 | 11 | - TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. 12 | - SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. 13 | - IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. 14 | - RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. 15 | - DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. 16 | 17 | Return output as a well-formed JSON-formatted string with the following format: 18 | {{ 19 | "title": , 20 | "summary": , 21 | "rating": , 22 | "rating_explanation": , 23 | "findings": [ 24 | {{ 25 | "summary":, 26 | "explanation": 27 | }}, 28 | {{ 29 | "summary":, 30 | "explanation": 31 | }} 32 | ] 33 | }} 34 | 35 | # Grounding Rules 36 | 37 | Points supported by data should list their data references as follows: 38 | 39 | "This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." 40 | 41 | Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. 42 | 43 | For example: 44 | "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." 45 | 46 | where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. 47 | 48 | Do not include information where the supporting evidence for it is not provided. 49 | 50 | 51 | # Example Input 52 | ----------- 53 | Text: 54 | 55 | Entities 56 | 57 | id,entity,description 58 | 5,VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March 59 | 6,HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza 60 | 61 | Relationships 62 | 63 | id,source,target,description 64 | 37,VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March 65 | 38,VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza 66 | 39,VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza 67 | 40,VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza 68 | 41,VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march 69 | 43,HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March 70 | 71 | Output: 72 | {{ 73 | "title": "Verdant Oasis Plaza and Unity March", 74 | "summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.", 75 | "rating": 5.0, 76 | "rating_explanation": "The impact severity rating is moderate due to the potential for unrest or conflict during the Unity March.", 77 | "findings": [ 78 | {{ 79 | "summary": "Verdant Oasis Plaza as the central location", 80 | "explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes. [Data: Entities (5), Relationships (37, 38, 39, 40, 41,+more)]" 81 | }}, 82 | {{ 83 | "summary": "Harmony Assembly's role in the community", 84 | "explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community. [Data: Entities(6), Relationships (38, 43)]" 85 | }}, 86 | {{ 87 | "summary": "Unity March as a significant event", 88 | "explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community. [Data: Relationships (39)]" 89 | }}, 90 | {{ 91 | "summary": "Role of Tribune Spotlight", 92 | "explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved. [Data: Relationships (40)]" 93 | }} 94 | ] 95 | }} 96 | 97 | 98 | # Real Data 99 | 100 | Use the following text for your answer. Do not make anything up in your answer. 101 | 102 | Text: 103 | {input_text} 104 | 105 | The report should include the following sections: 106 | 107 | - TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. 108 | - SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. 109 | - IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community. 110 | - RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating. 111 | - DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. 112 | 113 | Return output as a well-formed JSON-formatted string with the following format: 114 | {{ 115 | "title": , 116 | "summary": , 117 | "rating": , 118 | "rating_explanation": , 119 | "findings": [ 120 | {{ 121 | "summary":, 122 | "explanation": 123 | }}, 124 | {{ 125 | "summary":, 126 | "explanation": 127 | }} 128 | ] 129 | }} 130 | 131 | # Grounding Rules 132 | 133 | Points supported by data should list their data references as follows: 134 | 135 | "This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." 136 | 137 | Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. 138 | 139 | For example: 140 | "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (1), Entities (5, 7); Relationships (23); Claims (7, 2, 34, 64, 46, +more)]." 141 | 142 | where 1, 5, 7, 23, 2, 34, 46, and 64 represent the id (not the index) of the relevant data record. 143 | 144 | Do not include information where the supporting evidence for it is not provided. 145 | 146 | Output: -------------------------------------------------------------------------------- /indexing/prompts/entity_extraction.txt: -------------------------------------------------------------------------------- 1 | 2 | -Goal- 3 | Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. 4 | 5 | -Steps- 6 | 1. Identify all entities. For each identified entity, extract the following information: 7 | - entity_name: Name of the entity, capitalized 8 | - entity_type: One of the following types: [{entity_types}] 9 | - entity_description: Comprehensive description of the entity's attributes and activities 10 | Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) 11 | 12 | 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. 13 | For each pair of related entities, extract the following information: 14 | - source_entity: name of the source entity, as identified in step 1 15 | - target_entity: name of the target entity, as identified in step 1 16 | - relationship_description: explanation as to why you think the source entity and the target entity are related to each other 17 | - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity 18 | Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) 19 | 20 | 3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. 21 | 22 | 4. When finished, output {completion_delimiter} 23 | 24 | ###################### 25 | -Examples- 26 | ###################### 27 | Example 1: 28 | Entity_types: ORGANIZATION,PERSON 29 | Text: 30 | The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. 31 | ###################### 32 | Output: 33 | ("entity"{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) 34 | {record_delimiter} 35 | ("entity"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}PERSON{tuple_delimiter}Martin Smith is the chair of the Central Institution) 36 | {record_delimiter} 37 | ("entity"{tuple_delimiter}MARKET STRATEGY COMMITTEE{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) 38 | {record_delimiter} 39 | ("relationship"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{tuple_delimiter}9) 40 | {completion_delimiter} 41 | 42 | ###################### 43 | Example 2: 44 | Entity_types: ORGANIZATION 45 | Text: 46 | TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. 47 | 48 | TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. 49 | ###################### 50 | Output: 51 | ("entity"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}ORGANIZATION{tuple_delimiter}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) 52 | {record_delimiter} 53 | ("entity"{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}ORGANIZATION{tuple_delimiter}Vision Holdings is a firm that previously owned TechGlobal) 54 | {record_delimiter} 55 | ("relationship"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}Vision Holdings formerly owned TechGlobal from 2014 until present{tuple_delimiter}5) 56 | {completion_delimiter} 57 | 58 | ###################### 59 | Example 3: 60 | Entity_types: ORGANIZATION,GEO,PERSON 61 | Text: 62 | Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. 63 | 64 | The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. 65 | 66 | The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. 67 | 68 | They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. 69 | 70 | The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. 71 | ###################### 72 | Output: 73 | ("entity"{tuple_delimiter}FIRUZABAD{tuple_delimiter}GEO{tuple_delimiter}Firuzabad held Aurelians as hostages) 74 | {record_delimiter} 75 | ("entity"{tuple_delimiter}AURELIA{tuple_delimiter}GEO{tuple_delimiter}Country seeking to release hostages) 76 | {record_delimiter} 77 | ("entity"{tuple_delimiter}QUINTARA{tuple_delimiter}GEO{tuple_delimiter}Country that negotiated a swap of money in exchange for hostages) 78 | {record_delimiter} 79 | {record_delimiter} 80 | ("entity"{tuple_delimiter}TIRUZIA{tuple_delimiter}GEO{tuple_delimiter}Capital of Firuzabad where the Aurelians were being held) 81 | {record_delimiter} 82 | ("entity"{tuple_delimiter}KROHAARA{tuple_delimiter}GEO{tuple_delimiter}Capital city in Quintara) 83 | {record_delimiter} 84 | ("entity"{tuple_delimiter}CASHION{tuple_delimiter}GEO{tuple_delimiter}Capital city in Aurelia) 85 | {record_delimiter} 86 | ("entity"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}PERSON{tuple_delimiter}Aurelian who spent time in Tiruzia's Alhamia Prison) 87 | {record_delimiter} 88 | ("entity"{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}GEO{tuple_delimiter}Prison in Tiruzia) 89 | {record_delimiter} 90 | ("entity"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}PERSON{tuple_delimiter}Aurelian journalist who was held hostage) 91 | {record_delimiter} 92 | ("entity"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}PERSON{tuple_delimiter}Bratinas national and environmentalist who was held hostage) 93 | {record_delimiter} 94 | ("relationship"{tuple_delimiter}FIRUZABAD{tuple_delimiter}AURELIA{tuple_delimiter}Firuzabad negotiated a hostage exchange with Aurelia{tuple_delimiter}2) 95 | {record_delimiter} 96 | ("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}AURELIA{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) 97 | {record_delimiter} 98 | ("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) 99 | {record_delimiter} 100 | ("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}Samuel Namara was a prisoner at Alhamia prison{tuple_delimiter}8) 101 | {record_delimiter} 102 | ("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{tuple_delimiter}2) 103 | {record_delimiter} 104 | ("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) 105 | {record_delimiter} 106 | ("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) 107 | {record_delimiter} 108 | ("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Samuel Namara was a hostage in Firuzabad{tuple_delimiter}2) 109 | {record_delimiter} 110 | ("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}FIRUZABAD{tuple_delimiter}Meggie Tazbah was a hostage in Firuzabad{tuple_delimiter}2) 111 | {record_delimiter} 112 | ("relationship"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}FIRUZABAD{tuple_delimiter}Durke Bataglani was a hostage in Firuzabad{tuple_delimiter}2) 113 | {completion_delimiter} 114 | 115 | ###################### 116 | -Real Data- 117 | ###################### 118 | Entity_types: {entity_types} 119 | Text: {input_text} 120 | ###################### 121 | Output: -------------------------------------------------------------------------------- /indexing/prompts/summarize_descriptions.txt: -------------------------------------------------------------------------------- 1 | 2 | You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. 3 | Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. 4 | Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. 5 | If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. 6 | Make sure it is written in third person, and include the entity names so we have the full context. 7 | 8 | ####### 9 | -Data- 10 | Entities: {entity_name} 11 | Description List: {description_list} 12 | ####### 13 | Output: 14 | -------------------------------------------------------------------------------- /indexing/settings.yaml: -------------------------------------------------------------------------------- 1 | 2 | encoding_model: cl100k_base 3 | skip_workflows: [] 4 | llm: 5 | api_key: ${LLM_API_KEY} 6 | type: ${LLM_SERVICE_TYPE} # or azure_openai_chat 7 | model: ${LLM_MODEL} 8 | api_base: ${LLM_API_BASE} 9 | model_supports_json: true # recommended if this is available for your model. 10 | # max_tokens: 4000 11 | # request_timeout: 180.0 12 | # api_base: https://.openai.azure.com 13 | # api_version: 2024-02-15-preview 14 | # organization: 15 | # deployment_name: 16 | # tokens_per_minute: 150_000 # set a leaky bucket throttle 17 | # requests_per_minute: 10_000 # set a leaky bucket throttle 18 | # max_retries: 10 19 | # max_retry_wait: 10.0 20 | # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times 21 | # concurrent_requests: 25 # the number of parallel inflight requests that may be made 22 | # temperature: 0 # temperature for sampling 23 | # top_p: 1 # top-p sampling 24 | # n: 1 # Number of completions to generate 25 | 26 | parallelization: 27 | stagger: 0.3 28 | # num_threads: 50 # the number of threads to use for parallel processing 29 | 30 | async_mode: threaded # or asyncio 31 | 32 | embeddings: 33 | ## parallelization: override the global parallelization settings for embeddings 34 | async_mode: threaded # or asyncio 35 | # target: required # or all 36 | llm: 37 | api_key: ${EMBEDDINGS_API_KEY} 38 | type: ${EMBEDDINGS_SERVICE_TYPE} # or azure_openai_embedding 39 | model: ${EMBEDDINGS_MODEL} 40 | api_base: ${EMBEDDINGS_API_BASE} 41 | # api_base: https://.openai.azure.com 42 | # api_version: 2024-02-15-preview 43 | # organization: 44 | # deployment_name: 45 | # tokens_per_minute: 150_000 # set a leaky bucket throttle 46 | # requests_per_minute: 10_000 # set a leaky bucket throttle 47 | # max_retries: 10 48 | # max_retry_wait: 10.0 49 | # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times 50 | # concurrent_requests: 25 # the number of parallel inflight requests that may be made 51 | # batch_size: 16 # the number of documents to send in a single request 52 | # batch_max_tokens: 8191 # the maximum number of tokens to send in a single request 53 | 54 | 55 | 56 | 57 | chunks: 58 | size: 1200 59 | overlap: 100 60 | group_by_columns: [id] # by default, we don't allow chunks to cross documents 61 | 62 | input: 63 | type: file # or blob 64 | file_type: text # or csv 65 | base_dir: "input" 66 | file_encoding: utf-8 67 | file_pattern: ".*\\.txt$" 68 | 69 | cache: 70 | type: file # or blob 71 | base_dir: "cache" 72 | # connection_string: 73 | # container_name: 74 | 75 | storage: 76 | type: file # or blob 77 | base_dir: "output/${timestamp}/artifacts" 78 | # connection_string: 79 | # container_name: 80 | 81 | reporting: 82 | type: file # or console, blob 83 | base_dir: "output/${timestamp}/reports" 84 | # connection_string: 85 | # container_name: 86 | 87 | entity_extraction: 88 | ## llm: override the global llm settings for this task 89 | ## parallelization: override the global parallelization settings for this task 90 | ## async_mode: override the global async_mode settings for this task 91 | prompt: "prompts/entity_extraction.txt" 92 | entity_types: [organization,person,geo,event] 93 | max_gleanings: 1 94 | 95 | summarize_descriptions: 96 | ## llm: override the global llm settings for this task 97 | ## parallelization: override the global parallelization settings for this task 98 | ## async_mode: override the global async_mode settings for this task 99 | prompt: "prompts/summarize_descriptions.txt" 100 | max_length: 500 101 | 102 | claim_extraction: 103 | ## llm: override the global llm settings for this task 104 | ## parallelization: override the global parallelization settings for this task 105 | ## async_mode: override the global async_mode settings for this task 106 | enabled: true 107 | prompt: "prompts/claim_extraction.txt" 108 | description: "Any claims or facts that could be relevant to information discovery." 109 | max_gleanings: 1 110 | 111 | community_reports: 112 | ## llm: override the global llm settings for this task 113 | ## parallelization: override the global parallelization settings for this task 114 | ## async_mode: override the global async_mode settings for this task 115 | prompt: "prompts/community_report.txt" 116 | max_length: 2000 117 | max_input_length: 8000 118 | 119 | cluster_graph: 120 | max_cluster_size: 10 121 | 122 | embed_graph: 123 | enabled: true # if true, will generate node2vec embeddings for nodes 124 | # num_walks: 10 125 | # walk_length: 40 126 | # window_size: 2 127 | # iterations: 3 128 | # random_seed: 597832 129 | 130 | umap: 131 | enabled: true # if true, will generate UMAP embeddings for nodes 132 | 133 | snapshots: 134 | graphml: true 135 | raw_entities: true 136 | top_level_nodes: true 137 | 138 | local_search: 139 | # text_unit_prop: 0.5 140 | # community_prop: 0.1 141 | # conversation_history_max_turns: 5 142 | # top_k_mapped_entities: 10 143 | # top_k_relationships: 10 144 | # llm_temperature: 0 # temperature for sampling 145 | # llm_top_p: 1 # top-p sampling 146 | # llm_n: 1 # Number of completions to generate 147 | # max_tokens: 12000 148 | 149 | global_search: 150 | # llm_temperature: 0 # temperature for sampling 151 | # llm_top_p: 1 # top-p sampling 152 | # llm_n: 1 # Number of completions to generate 153 | # max_tokens: 12000 154 | # data_max_tokens: 12000 155 | # map_max_tokens: 1000 156 | # reduce_max_tokens: 2000 157 | # concurrency: 32 158 | -------------------------------------------------------------------------------- /lancedb/empty.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wade1010/graphrag-ui/493e051a05890803e6a566c811a96467bc20d68e/lancedb/empty.md -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gradio==4.43.0 2 | fastapi==0.112.4 3 | uvicorn==0.30.6 4 | python-dotenv==1.0.1 5 | pydantic==2.9.0 6 | pandas==2.2.2 7 | tiktoken==0.7.0 8 | langchain-community==0.2.16 9 | aiohttp==3.10.5 10 | PyYAML==6.0.2 11 | requests==2.32.3 12 | duckduckgo-search==6.2.11 13 | ollama==0.3.2 14 | plotly==5.24.0 -------------------------------------------------------------------------------- /settings-example.yaml: -------------------------------------------------------------------------------- 1 | 2 | encoding_model: cl100k_base 3 | skip_workflows: [] 4 | llm: 5 | api_key: ${LLM_API_KEY} 6 | type: ${LLM_SERVICE_TYPE} # or azure_openai_chat 7 | model: ${LLM_MODEL} 8 | api_base: ${LLM_API_BASE} 9 | model_supports_json: true # recommended if this is available for your model. 10 | # max_tokens: 4000 11 | # request_timeout: 180.0 12 | # api_base: https://.openai.azure.com 13 | # api_version: 2024-02-15-preview 14 | # organization: 15 | # deployment_name: 16 | # tokens_per_minute: 150_000 # set a leaky bucket throttle 17 | # requests_per_minute: 10_000 # set a leaky bucket throttle 18 | # max_retries: 10 19 | # max_retry_wait: 10.0 20 | # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times 21 | # concurrent_requests: 25 # the number of parallel inflight requests that may be made 22 | # temperature: 0 # temperature for sampling 23 | # top_p: 1 # top-p sampling 24 | # n: 1 # Number of completions to generate 25 | 26 | parallelization: 27 | stagger: 0.3 28 | # num_threads: 50 # the number of threads to use for parallel processing 29 | 30 | async_mode: threaded # or asyncio 31 | 32 | embeddings: 33 | ## parallelization: override the global parallelization settings for embeddings 34 | async_mode: threaded # or asyncio 35 | # target: required # or all 36 | llm: 37 | api_key: ${EMBEDDINGS_API_KEY} 38 | type: ${EMBEDDINGS_SERVICE_TYPE} # or azure_openai_embedding 39 | model: ${EMBEDDINGS_MODEL} 40 | api_base: ${EMBEDDINGS_API_BASE} 41 | # api_base: https://.openai.azure.com 42 | # api_version: 2024-02-15-preview 43 | # organization: 44 | # deployment_name: 45 | # tokens_per_minute: 150_000 # set a leaky bucket throttle 46 | # requests_per_minute: 10_000 # set a leaky bucket throttle 47 | # max_retries: 10 48 | # max_retry_wait: 10.0 49 | # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times 50 | # concurrent_requests: 25 # the number of parallel inflight requests that may be made 51 | # batch_size: 16 # the number of documents to send in a single request 52 | # batch_max_tokens: 8191 # the maximum number of tokens to send in a single request 53 | 54 | 55 | 56 | 57 | chunks: 58 | size: 1200 59 | overlap: 100 60 | group_by_columns: [id] # by default, we don't allow chunks to cross documents 61 | 62 | input: 63 | type: file # or blob 64 | file_type: text # or csv 65 | base_dir: "input" 66 | file_encoding: utf-8 67 | file_pattern: ".*\\.txt$" 68 | 69 | cache: 70 | type: file # or blob 71 | base_dir: "cache" 72 | # connection_string: 73 | # container_name: 74 | 75 | storage: 76 | type: file # or blob 77 | base_dir: "output/${timestamp}/artifacts" 78 | # connection_string: 79 | # container_name: 80 | 81 | reporting: 82 | type: file # or console, blob 83 | base_dir: "output/${timestamp}/reports" 84 | # connection_string: 85 | # container_name: 86 | 87 | entity_extraction: 88 | ## llm: override the global llm settings for this task 89 | ## parallelization: override the global parallelization settings for this task 90 | ## async_mode: override the global async_mode settings for this task 91 | prompt: "prompts/entity_extraction.txt" 92 | entity_types: [organization,person,geo,event] 93 | max_gleanings: 1 94 | 95 | summarize_descriptions: 96 | ## llm: override the global llm settings for this task 97 | ## parallelization: override the global parallelization settings for this task 98 | ## async_mode: override the global async_mode settings for this task 99 | prompt: "prompts/summarize_descriptions.txt" 100 | max_length: 500 101 | 102 | claim_extraction: 103 | ## llm: override the global llm settings for this task 104 | ## parallelization: override the global parallelization settings for this task 105 | ## async_mode: override the global async_mode settings for this task 106 | enabled: true 107 | prompt: "prompts/claim_extraction.txt" 108 | description: "Any claims or facts that could be relevant to information discovery." 109 | max_gleanings: 1 110 | 111 | community_reports: 112 | ## llm: override the global llm settings for this task 113 | ## parallelization: override the global parallelization settings for this task 114 | ## async_mode: override the global async_mode settings for this task 115 | prompt: "prompts/community_report.txt" 116 | max_length: 2000 117 | max_input_length: 8000 118 | 119 | cluster_graph: 120 | max_cluster_size: 10 121 | 122 | embed_graph: 123 | enabled: false # if true, will generate node2vec embeddings for nodes 124 | # num_walks: 10 125 | # walk_length: 40 126 | # window_size: 2 127 | # iterations: 3 128 | # random_seed: 597832 129 | 130 | umap: 131 | enabled: false # if true, will generate UMAP embeddings for nodes 132 | 133 | snapshots: 134 | graphml: false 135 | raw_entities: false 136 | top_level_nodes: false 137 | 138 | local_search: 139 | # text_unit_prop: 0.5 140 | # community_prop: 0.1 141 | # conversation_history_max_turns: 5 142 | # top_k_mapped_entities: 10 143 | # top_k_relationships: 10 144 | # llm_temperature: 0 # temperature for sampling 145 | # llm_top_p: 1 # top-p sampling 146 | # llm_n: 1 # Number of completions to generate 147 | # max_tokens: 12000 148 | 149 | global_search: 150 | # llm_temperature: 0 # temperature for sampling 151 | # llm_top_p: 1 # top-p sampling 152 | # llm_n: 1 # Number of completions to generate 153 | # max_tokens: 12000 154 | # data_max_tokens: 12000 155 | # map_max_tokens: 1000 156 | # reduce_max_tokens: 2000 157 | # concurrency: 32 158 | -------------------------------------------------------------------------------- /web.py: -------------------------------------------------------------------------------- 1 | """Util that calls DuckDuckGo Search. 2 | 3 | No setup required. Free. 4 | https://pypi.org/project/duckduckgo-search/ 5 | """ 6 | 7 | from typing import Dict, List, Optional 8 | 9 | from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator 10 | 11 | 12 | class DuckDuckGoSearchAPIWrapper(BaseModel): 13 | """Wrapper for DuckDuckGo Search API. 14 | 15 | Free and does not require any setup. 16 | """ 17 | 18 | region: Optional[str] = "wt-wt" 19 | """ 20 | See https://pypi.org/project/duckduckgo-search/#regions 21 | """ 22 | safesearch: str = "moderate" 23 | """ 24 | Options: strict, moderate, off 25 | """ 26 | time: Optional[str] = "y" 27 | """ 28 | Options: d, w, m, y 29 | """ 30 | max_results: int = 5 31 | backend: str = "api" 32 | """ 33 | Options: api, html, lite 34 | """ 35 | source: str = "text" 36 | """ 37 | Options: text, news 38 | """ 39 | 40 | class Config: 41 | """Configuration for this pydantic object.""" 42 | 43 | extra = Extra.forbid 44 | 45 | @root_validator(pre=True) 46 | def validate_environment(cls, values: Dict) -> Dict: 47 | """Validate that python package exists in environment.""" 48 | try: 49 | from duckduckgo_search import DDGS # noqa: F401 50 | except ImportError: 51 | raise ImportError( 52 | "Could not import duckduckgo-search python package. " 53 | "Please install it with `pip install -U duckduckgo-search`." 54 | ) 55 | return values 56 | 57 | def _ddgs_text( 58 | self, query: str, max_results: Optional[int] = None 59 | ) -> List[Dict[str, str]]: 60 | """Run query through DuckDuckGo text search and return results.""" 61 | from duckduckgo_search import DDGS 62 | 63 | with DDGS() as ddgs: 64 | ddgs_gen = ddgs.text( 65 | query, 66 | region=self.region, 67 | safesearch=self.safesearch, 68 | timelimit=self.time, 69 | max_results=max_results or self.max_results, 70 | backend=self.backend, 71 | ) 72 | if ddgs_gen: 73 | return [r for r in ddgs_gen] 74 | return [] 75 | 76 | def _ddgs_news( 77 | self, query: str, max_results: Optional[int] = None 78 | ) -> List[Dict[str, str]]: 79 | """Run query through DuckDuckGo news search and return results.""" 80 | from duckduckgo_search import DDGS 81 | 82 | with DDGS() as ddgs: 83 | ddgs_gen = ddgs.news( 84 | query, 85 | region=self.region, 86 | safesearch=self.safesearch, 87 | timelimit=self.time, 88 | max_results=max_results or self.max_results, 89 | ) 90 | if ddgs_gen: 91 | return [r for r in ddgs_gen] 92 | return [] 93 | 94 | def run(self, query: str) -> str: 95 | """Run query through DuckDuckGo and return concatenated results.""" 96 | if self.source == "text": 97 | results = self._ddgs_text(query) 98 | elif self.source == "news": 99 | results = self._ddgs_news(query) 100 | else: 101 | results = [] 102 | 103 | if not results: 104 | return "No good DuckDuckGo Search Result was found" 105 | return " ".join(r["body"] for r in results) 106 | 107 | 108 | def results( 109 | self, query: str, max_results: int, source: Optional[str] = None 110 | ) -> List[Dict[str, str]]: 111 | """Run query through DuckDuckGo and return metadata. 112 | 113 | Args: 114 | query: The query to search for. 115 | max_results: The number of results to return. 116 | source: The source to look from. 117 | 118 | Returns: 119 | A list of dictionaries with the following keys: 120 | snippet - The description of the result. 121 | title - The title of the result. 122 | link - The link to the result. 123 | """ 124 | source = source or self.source 125 | if source == "text": 126 | results = [ 127 | {"snippet": r["body"], "title": r["title"], "link": r["href"]} 128 | for r in self._ddgs_text(query, max_results=max_results) 129 | ] 130 | elif source == "news": 131 | results = [ 132 | { 133 | "snippet": r["body"], 134 | "title": r["title"], 135 | "link": r["url"], 136 | "date": r["date"], 137 | "source": r["source"], 138 | } 139 | for r in self._ddgs_news(query, max_results=max_results) 140 | ] 141 | else: 142 | results = [] 143 | 144 | if results is None: 145 | results = [{"Result": "No good DuckDuckGo Search Result was found"}] 146 | 147 | return results --------------------------------------------------------------------------------