├── GraphRAG4OpenWebUI └── test ├── requirements.txt ├── README_ZH-CN.md ├── README.md ├── graphrag3dknowledge.py ├── LICENSE ├── main-cn.py └── main-en.py /GraphRAG4OpenWebUI/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | uvicorn 3 | pandas 4 | tiktoken 5 | graphrag 6 | tavily-python 7 | pydantic 8 | python-dotenv 9 | asyncio 10 | aiohttp 11 | numpy 12 | scikit-learn 13 | matplotlib 14 | seaborn 15 | nltk 16 | spacy 17 | transformers 18 | torch 19 | torchvision 20 | torchaudio 21 | -------------------------------------------------------------------------------- /README_ZH-CN.md: -------------------------------------------------------------------------------- 1 | ### 🔥🔥🔥如有问题请联系我的微信 stoeng 2 | ### 🔥🔥🔥项目对应的视频演示请看 https://youtu.be/z4Si6O5NQ4c 3 | 4 | # GraphRAG4OpenWebUI 5 |
6 |

将微软的 GraphRAG 技术集成到 Open WebUI 中,实现高级信息检索

7 | English | 简体中文 8 |
9 | 10 | GraphRAG4OpenWebUI 是一个专为 Open WebUI 设计的 API 接口,旨在集成微软研究院的 GraphRAG(基于图的检索增强生成)技术。该项目提供了一个强大的信息检索系统,支持多种搜索模型,特别适合在开放式 Web 用户界面中使用。 11 | 12 | ## 项目概述 13 | 14 | 本项目的主要目标是为 Open WebUI 提供一个便捷的接口,以利用 GraphRAG 的强大功能。它集成了三种主要的检索方法,并提供了一个综合搜索选项,使用户能够获得全面而精确的搜索结果。 15 | 16 | ### 主要检索功能 17 | 18 | 1. **本地搜索(Local Search)** 19 | - 利用 GraphRAG 技术在本地知识库中进行高效检索 20 | - 适用于快速访问预先定义的结构化信息 21 | - 利用图结构提高检索的准确性和相关性 22 | 23 | 2. **全局搜索(Global Search)** 24 | - 在更广泛的范围内搜索信息,超越本地知识库的限制 25 | - 适用于需要更全面信息的查询 26 | - 利用 GraphRAG 的全局上下文理解能力,提供更丰富的搜索结果 27 | 28 | 3. **Tavily 搜索** 29 | - 集成外部 Tavily 搜索 API 30 | - 提供额外的互联网搜索能力,扩展信息源 31 | - 适用于需要最新或广泛网络信息的查询 32 | 33 | 4. **全模型搜索(Full Model Search)** 34 | - 综合上述三种搜索方法 35 | - 提供最全面的搜索结果,满足复杂的信息需求 36 | - 自动整合和排序来自不同来源的信息 37 | 38 | ### 本地LLM和Embedding模型支持 39 | 40 | GraphRAG4OpenWebUI 现在支持使用本地的语言模型(LLM)和嵌入模型,增加了项目的灵活性和隐私性。特别地,我们支持以下本地模型: 41 | 42 | 1. **Ollama** 43 | - 支持使用 Ollama 运行的各种开源 LLM,如 Llama 2、Mistral 等 44 | - 可以通过设置 `API_BASE` 环境变量来指向 Ollama 的 API 端点 45 | 46 | 2. **LM Studio** 47 | - 兼容 LM Studio 运行的模型 48 | - 通过配置 `API_BASE` 环境变量连接到 LM Studio 的服务 49 | 50 | 3. **本地 Embedding 模型** 51 | - 支持使用本地运行的嵌入模型,如 SentenceTransformers 52 | - 通过设置 `GRAPHRAG_EMBEDDING_MODEL` 环境变量来指定使用的嵌入模型 53 | 54 | 这些本地模型的支持使得 GraphRAG4OpenWebUI 能够在不依赖外部API的情况下运行,提高了数据隐私和降低了使用成本。 55 | 56 | ## 安装 57 | 确保您的系统中已安装 Python 3.8 或更高版本。然后,按照以下步骤安装: 58 | 1. 克隆仓库: 59 | ```bash 60 | git clone https://github.com/your-username/GraphRAG4OpenWebUI.git 61 | cd GraphRAG4OpenWebUI 62 | ``` 63 | 64 | 2. 创建并激活虚拟环境: 65 | ```bash 66 | python -m venv venv 67 | source venv/bin/activate # 在 Windows 上使用 venv\Scripts\activate 68 | ``` 69 | 70 | 3. 安装依赖: 71 | ```bash 72 | pip install -r requirements.txt 73 | ``` 74 | 注意:graphrag 包可能需要从特定的源安装。如果上述命令无法安装 graphrag,请参考微软研究院的具体说明或联系维护者获取正确的安装方法。 75 | 76 | ## 配置 77 | 78 | 在运行 API 之前,需要设置以下环境变量。您可以通过创建 `.env` 文件或直接在终端中导出这些变量: 79 | 80 | 81 | ```bash 82 | export TAVILY_API_KEY="your_tavily_api_key" 83 | 84 | export INPUT_DIR="/path/to/your/input/directory" 85 | 86 | # 设置llm API密钥 87 | export GRAPHRAG_API_KEY="your_actual_api_key_here" 88 | 89 | # 设置嵌入API密钥(如果与GRAPHRAG_API_KEY不同) 90 | export GRAPHRAG_API_KEY_EMBEDDING="your_embedding_api_key_here" 91 | 92 | # 设置LLM模型(默认为"gemma2") 93 | export GRAPHRAG_LLM_MODEL="gemma2" 94 | 95 | # 设置API基础URL(默认为本地服务器) 96 | export API_BASE="http://localhost:11434/v1" 97 | 98 | # 设置嵌入API基础URL(默认为OpenAI的API) 99 | export API_BASE_EMBEDDING="https://api.openai.com/v1" 100 | 101 | # 设置嵌入模型(默认为"text-embedding-3-small") 102 | export GRAPHRAG_EMBEDDING_MODEL="text-embedding-3-small" 103 | ``` 104 | 105 | 请确保将上述命令中的占位符替换为实际的 API 密钥和路径。 106 | 107 | ## 使用方法 108 | 109 | 1. 启动服务器: 110 | ``` 111 | python main-cn.py 112 | ``` 113 | 服务器将在 `http://localhost:8012` 上运行。 114 | 115 | 2. API 端点: 116 | - `/v1/chat/completions`: POST 请求,用于执行搜索 117 | - `/v1/models`: GET 请求,获取可用模型列表 118 | 119 | 3. 在 Open WebUI 中集成: 120 | 在 Open WebUI 的配置中,将 API 端点设置为 `http://localhost:8012/v1/chat/completions`。这将允许 Open WebUI 使用 GraphRAG4OpenWebUI 的搜索功能。 121 | 122 | 4. 发送搜索请求示例: 123 | ```python 124 | import requests 125 | import json 126 | 127 | url = "http://localhost:8012/v1/chat/completions" 128 | headers = {"Content-Type": "application/json"} 129 | data = { 130 | "model": "full-model:latest", 131 | "messages": [{"role": "user", "content": "您的搜索查询"}], 132 | "temperature": 0.7 133 | } 134 | 135 | response = requests.post(url, headers=headers, data=json.dumps(data)) 136 | print(response.json()) 137 | ``` 138 | 139 | ## 可用模型 140 | 141 | - `graphrag-local-search:latest`: 本地搜索 142 | - `graphrag-global-search:latest`: 全局搜索 143 | - `tavily-search:latest`: Tavily 搜索 144 | - `full-model:latest`: 综合搜索(包含上述所有搜索方法) 145 | 146 | ## 注意事项 147 | 148 | - 确保在 `INPUT_DIR` 目录中有正确的输入文件(如 Parquet 文件)。 149 | - API 使用异步编程,确保您的环境支持异步操作。 150 | - 对于大规模部署,建议使用生产级的 ASGI 服务器。 151 | - 本项目专为 Open WebUI 设计,可以轻松集成到各种基于 Web 的应用中。 152 | 153 | ## 贡献 154 | 155 | 我们欢迎您提交 Pull Requests 来改进这个项目。对于重大变更,请先开 issue 讨论您想要改变的内容。 156 | 157 | ## 许可证 158 | 159 | [Apache-2.0 许可证](LICENSE) 160 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # GraphRAG4OpenWebUI 3 |
4 |

Integrate Microsoft's GraphRAG Technology into Open WebUI for Advanced Information Retrieval

5 | English 6 |
7 | 8 | GraphRAG4OpenWebUI is an API interface specifically designed for Open WebUI, aiming to integrate Microsoft Research's GraphRAG (Graph-based Retrieval-Augmented Generation) technology. This project provides a powerful information retrieval system that supports multiple search models, particularly suitable for use in open web user interfaces. 9 | 10 | ## Project Overview 11 | 12 | The main goal of this project is to provide a convenient interface for Open WebUI to leverage the powerful features of GraphRAG. It integrates three main retrieval methods and offers a comprehensive search option, allowing users to obtain thorough and precise search results. 13 | 14 | ### Key Retrieval Features 15 | 16 | 1. **Local Search** 17 | - Utilizes GraphRAG technology for efficient retrieval in local knowledge bases 18 | - Suitable for quick access to pre-defined structured information 19 | - Leverages graph structures to improve retrieval accuracy and relevance 20 | 21 | 2. **Global Search** 22 | - Searches for information in a broader scope, beyond local knowledge bases 23 | - Suitable for queries requiring more comprehensive information 24 | - Utilizes GraphRAG's global context understanding capabilities to provide richer search results 25 | 26 | 3. **Tavily Search** 27 | - Integrates external Tavily search API 28 | - Provides additional internet search capabilities, expanding information sources 29 | - Suitable for queries requiring the latest or extensive web information 30 | 31 | 4. **Full Model Search** 32 | - Combines all three search methods above 33 | - Provides the most comprehensive search results, meeting complex information needs 34 | - Automatically integrates and ranks information from different sources 35 | 36 | ### Local LLM and Embedding Model Support 37 | 38 | GraphRAG4OpenWebUI now supports the use of local language models (LLMs) and embedding models, increasing the project's flexibility and privacy. Specifically, we support the following local models: 39 | 40 | 1. **Ollama** 41 | - Supports various open-source LLMs run through Ollama, such as Llama 2, Mistral, etc. 42 | - Can be configured by setting the `API_BASE` environment variable to point to Ollama's API endpoint 43 | 44 | 2. **LM Studio** 45 | - Compatible with models run by LM Studio 46 | - Connect to LM Studio's service by configuring the `API_BASE` environment variable 47 | 48 | 3. **Local Embedding Models** 49 | - Supports the use of locally run embedding models, such as SentenceTransformers 50 | - Specify the embedding model to use by setting the `GRAPHRAG_EMBEDDING_MODEL` environment variable 51 | 52 | This support for local models allows GraphRAG4OpenWebUI to run without relying on external APIs, enhancing data privacy and reducing usage costs. 53 | 54 | ## Installation 55 | Ensure that you have Python 3.8 or higher installed on your system. Then, follow these steps to install: 56 | 1. Clone the repository: 57 | ```bash 58 | git clone https://github.com/your-username/GraphRAG4OpenWebUI.git 59 | cd GraphRAG4OpenWebUI 60 | ``` 61 | 62 | 2. Create and activate a virtual environment: 63 | ```bash 64 | python -m venv venv 65 | source venv/bin/activate # On Windows use venv\Scripts\activate 66 | ``` 67 | 68 | 3. Install dependencies: 69 | ```bash 70 | pip install -r requirements.txt 71 | ``` 72 | Note: The graphrag package might need to be installed from a specific source. If the above command fails to install graphrag, please refer to Microsoft Research's specific instructions or contact the maintainer for the correct installation method. 73 | 74 | ## Configuration 75 | 76 | Before running the API, you need to set the following environment variables. You can do this by creating a `.env` file or exporting them directly in your terminal: 77 | 78 | 79 | 80 | ```bash 81 | # Set the TAVILY API key 82 | export TAVILY_API_KEY="your_tavily_api_key" 83 | 84 | export INPUT_DIR="/path/to/your/input/directory" 85 | 86 | # Set the API key for LLM 87 | export GRAPHRAG_API_KEY="your_actual_api_key_here" 88 | 89 | # Set the API key for embedding (if different from GRAPHRAG_API_KEY) 90 | export GRAPHRAG_API_KEY_EMBEDDING="your_embedding_api_key_here" 91 | 92 | # Set the LLM model 93 | export GRAPHRAG_LLM_MODEL="gemma2" 94 | 95 | # Set the API base URL 96 | export API_BASE="http://localhost:11434/v1" 97 | 98 | # Set the embedding API base URL (default is OpenAI's API) 99 | export API_BASE_EMBEDDING="https://api.openai.com/v1" 100 | 101 | # Set the embedding model (default is "text-embedding-3-small") 102 | export GRAPHRAG_EMBEDDING_MODEL="text-embedding-3-small" 103 | ``` 104 | 105 | Make sure to replace the placeholders in the above commands with your actual API keys and paths. 106 | 107 | ## Usage 108 | 109 | 1. Start the server: 110 | ``` 111 | python main-en.py 112 | ``` 113 | The server will run on `http://localhost:8012`. 114 | 115 | 2. API Endpoints: 116 | - `/v1/chat/completions`: POST request for performing searches 117 | - `/v1/models`: GET request to retrieve the list of available models 118 | 119 | 3. Integration with Open WebUI: 120 | In the Open WebUI configuration, set the API endpoint to `http://localhost:8012/v1/chat/completions`. This will allow Open WebUI to use the search functionality of GraphRAG4OpenWebUI. 121 | 122 | 4. Example search request: 123 | ```python 124 | import requests 125 | import json 126 | 127 | url = "http://localhost:8012/v1/chat/completions" 128 | headers = {"Content-Type": "application/json"} 129 | data = { 130 | "model": "full-model:latest", 131 | "messages": [{"role": "user", "content": "Your search query"}], 132 | "temperature": 0.7 133 | } 134 | 135 | response = requests.post(url, headers=headers, data=json.dumps(data)) 136 | print(response.json()) 137 | ``` 138 | 139 | ## Available Models 140 | 141 | - `graphrag-local-search:latest`: Local search 142 | - `graphrag-global-search:latest`: Global search 143 | - `tavily-search:latest`: Tavily search 144 | - `full-model:latest`: Comprehensive search (includes all search methods above) 145 | 146 | ## Notes 147 | 148 | - Ensure that you have the correct input files (such as Parquet files) in the `INPUT_DIR` directory. 149 | - The API uses asynchronous programming, make sure your environment supports async operations. 150 | - For large-scale deployment, consider using a production-grade ASGI server. 151 | - This project is specifically designed for Open WebUI and can be easily integrated into various web-based applications. 152 | 153 | ## Contributing 154 | 155 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 156 | 157 | ## License 158 | 159 | [Apache-2.0 License](LICENSE) 160 | -------------------------------------------------------------------------------- /graphrag3dknowledge.py: -------------------------------------------------------------------------------- 1 | import os #用于文件系统操作 2 | import pandas as pd #用于数据处理和操作 3 | import networkx as nx #用于创建和分析图结构 4 | import plotly.graph_objects as go #plotly:用于创建交互式可视化 plotly.graph_objects:用于创建低级的plotly图形对象 5 | from plotly.subplots import make_subplots #用于创建子图 6 | import plotly.express as px #用于快速创建统计图表 7 | 8 | def read_parquet_files(directory): 9 | """ 10 | 读取指定目录下的所有Parquet文件并合并 11 | 功能:读取指定目录下的所有Parquet文件并合并成一个DataFrame 12 | 实现:使用os.listdir遍历目录,pd.read_parquet读取每个文件,然后用pd.concat合并 13 | """ 14 | dataframes = [] 15 | for filename in os.listdir(directory): 16 | if filename.endswith('.parquet'): 17 | file_path = os.path.join(directory, filename) 18 | df = pd.read_parquet(file_path) 19 | dataframes.append(df) 20 | return pd.concat(dataframes, ignore_index=True) if dataframes else pd.DataFrame() 21 | 22 | 23 | def clean_dataframe(df): 24 | """ 25 | 清理DataFrame,移除无效的行 26 | 功能:清理DataFrame,移除无效的行 27 | 实现:删除source和target列中的空值,将这两列转换为字符串类型 28 | """ 29 | df = df.dropna(subset=['source', 'target']) 30 | df['source'] = df['source'].astype(str) 31 | df['target'] = df['target'].astype(str) 32 | return df 33 | 34 | 35 | def create_knowledge_graph(df): 36 | """ 37 | 从DataFrame创建知识图谱 38 | 功能:从DataFrame创建知识图谱 39 | 实现:使用networkx创建有向图,遍历DataFrame的每一行,添加边和属性 40 | """ 41 | G = nx.DiGraph() 42 | for _, row in df.iterrows(): 43 | source = row['source'] 44 | target = row['target'] 45 | attributes = {k: v for k, v in row.items() if k not in ['source', 'target']} 46 | G.add_edge(source, target, **attributes) 47 | return G 48 | 49 | 50 | def create_node_link_trace(G, pos): 51 | """ 52 | 功能:创建节点和边的3D轨迹 53 | 实现:使用networkx的布局信息创建Plotly的Scatter3d对象 54 | """ 55 | edge_x = [] 56 | edge_y = [] 57 | edge_z = [] 58 | for edge in G.edges(): 59 | x0, y0, z0 = pos[edge[0]] 60 | x1, y1, z1 = pos[edge[1]] 61 | edge_x.extend([x0, x1, None]) 62 | edge_y.extend([y0, y1, None]) 63 | edge_z.extend([z0, z1, None]) 64 | 65 | edge_trace = go.Scatter3d( 66 | x=edge_x, y=edge_y, z=edge_z, 67 | line=dict(width=0.5, color='#888'), 68 | hoverinfo='none', 69 | mode='lines') 70 | 71 | node_x = [pos[node][0] for node in G.nodes()] 72 | node_y = [pos[node][1] for node in G.nodes()] 73 | node_z = [pos[node][2] for node in G.nodes()] 74 | 75 | node_trace = go.Scatter3d( 76 | x=node_x, y=node_y, z=node_z, 77 | mode='markers', 78 | hoverinfo='text', 79 | marker=dict( 80 | showscale=True, 81 | colorscale='YlGnBu', 82 | size=10, 83 | colorbar=dict( 84 | thickness=15, 85 | title='Node Connections', 86 | xanchor='left', 87 | titleside='right' 88 | ) 89 | ) 90 | ) 91 | 92 | node_adjacencies = [] 93 | node_text = [] 94 | for node, adjacencies in G.adjacency(): 95 | node_adjacencies.append(len(adjacencies)) 96 | node_text.append(f'Node: {node}
# of connections: {len(adjacencies)}') 97 | 98 | node_trace.marker.color = node_adjacencies 99 | node_trace.text = node_text 100 | 101 | return edge_trace, node_trace 102 | 103 | 104 | def create_edge_label_trace(G, pos, edge_labels): 105 | """ 106 | 功能:创建边标签的3D轨迹 107 | 实现:计算边的中点位置,创建Scatter3d对象显示标签 108 | """ 109 | return go.Scatter3d( 110 | x=[pos[edge[0]][0] + (pos[edge[1]][0] - pos[edge[0]][0]) / 2 for edge in edge_labels], 111 | y=[pos[edge[0]][1] + (pos[edge[1]][1] - pos[edge[0]][1]) / 2 for edge in edge_labels], 112 | z=[pos[edge[0]][2] + (pos[edge[1]][2] - pos[edge[0]][2]) / 2 for edge in edge_labels], 113 | mode='text', 114 | text=list(edge_labels.values()), 115 | textposition='middle center', 116 | hoverinfo='none' 117 | ) 118 | 119 | 120 | def create_degree_distribution(G): 121 | """ 122 | 功能:创建节点度分布直方图 123 | 实现:使用plotly.express创建直方图 124 | """ 125 | degrees = [d for n, d in G.degree()] 126 | fig = px.histogram(x=degrees, nbins=20, labels={'x': 'Degree', 'y': 'Count'}) 127 | fig.update_layout( 128 | title_text='Node Degree Distribution', 129 | margin=dict(l=0, r=0, t=30, b=0), 130 | height=300 131 | ) 132 | return fig 133 | 134 | 135 | def create_centrality_plot(G): 136 | """ 137 | 功能:创建节点中心性分布箱线图 138 | 实现:计算度中心性,使用plotly.express创建箱线图 139 | """ 140 | centrality = nx.degree_centrality(G) 141 | centrality_values = list(centrality.values()) 142 | fig = px.box(y=centrality_values, labels={'y': 'Centrality'}) 143 | fig.update_layout( 144 | title_text='Degree Centrality Distribution', 145 | margin=dict(l=0, r=0, t=30, b=0), 146 | height=300 147 | ) 148 | return fig 149 | 150 | 151 | def visualize_graph_plotly(G): 152 | """功能:使用Plotly创建全面优化布局的高级交互式知识图谱可视化 153 | 实现: 154 | 创建3D布局 155 | 生成节点和边的轨迹 156 | 创建子图,包括3D图、度分布图和中心性分布图 157 | 添加交互式按钮和滑块 158 | 优化整体布局 159 | """ 160 | if G.number_of_nodes() == 0: 161 | print("Graph is empty. Nothing to visualize.") 162 | return 163 | 164 | pos = nx.spring_layout(G, dim=3) # 3D layout 165 | edge_trace, node_trace = create_node_link_trace(G, pos) 166 | 167 | edge_labels = nx.get_edge_attributes(G, 'relation') 168 | edge_label_trace = create_edge_label_trace(G, pos, edge_labels) 169 | 170 | degree_dist_fig = create_degree_distribution(G) 171 | centrality_fig = create_centrality_plot(G) 172 | 173 | fig = make_subplots( 174 | rows=2, cols=2, 175 | column_widths=[0.7, 0.3], 176 | row_heights=[0.7, 0.3], 177 | specs=[ 178 | [{"type": "scene", "rowspan": 2}, {"type": "xy"}], 179 | [None, {"type": "xy"}] 180 | ], 181 | subplot_titles=("3D Knowledge Graph Code by AI超元域频道", "Node Degree Distribution", "Degree Centrality Distribution") 182 | ) 183 | 184 | fig.add_trace(edge_trace, row=1, col=1) 185 | fig.add_trace(node_trace, row=1, col=1) 186 | fig.add_trace(edge_label_trace, row=1, col=1) 187 | 188 | fig.add_trace(degree_dist_fig.data[0], row=1, col=2) 189 | fig.add_trace(centrality_fig.data[0], row=2, col=2) 190 | 191 | # Update 3D layout 192 | fig.update_layout( 193 | scene=dict( 194 | xaxis=dict(showticklabels=False, showgrid=False, zeroline=False), 195 | yaxis=dict(showticklabels=False, showgrid=False, zeroline=False), 196 | zaxis=dict(showticklabels=False, showgrid=False, zeroline=False), 197 | aspectmode='cube' 198 | ), 199 | scene_camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) 200 | ) 201 | 202 | # Add buttons for different layouts 203 | fig.update_layout( 204 | updatemenus=[ 205 | dict( 206 | type="buttons", 207 | direction="left", 208 | buttons=list([ 209 | dict(args=[{"visible": [True, True, True, True, True]}], label="Show All", method="update"), 210 | dict(args=[{"visible": [True, True, False, True, True]}], label="Hide Edge Labels", 211 | method="update"), 212 | dict(args=[{"visible": [False, True, False, True, True]}], label="Nodes Only", method="update") 213 | ]), 214 | pad={"r": 10, "t": 10}, 215 | showactive=True, 216 | x=0.05, 217 | xanchor="left", 218 | y=1.1, 219 | yanchor="top" 220 | ), 221 | ] 222 | ) 223 | 224 | # Add slider for node size 225 | fig.update_layout( 226 | sliders=[dict( 227 | active=0, 228 | currentvalue={"prefix": "Node Size: "}, 229 | pad={"t": 50}, 230 | steps=[dict(method='update', 231 | args=[{'marker.size': [i] * len(G.nodes)}], 232 | label=str(i)) for i in range(5, 21, 5)] 233 | )] 234 | ) 235 | 236 | # 优化整体布局 237 | # fig.update_layout( 238 | # height=1198, # 增加整体高度 239 | # width=2055, # 增加整体宽度 240 | # title_text="Advanced Interactive Knowledge Graph", 241 | # margin=dict(l=10, r=10, t=25, b=10), 242 | # legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01) 243 | # ) 244 | 245 | fig.show() 246 | 247 | 248 | def main(): 249 | """ 功能:主函数,协调整个程序的执行流程 250 | 实现: 251 | 读取Parquet文件 252 | 清理数据 253 | 创建知识图谱 254 | 打印图的统计信息 255 | 调用可视化函数 256 | """ 257 | directory = '/Users/charlesqin/PycharmProjects/RAGCode/inputs/artifacts' # 替换为实际的目录路径 258 | df = read_parquet_files(directory) 259 | 260 | if df.empty: 261 | print("No data found in the specified directory.") 262 | return 263 | 264 | print("Original DataFrame shape:", df.shape) 265 | print("Original DataFrame columns:", df.columns.tolist()) 266 | print("Original DataFrame head:") 267 | print(df.head()) 268 | 269 | df = clean_dataframe(df) 270 | 271 | print("\nCleaned DataFrame shape:", df.shape) 272 | print("Cleaned DataFrame head:") 273 | print(df.head()) 274 | 275 | if df.empty: 276 | print("No valid data remaining after cleaning.") 277 | return 278 | 279 | G = create_knowledge_graph(df) 280 | 281 | print(f"\nGraph statistics:") 282 | print(f"Nodes: {G.number_of_nodes()}") 283 | print(f"Edges: {G.number_of_edges()}") 284 | 285 | if G.number_of_nodes() > 0: 286 | print(f"Connected components: {nx.number_connected_components(G.to_undirected())}") 287 | visualize_graph_plotly(G) 288 | else: 289 | print("Graph is empty. Cannot visualize.") 290 | 291 | 292 | if __name__ == "__main__": 293 | main() 294 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /main-cn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import time 4 | import uuid 5 | import json 6 | import re 7 | import pandas as pd 8 | import tiktoken 9 | import logging 10 | from fastapi import FastAPI, HTTPException, Request 11 | from fastapi.responses import JSONResponse, StreamingResponse 12 | from pydantic import BaseModel, Field 13 | from typing import List, Optional, Dict, Any, Union 14 | from contextlib import asynccontextmanager 15 | from tavily import TavilyClient 16 | 17 | 18 | # GraphRAG 相关导入 19 | from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey 20 | from graphrag.query.indexer_adapters import ( 21 | read_indexer_covariates, 22 | read_indexer_entities, 23 | read_indexer_relationships, 24 | read_indexer_reports, 25 | read_indexer_text_units, 26 | ) 27 | from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings 28 | from graphrag.query.llm.oai.chat_openai import ChatOpenAI 29 | from graphrag.query.llm.oai.embedding import OpenAIEmbedding 30 | from graphrag.query.llm.oai.typing import OpenaiApiType 31 | from graphrag.query.question_gen.local_gen import LocalQuestionGen 32 | from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext 33 | from graphrag.query.structured_search.local_search.search import LocalSearch 34 | from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext 35 | from graphrag.query.structured_search.global_search.search import GlobalSearch 36 | from graphrag.vector_stores.lancedb import LanceDBVectorStore 37 | 38 | # 设置日志 39 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 40 | logger = logging.getLogger(__name__) 41 | 42 | # 设置常量和配置 43 | INPUT_DIR = os.getenv('INPUT_DIR') 44 | LANCEDB_URI = f"{INPUT_DIR}/lancedb" 45 | COMMUNITY_REPORT_TABLE = "create_final_community_reports" 46 | ENTITY_TABLE = "create_final_nodes" 47 | ENTITY_EMBEDDING_TABLE = "create_final_entities" 48 | RELATIONSHIP_TABLE = "create_final_relationships" 49 | COVARIATE_TABLE = "create_final_covariates" 50 | TEXT_UNIT_TABLE = "create_final_text_units" 51 | COMMUNITY_LEVEL = 2 52 | PORT = 8012 53 | 54 | # 全局变量,用于存储搜索引擎和问题生成器 55 | local_search_engine = None 56 | global_search_engine = None 57 | question_generator = None 58 | 59 | 60 | # 数据模型 61 | class Message(BaseModel): 62 | role: str 63 | content: str 64 | 65 | 66 | class ChatCompletionRequest(BaseModel): 67 | model: str 68 | messages: List[Message] 69 | temperature: Optional[float] = 1.0 70 | top_p: Optional[float] = 1.0 71 | n: Optional[int] = 1 72 | stream: Optional[bool] = False 73 | stop: Optional[Union[str, List[str]]] = None 74 | max_tokens: Optional[int] = None 75 | presence_penalty: Optional[float] = 0 76 | frequency_penalty: Optional[float] = 0 77 | logit_bias: Optional[Dict[str, float]] = None 78 | user: Optional[str] = None 79 | 80 | 81 | class ChatCompletionResponseChoice(BaseModel): 82 | index: int 83 | message: Message 84 | finish_reason: Optional[str] = None 85 | 86 | 87 | class Usage(BaseModel): 88 | prompt_tokens: int 89 | completion_tokens: int 90 | total_tokens: int 91 | 92 | 93 | class ChatCompletionResponse(BaseModel): 94 | id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") 95 | object: str = "chat.completion" 96 | created: int = Field(default_factory=lambda: int(time.time())) 97 | model: str 98 | choices: List[ChatCompletionResponseChoice] 99 | usage: Usage 100 | system_fingerprint: Optional[str] = None 101 | 102 | 103 | async def setup_llm_and_embedder(): 104 | """ 105 | 设置语言模型(LLM)和嵌入模型 106 | """ 107 | logger.info("正在设置LLM和嵌入器") 108 | 109 | # 获取API密钥和基础URL 110 | api_key = os.environ.get("GRAPHRAG_API_KEY", "YOUR_API_KEY") 111 | api_key_embedding = os.environ.get("GRAPHRAG_API_KEY_EMBEDDING", api_key) 112 | api_base = os.environ.get("API_BASE", "https://api.openai.com/v1") 113 | api_base_embedding = os.environ.get("API_BASE_EMBEDDING", "https://api.openai.com/v1") 114 | 115 | # 获取模型名称 116 | llm_model = os.environ.get("GRAPHRAG_LLM_MODEL", "gpt-3.5-turbo-0125") 117 | embedding_model = os.environ.get("GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small") 118 | 119 | # 检查API密钥是否存在 120 | if api_key == "YOUR_API_KEY": 121 | logger.error("环境变量中未找到有效的GRAPHRAG_API_KEY") 122 | raise ValueError("GRAPHRAG_API_KEY未正确设置") 123 | 124 | # 初始化ChatOpenAI实例 125 | llm = ChatOpenAI( 126 | api_key=api_key, 127 | api_base=api_base, 128 | model=llm_model, 129 | api_type=OpenaiApiType.OpenAI, 130 | max_retries=20, 131 | ) 132 | 133 | # 初始化token编码器 134 | token_encoder = tiktoken.get_encoding("cl100k_base") 135 | 136 | # 初始化文本嵌入模型 137 | text_embedder = OpenAIEmbedding( 138 | api_key=api_key_embedding, 139 | api_base=api_base_embedding, 140 | api_type=OpenaiApiType.OpenAI, 141 | model=embedding_model, 142 | deployment_name=embedding_model, 143 | max_retries=20, 144 | ) 145 | 146 | 147 | logger.info("LLM和嵌入器设置完成") 148 | return llm, token_encoder, text_embedder 149 | 150 | 151 | async def load_context(): 152 | """ 153 | 加载上下文数据,包括实体、关系、报告、文本单元和协变量 154 | """ 155 | logger.info("正在加载上下文数据") 156 | try: 157 | entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet") 158 | entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet") 159 | entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL) 160 | 161 | description_embedding_store = LanceDBVectorStore(collection_name="entity_description_embeddings") 162 | description_embedding_store.connect(db_uri=LANCEDB_URI) 163 | store_entity_semantic_embeddings(entities=entities, vectorstore=description_embedding_store) 164 | 165 | relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet") 166 | relationships = read_indexer_relationships(relationship_df) 167 | 168 | report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet") 169 | reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL) 170 | 171 | text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet") 172 | text_units = read_indexer_text_units(text_unit_df) 173 | 174 | covariate_df = pd.read_parquet(f"{INPUT_DIR}/{COVARIATE_TABLE}.parquet") 175 | claims = read_indexer_covariates(covariate_df) 176 | logger.info(f"声明记录数: {len(claims)}") 177 | covariates = {"claims": claims} 178 | 179 | logger.info("上下文数据加载完成") 180 | return entities, relationships, reports, text_units, description_embedding_store, covariates 181 | except Exception as e: 182 | logger.error(f"加载上下文数据时出错: {str(e)}") 183 | raise 184 | 185 | 186 | async def setup_search_engines(llm, token_encoder, text_embedder, entities, relationships, reports, text_units, 187 | description_embedding_store, covariates): 188 | """ 189 | 设置本地搜索引擎和全局搜索引擎 190 | """ 191 | logger.info("正在设置搜索引擎") 192 | 193 | # 设置本地搜索引擎 194 | local_context_builder = LocalSearchMixedContext( 195 | community_reports=reports, 196 | text_units=text_units, 197 | entities=entities, 198 | relationships=relationships, 199 | covariates=covariates, 200 | entity_text_embeddings=description_embedding_store, 201 | embedding_vectorstore_key=EntityVectorStoreKey.ID, 202 | text_embedder=text_embedder, 203 | token_encoder=token_encoder, 204 | ) 205 | 206 | local_context_params = { 207 | "text_unit_prop": 0.5, 208 | "community_prop": 0.1, 209 | "conversation_history_max_turns": 5, 210 | "conversation_history_user_turns_only": True, 211 | "top_k_mapped_entities": 10, 212 | "top_k_relationships": 10, 213 | "include_entity_rank": True, 214 | "include_relationship_weight": True, 215 | "include_community_rank": False, 216 | "return_candidate_context": False, 217 | "embedding_vectorstore_key": EntityVectorStoreKey.ID, 218 | "max_tokens": 12_000, 219 | } 220 | 221 | local_llm_params = { 222 | "max_tokens": 2_000, 223 | "temperature": 0.0, 224 | } 225 | 226 | local_search_engine = LocalSearch( 227 | llm=llm, 228 | context_builder=local_context_builder, 229 | token_encoder=token_encoder, 230 | llm_params=local_llm_params, 231 | context_builder_params=local_context_params, 232 | response_type="multiple paragraphs", 233 | ) 234 | 235 | # 设置全局搜索引擎 236 | global_context_builder = GlobalCommunityContext( 237 | community_reports=reports, 238 | entities=entities, 239 | token_encoder=token_encoder, 240 | ) 241 | 242 | global_context_builder_params = { 243 | "use_community_summary": False, 244 | "shuffle_data": True, 245 | "include_community_rank": True, 246 | "min_community_rank": 0, 247 | "community_rank_name": "rank", 248 | "include_community_weight": True, 249 | "community_weight_name": "occurrence weight", 250 | "normalize_community_weight": True, 251 | "max_tokens": 12_000, 252 | "context_name": "Reports", 253 | } 254 | 255 | map_llm_params = { 256 | "max_tokens": 1000, 257 | "temperature": 0.0, 258 | "response_format": {"type": "json_object"}, 259 | } 260 | 261 | reduce_llm_params = { 262 | "max_tokens": 2000, 263 | "temperature": 0.0, 264 | } 265 | 266 | global_search_engine = GlobalSearch( 267 | llm=llm, 268 | context_builder=global_context_builder, 269 | token_encoder=token_encoder, 270 | max_data_tokens=12_000, 271 | map_llm_params=map_llm_params, 272 | reduce_llm_params=reduce_llm_params, 273 | allow_general_knowledge=False, 274 | json_mode=True, 275 | context_builder_params=global_context_builder_params, 276 | concurrent_coroutines=32, 277 | response_type="multiple paragraphs", 278 | ) 279 | 280 | logger.info("搜索引擎设置完成") 281 | return local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params 282 | 283 | 284 | def format_response(response): 285 | """ 286 | 格式化响应,添加适当的换行和段落分隔。 287 | """ 288 | paragraphs = re.split(r'\n{2,}', response) 289 | 290 | formatted_paragraphs = [] 291 | for para in paragraphs: 292 | if '```' in para: 293 | parts = para.split('```') 294 | for i, part in enumerate(parts): 295 | if i % 2 == 1: # 这是代码块 296 | parts[i] = f"\n```\n{part.strip()}\n```\n" 297 | para = ''.join(parts) 298 | else: 299 | para = para.replace('. ', '.\n') 300 | 301 | formatted_paragraphs.append(para.strip()) 302 | 303 | return '\n\n'.join(formatted_paragraphs) 304 | 305 | 306 | async def tavily_search(prompt: str): 307 | """ 308 | 使用Tavily API进行搜索 309 | """ 310 | try: 311 | client = TavilyClient(api_key=os.environ['TAVILY_API_KEY']) 312 | resp = client.search(prompt, search_depth="advanced") 313 | 314 | # 将Tavily响应转换为Markdown格式 315 | markdown_response = "# 搜索结果\n\n" 316 | for result in resp.get('results', []): 317 | markdown_response += f"## [{result['title']}]({result['url']})\n\n" 318 | markdown_response += f"{result['content']}\n\n" 319 | 320 | return markdown_response 321 | except Exception as e: 322 | raise HTTPException(status_code=500, detail=f"Tavily搜索错误: {str(e)}") 323 | 324 | 325 | @asynccontextmanager 326 | async def lifespan(app: FastAPI): 327 | # 启动时执行 328 | global local_search_engine, global_search_engine, question_generator 329 | try: 330 | logger.info("正在初始化搜索引擎和问题生成器...") 331 | llm, token_encoder, text_embedder = await setup_llm_and_embedder() 332 | entities, relationships, reports, text_units, description_embedding_store, covariates = await load_context() 333 | local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params = await setup_search_engines( 334 | llm, token_encoder, text_embedder, entities, relationships, reports, text_units, 335 | description_embedding_store, covariates 336 | ) 337 | 338 | question_generator = LocalQuestionGen( 339 | llm=llm, 340 | context_builder=local_context_builder, 341 | token_encoder=token_encoder, 342 | llm_params=local_llm_params, 343 | context_builder_params=local_context_params, 344 | ) 345 | logger.info("初始化完成。") 346 | except Exception as e: 347 | logger.error(f"初始化过程中出错: {str(e)}") 348 | raise 349 | 350 | yield 351 | 352 | # 关闭时执行 353 | logger.info("正在关闭...") 354 | 355 | 356 | app = FastAPI(lifespan=lifespan) 357 | 358 | 359 | # 在 chat_completions 函数中添加以下代码 360 | 361 | async def full_model_search(prompt: str): 362 | """ 363 | 执行全模型搜索,包括本地检索、全局检索和 Tavily 搜索 364 | """ 365 | local_result = await local_search_engine.asearch(prompt) 366 | global_result = await global_search_engine.asearch(prompt) 367 | tavily_result = await tavily_search(prompt) 368 | 369 | # 格式化结果 370 | formatted_result = "# 🔥🔥🔥综合搜索结果\n\n" 371 | 372 | formatted_result += "## 🔥🔥🔥本地检索结果\n" 373 | formatted_result += format_response(local_result.response) + "\n\n" 374 | 375 | formatted_result += "## 🔥🔥🔥全局检索结果\n" 376 | formatted_result += format_response(global_result.response) + "\n\n" 377 | 378 | formatted_result += "## 🔥🔥🔥Tavily 搜索结果\n" 379 | formatted_result += tavily_result + "\n\n" 380 | 381 | return formatted_result 382 | 383 | 384 | @app.post("/v1/chat/completions") 385 | async def chat_completions(request: ChatCompletionRequest): 386 | if not local_search_engine or not global_search_engine: 387 | logger.error("搜索引擎未初始化") 388 | raise HTTPException(status_code=500, detail="搜索引擎未初始化") 389 | 390 | try: 391 | logger.info(f"收到聊天完成请求: {request}") 392 | prompt = request.messages[-1].content 393 | logger.info(f"处理提示: {prompt}") 394 | 395 | # 根据模型选择使用不同的搜索方法 396 | if request.model == "graphrag-global-search:latest": 397 | result = await global_search_engine.asearch(prompt) 398 | formatted_response = format_response(result.response) 399 | elif request.model == "tavily-search:latest": 400 | result = await tavily_search(prompt) 401 | formatted_response = result 402 | elif request.model == "full-model:latest": 403 | formatted_response = await full_model_search(prompt) 404 | else: # 默认使用本地搜索 405 | result = await local_search_engine.asearch(prompt) 406 | formatted_response = format_response(result.response) 407 | 408 | logger.info(f"格式化的搜索结果: {formatted_response}") 409 | 410 | # 流式响应和非流式响应的处理保持不变 411 | if request.stream: 412 | async def generate_stream(): 413 | chunk_id = f"chatcmpl-{uuid.uuid4().hex}" 414 | lines = formatted_response.split('\n') 415 | for i, line in enumerate(lines): 416 | chunk = { 417 | "id": chunk_id, 418 | "object": "chat.completion.chunk", 419 | "created": int(time.time()), 420 | "model": request.model, 421 | "choices": [ 422 | { 423 | "index": 0, 424 | "delta": {"content": line + '\n'}, # if i > 0 else {"role": "assistant", "content": ""}, 425 | "finish_reason": None 426 | } 427 | ] 428 | } 429 | yield f"data: {json.dumps(chunk)}\n\n" 430 | await asyncio.sleep(0.05) 431 | 432 | final_chunk = { 433 | "id": chunk_id, 434 | "object": "chat.completion.chunk", 435 | "created": int(time.time()), 436 | "model": request.model, 437 | "choices": [ 438 | { 439 | "index": 0, 440 | "delta": {}, 441 | "finish_reason": "stop" 442 | } 443 | ] 444 | } 445 | yield f"data: {json.dumps(final_chunk)}\n\n" 446 | yield "data: [DONE]\n\n" 447 | 448 | return StreamingResponse(generate_stream(), media_type="text/event-stream") 449 | else: 450 | response = ChatCompletionResponse( 451 | model=request.model, 452 | choices=[ 453 | ChatCompletionResponseChoice( 454 | index=0, 455 | message=Message(role="assistant", content=formatted_response), 456 | finish_reason="stop" 457 | ) 458 | ], 459 | usage=Usage( 460 | prompt_tokens=len(prompt.split()), 461 | completion_tokens=len(formatted_response.split()), 462 | total_tokens=len(prompt.split()) + len(formatted_response.split()) 463 | ) 464 | ) 465 | logger.info(f"发送响应: {response}") 466 | return JSONResponse(content=response.dict()) 467 | 468 | except Exception as e: 469 | logger.error(f"处理聊天完成时出错: {str(e)}") 470 | raise HTTPException(status_code=500, detail=str(e)) 471 | 472 | @app.get("/v1/models") 473 | async def list_models(): 474 | """ 475 | 返回可用模型列表 476 | """ 477 | logger.info("收到模型列表请求") 478 | current_time = int(time.time()) 479 | models = [ 480 | {"id": "graphrag-local-search:latest", "object": "model", "created": current_time - 100000, "owned_by": "graphrag"}, 481 | {"id": "graphrag-global-search:latest", "object": "model", "created": current_time - 95000, "owned_by": "graphrag"}, 482 | # {"id": "graphrag-question-generator:latest", "object": "model", "created": current_time - 90000, "owned_by": "graphrag"}, 483 | # {"id": "gpt-3.5-turbo:latest", "object": "model", "created": current_time - 80000, "owned_by": "openai"}, 484 | # {"id": "text-embedding-3-small:latest", "object": "model", "created": current_time - 70000, "owned_by": "openai"}, 485 | {"id": "tavily-search:latest", "object": "model", "created": current_time - 85000, "owned_by": "tavily"}, 486 | {"id": "full-model:latest", "object": "model", "created": current_time - 80000, "owned_by": "combined"} 487 | 488 | ] 489 | 490 | response = { 491 | "object": "list", 492 | "data": models 493 | } 494 | 495 | logger.info(f"发送模型列表: {response}") 496 | return JSONResponse(content=response) 497 | 498 | if __name__ == "__main__": 499 | import uvicorn 500 | 501 | logger.info(f"在端口 {PORT} 上启动服务器") 502 | uvicorn.run(app, host="0.0.0.0", port=PORT) 503 | 504 | -------------------------------------------------------------------------------- /main-en.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import time 4 | import uuid 5 | import json 6 | import re 7 | import pandas as pd 8 | import tiktoken 9 | import logging 10 | from fastapi import FastAPI, HTTPException, Request 11 | from fastapi.responses import JSONResponse, StreamingResponse 12 | from pydantic import BaseModel, Field 13 | from typing import List, Optional, Dict, Any, Union 14 | from contextlib import asynccontextmanager 15 | from tavily import TavilyClient 16 | 17 | 18 | # GraphRAG related imports 19 | from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey 20 | from graphrag.query.indexer_adapters import ( 21 | read_indexer_covariates, 22 | read_indexer_entities, 23 | read_indexer_relationships, 24 | read_indexer_reports, 25 | read_indexer_text_units, 26 | ) 27 | from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings 28 | from graphrag.query.llm.oai.chat_openai import ChatOpenAI 29 | from graphrag.query.llm.oai.embedding import OpenAIEmbedding 30 | from graphrag.query.llm.oai.typing import OpenaiApiType 31 | from graphrag.query.question_gen.local_gen import LocalQuestionGen 32 | from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext 33 | from graphrag.query.structured_search.local_search.search import LocalSearch 34 | from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext 35 | from graphrag.query.structured_search.global_search.search import GlobalSearch 36 | from graphrag.vector_stores.lancedb import LanceDBVectorStore 37 | 38 | # Set up logging 39 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 40 | logger = logging.getLogger(__name__) 41 | 42 | # Set constants and configurations 43 | INPUT_DIR = os.getenv('INPUT_DIR') 44 | LANCEDB_URI = f"{INPUT_DIR}/lancedb" 45 | COMMUNITY_REPORT_TABLE = "create_final_community_reports" 46 | ENTITY_TABLE = "create_final_nodes" 47 | ENTITY_EMBEDDING_TABLE = "create_final_entities" 48 | RELATIONSHIP_TABLE = "create_final_relationships" 49 | COVARIATE_TABLE = "create_final_covariates" 50 | TEXT_UNIT_TABLE = "create_final_text_units" 51 | COMMUNITY_LEVEL = 2 52 | PORT = 8012 53 | 54 | # Global variables for storing search engines and question generator 55 | local_search_engine = None 56 | global_search_engine = None 57 | question_generator = None 58 | 59 | 60 | # Data models 61 | class Message(BaseModel): 62 | role: str 63 | content: str 64 | 65 | 66 | class ChatCompletionRequest(BaseModel): 67 | model: str 68 | messages: List[Message] 69 | temperature: Optional[float] = 1.0 70 | top_p: Optional[float] = 1.0 71 | n: Optional[int] = 1 72 | stream: Optional[bool] = False 73 | stop: Optional[Union[str, List[str]]] = None 74 | max_tokens: Optional[int] = None 75 | presence_penalty: Optional[float] = 0 76 | frequency_penalty: Optional[float] = 0 77 | logit_bias: Optional[Dict[str, float]] = None 78 | user: Optional[str] = None 79 | 80 | 81 | class ChatCompletionResponseChoice(BaseModel): 82 | index: int 83 | message: Message 84 | finish_reason: Optional[str] = None 85 | 86 | 87 | class Usage(BaseModel): 88 | prompt_tokens: int 89 | completion_tokens: int 90 | total_tokens: int 91 | 92 | 93 | class ChatCompletionResponse(BaseModel): 94 | id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") 95 | object: str = "chat.completion" 96 | created: int = Field(default_factory=lambda: int(time.time())) 97 | model: str 98 | choices: List[ChatCompletionResponseChoice] 99 | usage: Usage 100 | system_fingerprint: Optional[str] = None 101 | 102 | 103 | async def setup_llm_and_embedder(): 104 | """ 105 | Set up Language Model (LLM) and embedding model 106 | """ 107 | logger.info("Setting up LLM and embedder") 108 | 109 | # Get API keys and base URLs 110 | api_key = os.environ.get("GRAPHRAG_API_KEY", "YOUR_API_KEY") 111 | api_key_embedding = os.environ.get("GRAPHRAG_API_KEY_EMBEDDING", api_key) 112 | api_base = os.environ.get("API_BASE", "https://api.openai.com/v1") 113 | api_base_embedding = os.environ.get("API_BASE_EMBEDDING", "https://api.openai.com/v1") 114 | 115 | # Get model names 116 | llm_model = os.environ.get("GRAPHRAG_LLM_MODEL", "gpt-3.5-turbo-0125") 117 | embedding_model = os.environ.get("GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small") 118 | 119 | # Check if API key exists 120 | if api_key == "YOUR_API_KEY": 121 | logger.error("Valid GRAPHRAG_API_KEY not found in environment variables") 122 | raise ValueError("GRAPHRAG_API_KEY is not set correctly") 123 | 124 | # Initialize ChatOpenAI instance 125 | llm = ChatOpenAI( 126 | api_key=api_key, 127 | api_base=api_base, 128 | model=llm_model, 129 | api_type=OpenaiApiType.OpenAI, 130 | max_retries=20, 131 | ) 132 | 133 | # Initialize token encoder 134 | token_encoder = tiktoken.get_encoding("cl100k_base") 135 | 136 | # Initialize text embedding model 137 | text_embedder = OpenAIEmbedding( 138 | api_key=api_key_embedding, 139 | api_base=api_base_embedding, 140 | api_type=OpenaiApiType.OpenAI, 141 | model=embedding_model, 142 | deployment_name=embedding_model, 143 | max_retries=20, 144 | ) 145 | 146 | 147 | logger.info("LLM and embedder setup complete") 148 | return llm, token_encoder, text_embedder 149 | 150 | 151 | async def load_context(): 152 | """ 153 | Load context data including entities, relationships, reports, text units, and covariates 154 | """ 155 | logger.info("Loading context data") 156 | try: 157 | entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet") 158 | entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet") 159 | entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL) 160 | 161 | description_embedding_store = LanceDBVectorStore(collection_name="entity_description_embeddings") 162 | description_embedding_store.connect(db_uri=LANCEDB_URI) 163 | store_entity_semantic_embeddings(entities=entities, vectorstore=description_embedding_store) 164 | 165 | relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet") 166 | relationships = read_indexer_relationships(relationship_df) 167 | 168 | report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet") 169 | reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL) 170 | 171 | text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet") 172 | text_units = read_indexer_text_units(text_unit_df) 173 | 174 | covariate_df = pd.read_parquet(f"{INPUT_DIR}/{COVARIATE_TABLE}.parquet") 175 | claims = read_indexer_covariates(covariate_df) 176 | logger.info(f"Number of claim records: {len(claims)}") 177 | covariates = {"claims": claims} 178 | 179 | logger.info("Context data loading complete") 180 | return entities, relationships, reports, text_units, description_embedding_store, covariates 181 | except Exception as e: 182 | logger.error(f"Error loading context data: {str(e)}") 183 | raise 184 | 185 | 186 | async def setup_search_engines(llm, token_encoder, text_embedder, entities, relationships, reports, text_units, 187 | description_embedding_store, covariates): 188 | """ 189 | Set up local and global search engines 190 | """ 191 | logger.info("Setting up search engines") 192 | 193 | # Set up local search engine 194 | local_context_builder = LocalSearchMixedContext( 195 | community_reports=reports, 196 | text_units=text_units, 197 | entities=entities, 198 | relationships=relationships, 199 | covariates=covariates, 200 | entity_text_embeddings=description_embedding_store, 201 | embedding_vectorstore_key=EntityVectorStoreKey.ID, 202 | text_embedder=text_embedder, 203 | token_encoder=token_encoder, 204 | ) 205 | 206 | local_context_params = { 207 | "text_unit_prop": 0.5, 208 | "community_prop": 0.1, 209 | "conversation_history_max_turns": 5, 210 | "conversation_history_user_turns_only": True, 211 | "top_k_mapped_entities": 10, 212 | "top_k_relationships": 10, 213 | "include_entity_rank": True, 214 | "include_relationship_weight": True, 215 | "include_community_rank": False, 216 | "return_candidate_context": False, 217 | "embedding_vectorstore_key": EntityVectorStoreKey.ID, 218 | "max_tokens": 12_000, 219 | } 220 | 221 | local_llm_params = { 222 | "max_tokens": 2_000, 223 | "temperature": 0.0, 224 | } 225 | 226 | local_search_engine = LocalSearch( 227 | llm=llm, 228 | context_builder=local_context_builder, 229 | token_encoder=token_encoder, 230 | llm_params=local_llm_params, 231 | context_builder_params=local_context_params, 232 | response_type="multiple paragraphs", 233 | ) 234 | 235 | # Set up global search engine 236 | global_context_builder = GlobalCommunityContext( 237 | community_reports=reports, 238 | entities=entities, 239 | token_encoder=token_encoder, 240 | ) 241 | 242 | global_context_builder_params = { 243 | "use_community_summary": False, 244 | "shuffle_data": True, 245 | "include_community_rank": True, 246 | "min_community_rank": 0, 247 | "community_rank_name": "rank", 248 | "include_community_weight": True, 249 | "community_weight_name": "occurrence weight", 250 | "normalize_community_weight": True, 251 | "max_tokens": 12_000, 252 | "context_name": "Reports", 253 | } 254 | 255 | map_llm_params = { 256 | "max_tokens": 1000, 257 | "temperature": 0.0, 258 | "response_format": {"type": "json_object"}, 259 | } 260 | 261 | reduce_llm_params = { 262 | "max_tokens": 2000, 263 | "temperature": 0.0, 264 | } 265 | 266 | global_search_engine = GlobalSearch( 267 | llm=llm, 268 | context_builder=global_context_builder, 269 | token_encoder=token_encoder, 270 | max_data_tokens=12_000, 271 | map_llm_params=map_llm_params, 272 | reduce_llm_params=reduce_llm_params, 273 | allow_general_knowledge=False, 274 | json_mode=True, 275 | context_builder_params=global_context_builder_params, 276 | concurrent_coroutines=32, 277 | response_type="multiple paragraphs", 278 | ) 279 | 280 | logger.info("Search engines setup complete") 281 | return local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params 282 | 283 | 284 | def format_response(response): 285 | """ 286 | Format the response by adding appropriate line breaks and paragraph separations. 287 | """ 288 | paragraphs = re.split(r'\n{2,}', response) 289 | 290 | formatted_paragraphs = [] 291 | for para in paragraphs: 292 | if '```' in para: 293 | parts = para.split('```') 294 | for i, part in enumerate(parts): 295 | if i % 2 == 1: # This is a code block 296 | parts[i] = f"\n```\n{part.strip()}\n```\n" 297 | para = ''.join(parts) 298 | else: 299 | para = para.replace('. ', '.\n') 300 | 301 | formatted_paragraphs.append(para.strip()) 302 | 303 | return '\n\n'.join(formatted_paragraphs) 304 | 305 | 306 | async def tavily_search(prompt: str): 307 | """ 308 | Perform a search using the Tavily API 309 | """ 310 | try: 311 | client = TavilyClient(api_key=os.environ['TAVILY_API_KEY']) 312 | resp = client.search(prompt, search_depth="advanced") 313 | 314 | # Convert Tavily response to Markdown format 315 | markdown_response = "# Search Results\n\n" 316 | for result in resp.get('results', []): 317 | markdown_response += f"## [{result['title']}]({result['url']})\n\n" 318 | markdown_response += f"{result['content']}\n\n" 319 | 320 | return markdown_response 321 | except Exception as e: 322 | raise HTTPException(status_code=500, detail=f"Tavily search error: {str(e)}") 323 | 324 | 325 | @asynccontextmanager 326 | async def lifespan(app: FastAPI): 327 | # Execute on startup 328 | global local_search_engine, global_search_engine, question_generator 329 | try: 330 | logger.info("Initializing search engines and question generator...") 331 | llm, token_encoder, text_embedder = await setup_llm_and_embedder() 332 | entities, relationships, reports, text_units, description_embedding_store, covariates = await load_context() 333 | local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params = await setup_search_engines( 334 | llm, token_encoder, text_embedder, entities, relationships, reports, text_units, 335 | description_embedding_store, covariates 336 | ) 337 | 338 | question_generator = LocalQuestionGen( 339 | llm=llm, 340 | context_builder=local_context_builder, 341 | token_encoder=token_encoder, 342 | llm_params=local_llm_params, 343 | context_builder_params=local_context_params, 344 | ) 345 | logger.info("Initialization complete.") 346 | except Exception as e: 347 | logger.error(f"Error during initialization: {str(e)}") 348 | raise 349 | 350 | yield 351 | 352 | # Execute on shutdown 353 | logger.info("Shutting down...") 354 | 355 | 356 | app = FastAPI(lifespan=lifespan) 357 | 358 | 359 | # Add the following code to the chat_completions function 360 | 361 | async def full_model_search(prompt: str): 362 | """ 363 | Perform a full model search, including local retrieval, global retrieval, and Tavily search 364 | """ 365 | local_result = await local_search_engine.asearch(prompt) 366 | global_result = await global_search_engine.asearch(prompt) 367 | tavily_result = await tavily_search(prompt) 368 | 369 | # Format results 370 | formatted_result = "# 🔥🔥🔥Comprehensive Search Results\n\n" 371 | 372 | formatted_result += "## 🔥🔥🔥Local Retrieval Results\n" 373 | formatted_result += format_response(local_result.response) + "\n\n" 374 | 375 | formatted_result += "## 🔥🔥🔥Global Retrieval Results\n" 376 | formatted_result += format_response(global_result.response) + "\n\n" 377 | 378 | formatted_result += "## 🔥🔥🔥Tavily Search Results\n" 379 | formatted_result += tavily_result + "\n\n" 380 | 381 | return formatted_result 382 | 383 | @app.post("/v1/chat/completions") 384 | async def chat_completions(request: ChatCompletionRequest): 385 | if not local_search_engine or not global_search_engine: 386 | logger.error("Search engines not initialized") 387 | raise HTTPException(status_code=500, detail="Search engines not initialized") 388 | 389 | try: 390 | logger.info(f"Received chat completion request: {request}") 391 | prompt = request.messages[-1].content 392 | logger.info(f"Processing prompt: {prompt}") 393 | 394 | # Choose different search methods based on the model 395 | if request.model == "graphrag-global-search:latest": 396 | result = await global_search_engine.asearch(prompt) 397 | formatted_response = format_response(result.response) 398 | elif request.model == "tavily-search:latest": 399 | result = await tavily_search(prompt) 400 | formatted_response = result 401 | elif request.model == "full-model:latest": 402 | formatted_response = await full_model_search(prompt) 403 | else: # Default to local search 404 | result = await local_search_engine.asearch(prompt) 405 | formatted_response = format_response(result.response) 406 | 407 | logger.info(f"Formatted search result: {formatted_response}") 408 | 409 | # Handle streaming and non-streaming responses 410 | if request.stream: 411 | async def generate_stream(): 412 | chunk_id = f"chatcmpl-{uuid.uuid4().hex}" 413 | lines = formatted_response.split('\n') 414 | for i, line in enumerate(lines): 415 | chunk = { 416 | "id": chunk_id, 417 | "object": "chat.completion.chunk", 418 | "created": int(time.time()), 419 | "model": request.model, 420 | "choices": [ 421 | { 422 | "index": 0, 423 | "delta": {"content": line + '\n'}, # if i > 0 else {"role": "assistant", "content": ""}, 424 | "finish_reason": None 425 | } 426 | ] 427 | } 428 | yield f"data: {json.dumps(chunk)}\n\n" 429 | await asyncio.sleep(0.05) 430 | 431 | final_chunk = { 432 | "id": chunk_id, 433 | "object": "chat.completion.chunk", 434 | "created": int(time.time()), 435 | "model": request.model, 436 | "choices": [ 437 | { 438 | "index": 0, 439 | "delta": {}, 440 | "finish_reason": "stop" 441 | } 442 | ] 443 | } 444 | yield f"data: {json.dumps(final_chunk)}\n\n" 445 | yield "data: [DONE]\n\n" 446 | 447 | return StreamingResponse(generate_stream(), media_type="text/event-stream") 448 | else: 449 | response = ChatCompletionResponse( 450 | model=request.model, 451 | choices=[ 452 | ChatCompletionResponseChoice( 453 | index=0, 454 | message=Message(role="assistant", content=formatted_response), 455 | finish_reason="stop" 456 | ) 457 | ], 458 | usage=Usage( 459 | prompt_tokens=len(prompt.split()), 460 | completion_tokens=len(formatted_response.split()), 461 | total_tokens=len(prompt.split()) + len(formatted_response.split()) 462 | ) 463 | ) 464 | logger.info(f"Sending response: {response}") 465 | return JSONResponse(content=response.dict()) 466 | 467 | except Exception as e: 468 | logger.error(f"Error processing chat completion: {str(e)}") 469 | raise HTTPException(status_code=500, detail=str(e)) 470 | 471 | @app.get("/v1/models") 472 | async def list_models(): 473 | """ 474 | Return a list of available models 475 | """ 476 | logger.info("Received model list request") 477 | current_time = int(time.time()) 478 | models = [ 479 | {"id": "graphrag-local-search:latest", "object": "model", "created": current_time - 100000, "owned_by": "graphrag"}, 480 | {"id": "graphrag-global-search:latest", "object": "model", "created": current_time - 95000, "owned_by": "graphrag"}, 481 | # {"id": "graphrag-question-generator:latest", "object": "model", "created": current_time - 90000, "owned_by": "graphrag"}, 482 | # {"id": "gpt-3.5-turbo:latest", "object": "model", "created": current_time - 80000, "owned_by": "openai"}, 483 | # {"id": "text-embedding-3-small:latest", "object": "model", "created": current_time - 70000, "owned_by": "openai"}, 484 | {"id": "tavily-search:latest", "object": "model", "created": current_time - 85000, "owned_by": "tavily"}, 485 | {"id": "full-model:latest", "object": "model", "created": current_time - 80000, "owned_by": "combined"} 486 | ] 487 | 488 | response = { 489 | "object": "list", 490 | "data": models 491 | } 492 | 493 | logger.info(f"Sending model list: {response}") 494 | return JSONResponse(content=response) 495 | 496 | if __name__ == "__main__": 497 | import uvicorn 498 | 499 | logger.info(f"Starting server on port {PORT}") 500 | uvicorn.run(app, host="0.0.0.0", port=PORT) 501 | --------------------------------------------------------------------------------