├── GraphRAG4OpenWebUI
└── test
├── requirements.txt
├── README_ZH-CN.md
├── README.md
├── graphrag3dknowledge.py
├── LICENSE
├── main-cn.py
└── main-en.py
/GraphRAG4OpenWebUI/test:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi
2 | uvicorn
3 | pandas
4 | tiktoken
5 | graphrag
6 | tavily-python
7 | pydantic
8 | python-dotenv
9 | asyncio
10 | aiohttp
11 | numpy
12 | scikit-learn
13 | matplotlib
14 | seaborn
15 | nltk
16 | spacy
17 | transformers
18 | torch
19 | torchvision
20 | torchaudio
21 |
--------------------------------------------------------------------------------
/README_ZH-CN.md:
--------------------------------------------------------------------------------
1 | ### 🔥🔥🔥如有问题请联系我的微信 stoeng
2 | ### 🔥🔥🔥项目对应的视频演示请看 https://youtu.be/z4Si6O5NQ4c
3 |
4 | # GraphRAG4OpenWebUI
5 |
6 |
将微软的 GraphRAG 技术集成到 Open WebUI 中,实现高级信息检索
7 |
English | 简体中文
8 |
9 |
10 | GraphRAG4OpenWebUI 是一个专为 Open WebUI 设计的 API 接口,旨在集成微软研究院的 GraphRAG(基于图的检索增强生成)技术。该项目提供了一个强大的信息检索系统,支持多种搜索模型,特别适合在开放式 Web 用户界面中使用。
11 |
12 | ## 项目概述
13 |
14 | 本项目的主要目标是为 Open WebUI 提供一个便捷的接口,以利用 GraphRAG 的强大功能。它集成了三种主要的检索方法,并提供了一个综合搜索选项,使用户能够获得全面而精确的搜索结果。
15 |
16 | ### 主要检索功能
17 |
18 | 1. **本地搜索(Local Search)**
19 | - 利用 GraphRAG 技术在本地知识库中进行高效检索
20 | - 适用于快速访问预先定义的结构化信息
21 | - 利用图结构提高检索的准确性和相关性
22 |
23 | 2. **全局搜索(Global Search)**
24 | - 在更广泛的范围内搜索信息,超越本地知识库的限制
25 | - 适用于需要更全面信息的查询
26 | - 利用 GraphRAG 的全局上下文理解能力,提供更丰富的搜索结果
27 |
28 | 3. **Tavily 搜索**
29 | - 集成外部 Tavily 搜索 API
30 | - 提供额外的互联网搜索能力,扩展信息源
31 | - 适用于需要最新或广泛网络信息的查询
32 |
33 | 4. **全模型搜索(Full Model Search)**
34 | - 综合上述三种搜索方法
35 | - 提供最全面的搜索结果,满足复杂的信息需求
36 | - 自动整合和排序来自不同来源的信息
37 |
38 | ### 本地LLM和Embedding模型支持
39 |
40 | GraphRAG4OpenWebUI 现在支持使用本地的语言模型(LLM)和嵌入模型,增加了项目的灵活性和隐私性。特别地,我们支持以下本地模型:
41 |
42 | 1. **Ollama**
43 | - 支持使用 Ollama 运行的各种开源 LLM,如 Llama 2、Mistral 等
44 | - 可以通过设置 `API_BASE` 环境变量来指向 Ollama 的 API 端点
45 |
46 | 2. **LM Studio**
47 | - 兼容 LM Studio 运行的模型
48 | - 通过配置 `API_BASE` 环境变量连接到 LM Studio 的服务
49 |
50 | 3. **本地 Embedding 模型**
51 | - 支持使用本地运行的嵌入模型,如 SentenceTransformers
52 | - 通过设置 `GRAPHRAG_EMBEDDING_MODEL` 环境变量来指定使用的嵌入模型
53 |
54 | 这些本地模型的支持使得 GraphRAG4OpenWebUI 能够在不依赖外部API的情况下运行,提高了数据隐私和降低了使用成本。
55 |
56 | ## 安装
57 | 确保您的系统中已安装 Python 3.8 或更高版本。然后,按照以下步骤安装:
58 | 1. 克隆仓库:
59 | ```bash
60 | git clone https://github.com/your-username/GraphRAG4OpenWebUI.git
61 | cd GraphRAG4OpenWebUI
62 | ```
63 |
64 | 2. 创建并激活虚拟环境:
65 | ```bash
66 | python -m venv venv
67 | source venv/bin/activate # 在 Windows 上使用 venv\Scripts\activate
68 | ```
69 |
70 | 3. 安装依赖:
71 | ```bash
72 | pip install -r requirements.txt
73 | ```
74 | 注意:graphrag 包可能需要从特定的源安装。如果上述命令无法安装 graphrag,请参考微软研究院的具体说明或联系维护者获取正确的安装方法。
75 |
76 | ## 配置
77 |
78 | 在运行 API 之前,需要设置以下环境变量。您可以通过创建 `.env` 文件或直接在终端中导出这些变量:
79 |
80 |
81 | ```bash
82 | export TAVILY_API_KEY="your_tavily_api_key"
83 |
84 | export INPUT_DIR="/path/to/your/input/directory"
85 |
86 | # 设置llm API密钥
87 | export GRAPHRAG_API_KEY="your_actual_api_key_here"
88 |
89 | # 设置嵌入API密钥(如果与GRAPHRAG_API_KEY不同)
90 | export GRAPHRAG_API_KEY_EMBEDDING="your_embedding_api_key_here"
91 |
92 | # 设置LLM模型(默认为"gemma2")
93 | export GRAPHRAG_LLM_MODEL="gemma2"
94 |
95 | # 设置API基础URL(默认为本地服务器)
96 | export API_BASE="http://localhost:11434/v1"
97 |
98 | # 设置嵌入API基础URL(默认为OpenAI的API)
99 | export API_BASE_EMBEDDING="https://api.openai.com/v1"
100 |
101 | # 设置嵌入模型(默认为"text-embedding-3-small")
102 | export GRAPHRAG_EMBEDDING_MODEL="text-embedding-3-small"
103 | ```
104 |
105 | 请确保将上述命令中的占位符替换为实际的 API 密钥和路径。
106 |
107 | ## 使用方法
108 |
109 | 1. 启动服务器:
110 | ```
111 | python main-cn.py
112 | ```
113 | 服务器将在 `http://localhost:8012` 上运行。
114 |
115 | 2. API 端点:
116 | - `/v1/chat/completions`: POST 请求,用于执行搜索
117 | - `/v1/models`: GET 请求,获取可用模型列表
118 |
119 | 3. 在 Open WebUI 中集成:
120 | 在 Open WebUI 的配置中,将 API 端点设置为 `http://localhost:8012/v1/chat/completions`。这将允许 Open WebUI 使用 GraphRAG4OpenWebUI 的搜索功能。
121 |
122 | 4. 发送搜索请求示例:
123 | ```python
124 | import requests
125 | import json
126 |
127 | url = "http://localhost:8012/v1/chat/completions"
128 | headers = {"Content-Type": "application/json"}
129 | data = {
130 | "model": "full-model:latest",
131 | "messages": [{"role": "user", "content": "您的搜索查询"}],
132 | "temperature": 0.7
133 | }
134 |
135 | response = requests.post(url, headers=headers, data=json.dumps(data))
136 | print(response.json())
137 | ```
138 |
139 | ## 可用模型
140 |
141 | - `graphrag-local-search:latest`: 本地搜索
142 | - `graphrag-global-search:latest`: 全局搜索
143 | - `tavily-search:latest`: Tavily 搜索
144 | - `full-model:latest`: 综合搜索(包含上述所有搜索方法)
145 |
146 | ## 注意事项
147 |
148 | - 确保在 `INPUT_DIR` 目录中有正确的输入文件(如 Parquet 文件)。
149 | - API 使用异步编程,确保您的环境支持异步操作。
150 | - 对于大规模部署,建议使用生产级的 ASGI 服务器。
151 | - 本项目专为 Open WebUI 设计,可以轻松集成到各种基于 Web 的应用中。
152 |
153 | ## 贡献
154 |
155 | 我们欢迎您提交 Pull Requests 来改进这个项目。对于重大变更,请先开 issue 讨论您想要改变的内容。
156 |
157 | ## 许可证
158 |
159 | [Apache-2.0 许可证](LICENSE)
160 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # GraphRAG4OpenWebUI
3 |
4 |
Integrate Microsoft's GraphRAG Technology into Open WebUI for Advanced Information Retrieval
5 | English
6 |
7 |
8 | GraphRAG4OpenWebUI is an API interface specifically designed for Open WebUI, aiming to integrate Microsoft Research's GraphRAG (Graph-based Retrieval-Augmented Generation) technology. This project provides a powerful information retrieval system that supports multiple search models, particularly suitable for use in open web user interfaces.
9 |
10 | ## Project Overview
11 |
12 | The main goal of this project is to provide a convenient interface for Open WebUI to leverage the powerful features of GraphRAG. It integrates three main retrieval methods and offers a comprehensive search option, allowing users to obtain thorough and precise search results.
13 |
14 | ### Key Retrieval Features
15 |
16 | 1. **Local Search**
17 | - Utilizes GraphRAG technology for efficient retrieval in local knowledge bases
18 | - Suitable for quick access to pre-defined structured information
19 | - Leverages graph structures to improve retrieval accuracy and relevance
20 |
21 | 2. **Global Search**
22 | - Searches for information in a broader scope, beyond local knowledge bases
23 | - Suitable for queries requiring more comprehensive information
24 | - Utilizes GraphRAG's global context understanding capabilities to provide richer search results
25 |
26 | 3. **Tavily Search**
27 | - Integrates external Tavily search API
28 | - Provides additional internet search capabilities, expanding information sources
29 | - Suitable for queries requiring the latest or extensive web information
30 |
31 | 4. **Full Model Search**
32 | - Combines all three search methods above
33 | - Provides the most comprehensive search results, meeting complex information needs
34 | - Automatically integrates and ranks information from different sources
35 |
36 | ### Local LLM and Embedding Model Support
37 |
38 | GraphRAG4OpenWebUI now supports the use of local language models (LLMs) and embedding models, increasing the project's flexibility and privacy. Specifically, we support the following local models:
39 |
40 | 1. **Ollama**
41 | - Supports various open-source LLMs run through Ollama, such as Llama 2, Mistral, etc.
42 | - Can be configured by setting the `API_BASE` environment variable to point to Ollama's API endpoint
43 |
44 | 2. **LM Studio**
45 | - Compatible with models run by LM Studio
46 | - Connect to LM Studio's service by configuring the `API_BASE` environment variable
47 |
48 | 3. **Local Embedding Models**
49 | - Supports the use of locally run embedding models, such as SentenceTransformers
50 | - Specify the embedding model to use by setting the `GRAPHRAG_EMBEDDING_MODEL` environment variable
51 |
52 | This support for local models allows GraphRAG4OpenWebUI to run without relying on external APIs, enhancing data privacy and reducing usage costs.
53 |
54 | ## Installation
55 | Ensure that you have Python 3.8 or higher installed on your system. Then, follow these steps to install:
56 | 1. Clone the repository:
57 | ```bash
58 | git clone https://github.com/your-username/GraphRAG4OpenWebUI.git
59 | cd GraphRAG4OpenWebUI
60 | ```
61 |
62 | 2. Create and activate a virtual environment:
63 | ```bash
64 | python -m venv venv
65 | source venv/bin/activate # On Windows use venv\Scripts\activate
66 | ```
67 |
68 | 3. Install dependencies:
69 | ```bash
70 | pip install -r requirements.txt
71 | ```
72 | Note: The graphrag package might need to be installed from a specific source. If the above command fails to install graphrag, please refer to Microsoft Research's specific instructions or contact the maintainer for the correct installation method.
73 |
74 | ## Configuration
75 |
76 | Before running the API, you need to set the following environment variables. You can do this by creating a `.env` file or exporting them directly in your terminal:
77 |
78 |
79 |
80 | ```bash
81 | # Set the TAVILY API key
82 | export TAVILY_API_KEY="your_tavily_api_key"
83 |
84 | export INPUT_DIR="/path/to/your/input/directory"
85 |
86 | # Set the API key for LLM
87 | export GRAPHRAG_API_KEY="your_actual_api_key_here"
88 |
89 | # Set the API key for embedding (if different from GRAPHRAG_API_KEY)
90 | export GRAPHRAG_API_KEY_EMBEDDING="your_embedding_api_key_here"
91 |
92 | # Set the LLM model
93 | export GRAPHRAG_LLM_MODEL="gemma2"
94 |
95 | # Set the API base URL
96 | export API_BASE="http://localhost:11434/v1"
97 |
98 | # Set the embedding API base URL (default is OpenAI's API)
99 | export API_BASE_EMBEDDING="https://api.openai.com/v1"
100 |
101 | # Set the embedding model (default is "text-embedding-3-small")
102 | export GRAPHRAG_EMBEDDING_MODEL="text-embedding-3-small"
103 | ```
104 |
105 | Make sure to replace the placeholders in the above commands with your actual API keys and paths.
106 |
107 | ## Usage
108 |
109 | 1. Start the server:
110 | ```
111 | python main-en.py
112 | ```
113 | The server will run on `http://localhost:8012`.
114 |
115 | 2. API Endpoints:
116 | - `/v1/chat/completions`: POST request for performing searches
117 | - `/v1/models`: GET request to retrieve the list of available models
118 |
119 | 3. Integration with Open WebUI:
120 | In the Open WebUI configuration, set the API endpoint to `http://localhost:8012/v1/chat/completions`. This will allow Open WebUI to use the search functionality of GraphRAG4OpenWebUI.
121 |
122 | 4. Example search request:
123 | ```python
124 | import requests
125 | import json
126 |
127 | url = "http://localhost:8012/v1/chat/completions"
128 | headers = {"Content-Type": "application/json"}
129 | data = {
130 | "model": "full-model:latest",
131 | "messages": [{"role": "user", "content": "Your search query"}],
132 | "temperature": 0.7
133 | }
134 |
135 | response = requests.post(url, headers=headers, data=json.dumps(data))
136 | print(response.json())
137 | ```
138 |
139 | ## Available Models
140 |
141 | - `graphrag-local-search:latest`: Local search
142 | - `graphrag-global-search:latest`: Global search
143 | - `tavily-search:latest`: Tavily search
144 | - `full-model:latest`: Comprehensive search (includes all search methods above)
145 |
146 | ## Notes
147 |
148 | - Ensure that you have the correct input files (such as Parquet files) in the `INPUT_DIR` directory.
149 | - The API uses asynchronous programming, make sure your environment supports async operations.
150 | - For large-scale deployment, consider using a production-grade ASGI server.
151 | - This project is specifically designed for Open WebUI and can be easily integrated into various web-based applications.
152 |
153 | ## Contributing
154 |
155 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
156 |
157 | ## License
158 |
159 | [Apache-2.0 License](LICENSE)
160 |
--------------------------------------------------------------------------------
/graphrag3dknowledge.py:
--------------------------------------------------------------------------------
1 | import os #用于文件系统操作
2 | import pandas as pd #用于数据处理和操作
3 | import networkx as nx #用于创建和分析图结构
4 | import plotly.graph_objects as go #plotly:用于创建交互式可视化 plotly.graph_objects:用于创建低级的plotly图形对象
5 | from plotly.subplots import make_subplots #用于创建子图
6 | import plotly.express as px #用于快速创建统计图表
7 |
8 | def read_parquet_files(directory):
9 | """
10 | 读取指定目录下的所有Parquet文件并合并
11 | 功能:读取指定目录下的所有Parquet文件并合并成一个DataFrame
12 | 实现:使用os.listdir遍历目录,pd.read_parquet读取每个文件,然后用pd.concat合并
13 | """
14 | dataframes = []
15 | for filename in os.listdir(directory):
16 | if filename.endswith('.parquet'):
17 | file_path = os.path.join(directory, filename)
18 | df = pd.read_parquet(file_path)
19 | dataframes.append(df)
20 | return pd.concat(dataframes, ignore_index=True) if dataframes else pd.DataFrame()
21 |
22 |
23 | def clean_dataframe(df):
24 | """
25 | 清理DataFrame,移除无效的行
26 | 功能:清理DataFrame,移除无效的行
27 | 实现:删除source和target列中的空值,将这两列转换为字符串类型
28 | """
29 | df = df.dropna(subset=['source', 'target'])
30 | df['source'] = df['source'].astype(str)
31 | df['target'] = df['target'].astype(str)
32 | return df
33 |
34 |
35 | def create_knowledge_graph(df):
36 | """
37 | 从DataFrame创建知识图谱
38 | 功能:从DataFrame创建知识图谱
39 | 实现:使用networkx创建有向图,遍历DataFrame的每一行,添加边和属性
40 | """
41 | G = nx.DiGraph()
42 | for _, row in df.iterrows():
43 | source = row['source']
44 | target = row['target']
45 | attributes = {k: v for k, v in row.items() if k not in ['source', 'target']}
46 | G.add_edge(source, target, **attributes)
47 | return G
48 |
49 |
50 | def create_node_link_trace(G, pos):
51 | """
52 | 功能:创建节点和边的3D轨迹
53 | 实现:使用networkx的布局信息创建Plotly的Scatter3d对象
54 | """
55 | edge_x = []
56 | edge_y = []
57 | edge_z = []
58 | for edge in G.edges():
59 | x0, y0, z0 = pos[edge[0]]
60 | x1, y1, z1 = pos[edge[1]]
61 | edge_x.extend([x0, x1, None])
62 | edge_y.extend([y0, y1, None])
63 | edge_z.extend([z0, z1, None])
64 |
65 | edge_trace = go.Scatter3d(
66 | x=edge_x, y=edge_y, z=edge_z,
67 | line=dict(width=0.5, color='#888'),
68 | hoverinfo='none',
69 | mode='lines')
70 |
71 | node_x = [pos[node][0] for node in G.nodes()]
72 | node_y = [pos[node][1] for node in G.nodes()]
73 | node_z = [pos[node][2] for node in G.nodes()]
74 |
75 | node_trace = go.Scatter3d(
76 | x=node_x, y=node_y, z=node_z,
77 | mode='markers',
78 | hoverinfo='text',
79 | marker=dict(
80 | showscale=True,
81 | colorscale='YlGnBu',
82 | size=10,
83 | colorbar=dict(
84 | thickness=15,
85 | title='Node Connections',
86 | xanchor='left',
87 | titleside='right'
88 | )
89 | )
90 | )
91 |
92 | node_adjacencies = []
93 | node_text = []
94 | for node, adjacencies in G.adjacency():
95 | node_adjacencies.append(len(adjacencies))
96 | node_text.append(f'Node: {node}
# of connections: {len(adjacencies)}')
97 |
98 | node_trace.marker.color = node_adjacencies
99 | node_trace.text = node_text
100 |
101 | return edge_trace, node_trace
102 |
103 |
104 | def create_edge_label_trace(G, pos, edge_labels):
105 | """
106 | 功能:创建边标签的3D轨迹
107 | 实现:计算边的中点位置,创建Scatter3d对象显示标签
108 | """
109 | return go.Scatter3d(
110 | x=[pos[edge[0]][0] + (pos[edge[1]][0] - pos[edge[0]][0]) / 2 for edge in edge_labels],
111 | y=[pos[edge[0]][1] + (pos[edge[1]][1] - pos[edge[0]][1]) / 2 for edge in edge_labels],
112 | z=[pos[edge[0]][2] + (pos[edge[1]][2] - pos[edge[0]][2]) / 2 for edge in edge_labels],
113 | mode='text',
114 | text=list(edge_labels.values()),
115 | textposition='middle center',
116 | hoverinfo='none'
117 | )
118 |
119 |
120 | def create_degree_distribution(G):
121 | """
122 | 功能:创建节点度分布直方图
123 | 实现:使用plotly.express创建直方图
124 | """
125 | degrees = [d for n, d in G.degree()]
126 | fig = px.histogram(x=degrees, nbins=20, labels={'x': 'Degree', 'y': 'Count'})
127 | fig.update_layout(
128 | title_text='Node Degree Distribution',
129 | margin=dict(l=0, r=0, t=30, b=0),
130 | height=300
131 | )
132 | return fig
133 |
134 |
135 | def create_centrality_plot(G):
136 | """
137 | 功能:创建节点中心性分布箱线图
138 | 实现:计算度中心性,使用plotly.express创建箱线图
139 | """
140 | centrality = nx.degree_centrality(G)
141 | centrality_values = list(centrality.values())
142 | fig = px.box(y=centrality_values, labels={'y': 'Centrality'})
143 | fig.update_layout(
144 | title_text='Degree Centrality Distribution',
145 | margin=dict(l=0, r=0, t=30, b=0),
146 | height=300
147 | )
148 | return fig
149 |
150 |
151 | def visualize_graph_plotly(G):
152 | """功能:使用Plotly创建全面优化布局的高级交互式知识图谱可视化
153 | 实现:
154 | 创建3D布局
155 | 生成节点和边的轨迹
156 | 创建子图,包括3D图、度分布图和中心性分布图
157 | 添加交互式按钮和滑块
158 | 优化整体布局
159 | """
160 | if G.number_of_nodes() == 0:
161 | print("Graph is empty. Nothing to visualize.")
162 | return
163 |
164 | pos = nx.spring_layout(G, dim=3) # 3D layout
165 | edge_trace, node_trace = create_node_link_trace(G, pos)
166 |
167 | edge_labels = nx.get_edge_attributes(G, 'relation')
168 | edge_label_trace = create_edge_label_trace(G, pos, edge_labels)
169 |
170 | degree_dist_fig = create_degree_distribution(G)
171 | centrality_fig = create_centrality_plot(G)
172 |
173 | fig = make_subplots(
174 | rows=2, cols=2,
175 | column_widths=[0.7, 0.3],
176 | row_heights=[0.7, 0.3],
177 | specs=[
178 | [{"type": "scene", "rowspan": 2}, {"type": "xy"}],
179 | [None, {"type": "xy"}]
180 | ],
181 | subplot_titles=("3D Knowledge Graph Code by AI超元域频道", "Node Degree Distribution", "Degree Centrality Distribution")
182 | )
183 |
184 | fig.add_trace(edge_trace, row=1, col=1)
185 | fig.add_trace(node_trace, row=1, col=1)
186 | fig.add_trace(edge_label_trace, row=1, col=1)
187 |
188 | fig.add_trace(degree_dist_fig.data[0], row=1, col=2)
189 | fig.add_trace(centrality_fig.data[0], row=2, col=2)
190 |
191 | # Update 3D layout
192 | fig.update_layout(
193 | scene=dict(
194 | xaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
195 | yaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
196 | zaxis=dict(showticklabels=False, showgrid=False, zeroline=False),
197 | aspectmode='cube'
198 | ),
199 | scene_camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
200 | )
201 |
202 | # Add buttons for different layouts
203 | fig.update_layout(
204 | updatemenus=[
205 | dict(
206 | type="buttons",
207 | direction="left",
208 | buttons=list([
209 | dict(args=[{"visible": [True, True, True, True, True]}], label="Show All", method="update"),
210 | dict(args=[{"visible": [True, True, False, True, True]}], label="Hide Edge Labels",
211 | method="update"),
212 | dict(args=[{"visible": [False, True, False, True, True]}], label="Nodes Only", method="update")
213 | ]),
214 | pad={"r": 10, "t": 10},
215 | showactive=True,
216 | x=0.05,
217 | xanchor="left",
218 | y=1.1,
219 | yanchor="top"
220 | ),
221 | ]
222 | )
223 |
224 | # Add slider for node size
225 | fig.update_layout(
226 | sliders=[dict(
227 | active=0,
228 | currentvalue={"prefix": "Node Size: "},
229 | pad={"t": 50},
230 | steps=[dict(method='update',
231 | args=[{'marker.size': [i] * len(G.nodes)}],
232 | label=str(i)) for i in range(5, 21, 5)]
233 | )]
234 | )
235 |
236 | # 优化整体布局
237 | # fig.update_layout(
238 | # height=1198, # 增加整体高度
239 | # width=2055, # 增加整体宽度
240 | # title_text="Advanced Interactive Knowledge Graph",
241 | # margin=dict(l=10, r=10, t=25, b=10),
242 | # legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01)
243 | # )
244 |
245 | fig.show()
246 |
247 |
248 | def main():
249 | """ 功能:主函数,协调整个程序的执行流程
250 | 实现:
251 | 读取Parquet文件
252 | 清理数据
253 | 创建知识图谱
254 | 打印图的统计信息
255 | 调用可视化函数
256 | """
257 | directory = '/Users/charlesqin/PycharmProjects/RAGCode/inputs/artifacts' # 替换为实际的目录路径
258 | df = read_parquet_files(directory)
259 |
260 | if df.empty:
261 | print("No data found in the specified directory.")
262 | return
263 |
264 | print("Original DataFrame shape:", df.shape)
265 | print("Original DataFrame columns:", df.columns.tolist())
266 | print("Original DataFrame head:")
267 | print(df.head())
268 |
269 | df = clean_dataframe(df)
270 |
271 | print("\nCleaned DataFrame shape:", df.shape)
272 | print("Cleaned DataFrame head:")
273 | print(df.head())
274 |
275 | if df.empty:
276 | print("No valid data remaining after cleaning.")
277 | return
278 |
279 | G = create_knowledge_graph(df)
280 |
281 | print(f"\nGraph statistics:")
282 | print(f"Nodes: {G.number_of_nodes()}")
283 | print(f"Edges: {G.number_of_edges()}")
284 |
285 | if G.number_of_nodes() > 0:
286 | print(f"Connected components: {nx.number_connected_components(G.to_undirected())}")
287 | visualize_graph_plotly(G)
288 | else:
289 | print("Graph is empty. Cannot visualize.")
290 |
291 |
292 | if __name__ == "__main__":
293 | main()
294 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/main-cn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import asyncio
3 | import time
4 | import uuid
5 | import json
6 | import re
7 | import pandas as pd
8 | import tiktoken
9 | import logging
10 | from fastapi import FastAPI, HTTPException, Request
11 | from fastapi.responses import JSONResponse, StreamingResponse
12 | from pydantic import BaseModel, Field
13 | from typing import List, Optional, Dict, Any, Union
14 | from contextlib import asynccontextmanager
15 | from tavily import TavilyClient
16 |
17 |
18 | # GraphRAG 相关导入
19 | from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
20 | from graphrag.query.indexer_adapters import (
21 | read_indexer_covariates,
22 | read_indexer_entities,
23 | read_indexer_relationships,
24 | read_indexer_reports,
25 | read_indexer_text_units,
26 | )
27 | from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
28 | from graphrag.query.llm.oai.chat_openai import ChatOpenAI
29 | from graphrag.query.llm.oai.embedding import OpenAIEmbedding
30 | from graphrag.query.llm.oai.typing import OpenaiApiType
31 | from graphrag.query.question_gen.local_gen import LocalQuestionGen
32 | from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
33 | from graphrag.query.structured_search.local_search.search import LocalSearch
34 | from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
35 | from graphrag.query.structured_search.global_search.search import GlobalSearch
36 | from graphrag.vector_stores.lancedb import LanceDBVectorStore
37 |
38 | # 设置日志
39 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
40 | logger = logging.getLogger(__name__)
41 |
42 | # 设置常量和配置
43 | INPUT_DIR = os.getenv('INPUT_DIR')
44 | LANCEDB_URI = f"{INPUT_DIR}/lancedb"
45 | COMMUNITY_REPORT_TABLE = "create_final_community_reports"
46 | ENTITY_TABLE = "create_final_nodes"
47 | ENTITY_EMBEDDING_TABLE = "create_final_entities"
48 | RELATIONSHIP_TABLE = "create_final_relationships"
49 | COVARIATE_TABLE = "create_final_covariates"
50 | TEXT_UNIT_TABLE = "create_final_text_units"
51 | COMMUNITY_LEVEL = 2
52 | PORT = 8012
53 |
54 | # 全局变量,用于存储搜索引擎和问题生成器
55 | local_search_engine = None
56 | global_search_engine = None
57 | question_generator = None
58 |
59 |
60 | # 数据模型
61 | class Message(BaseModel):
62 | role: str
63 | content: str
64 |
65 |
66 | class ChatCompletionRequest(BaseModel):
67 | model: str
68 | messages: List[Message]
69 | temperature: Optional[float] = 1.0
70 | top_p: Optional[float] = 1.0
71 | n: Optional[int] = 1
72 | stream: Optional[bool] = False
73 | stop: Optional[Union[str, List[str]]] = None
74 | max_tokens: Optional[int] = None
75 | presence_penalty: Optional[float] = 0
76 | frequency_penalty: Optional[float] = 0
77 | logit_bias: Optional[Dict[str, float]] = None
78 | user: Optional[str] = None
79 |
80 |
81 | class ChatCompletionResponseChoice(BaseModel):
82 | index: int
83 | message: Message
84 | finish_reason: Optional[str] = None
85 |
86 |
87 | class Usage(BaseModel):
88 | prompt_tokens: int
89 | completion_tokens: int
90 | total_tokens: int
91 |
92 |
93 | class ChatCompletionResponse(BaseModel):
94 | id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
95 | object: str = "chat.completion"
96 | created: int = Field(default_factory=lambda: int(time.time()))
97 | model: str
98 | choices: List[ChatCompletionResponseChoice]
99 | usage: Usage
100 | system_fingerprint: Optional[str] = None
101 |
102 |
103 | async def setup_llm_and_embedder():
104 | """
105 | 设置语言模型(LLM)和嵌入模型
106 | """
107 | logger.info("正在设置LLM和嵌入器")
108 |
109 | # 获取API密钥和基础URL
110 | api_key = os.environ.get("GRAPHRAG_API_KEY", "YOUR_API_KEY")
111 | api_key_embedding = os.environ.get("GRAPHRAG_API_KEY_EMBEDDING", api_key)
112 | api_base = os.environ.get("API_BASE", "https://api.openai.com/v1")
113 | api_base_embedding = os.environ.get("API_BASE_EMBEDDING", "https://api.openai.com/v1")
114 |
115 | # 获取模型名称
116 | llm_model = os.environ.get("GRAPHRAG_LLM_MODEL", "gpt-3.5-turbo-0125")
117 | embedding_model = os.environ.get("GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small")
118 |
119 | # 检查API密钥是否存在
120 | if api_key == "YOUR_API_KEY":
121 | logger.error("环境变量中未找到有效的GRAPHRAG_API_KEY")
122 | raise ValueError("GRAPHRAG_API_KEY未正确设置")
123 |
124 | # 初始化ChatOpenAI实例
125 | llm = ChatOpenAI(
126 | api_key=api_key,
127 | api_base=api_base,
128 | model=llm_model,
129 | api_type=OpenaiApiType.OpenAI,
130 | max_retries=20,
131 | )
132 |
133 | # 初始化token编码器
134 | token_encoder = tiktoken.get_encoding("cl100k_base")
135 |
136 | # 初始化文本嵌入模型
137 | text_embedder = OpenAIEmbedding(
138 | api_key=api_key_embedding,
139 | api_base=api_base_embedding,
140 | api_type=OpenaiApiType.OpenAI,
141 | model=embedding_model,
142 | deployment_name=embedding_model,
143 | max_retries=20,
144 | )
145 |
146 |
147 | logger.info("LLM和嵌入器设置完成")
148 | return llm, token_encoder, text_embedder
149 |
150 |
151 | async def load_context():
152 | """
153 | 加载上下文数据,包括实体、关系、报告、文本单元和协变量
154 | """
155 | logger.info("正在加载上下文数据")
156 | try:
157 | entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
158 | entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet")
159 | entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)
160 |
161 | description_embedding_store = LanceDBVectorStore(collection_name="entity_description_embeddings")
162 | description_embedding_store.connect(db_uri=LANCEDB_URI)
163 | store_entity_semantic_embeddings(entities=entities, vectorstore=description_embedding_store)
164 |
165 | relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
166 | relationships = read_indexer_relationships(relationship_df)
167 |
168 | report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
169 | reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)
170 |
171 | text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
172 | text_units = read_indexer_text_units(text_unit_df)
173 |
174 | covariate_df = pd.read_parquet(f"{INPUT_DIR}/{COVARIATE_TABLE}.parquet")
175 | claims = read_indexer_covariates(covariate_df)
176 | logger.info(f"声明记录数: {len(claims)}")
177 | covariates = {"claims": claims}
178 |
179 | logger.info("上下文数据加载完成")
180 | return entities, relationships, reports, text_units, description_embedding_store, covariates
181 | except Exception as e:
182 | logger.error(f"加载上下文数据时出错: {str(e)}")
183 | raise
184 |
185 |
186 | async def setup_search_engines(llm, token_encoder, text_embedder, entities, relationships, reports, text_units,
187 | description_embedding_store, covariates):
188 | """
189 | 设置本地搜索引擎和全局搜索引擎
190 | """
191 | logger.info("正在设置搜索引擎")
192 |
193 | # 设置本地搜索引擎
194 | local_context_builder = LocalSearchMixedContext(
195 | community_reports=reports,
196 | text_units=text_units,
197 | entities=entities,
198 | relationships=relationships,
199 | covariates=covariates,
200 | entity_text_embeddings=description_embedding_store,
201 | embedding_vectorstore_key=EntityVectorStoreKey.ID,
202 | text_embedder=text_embedder,
203 | token_encoder=token_encoder,
204 | )
205 |
206 | local_context_params = {
207 | "text_unit_prop": 0.5,
208 | "community_prop": 0.1,
209 | "conversation_history_max_turns": 5,
210 | "conversation_history_user_turns_only": True,
211 | "top_k_mapped_entities": 10,
212 | "top_k_relationships": 10,
213 | "include_entity_rank": True,
214 | "include_relationship_weight": True,
215 | "include_community_rank": False,
216 | "return_candidate_context": False,
217 | "embedding_vectorstore_key": EntityVectorStoreKey.ID,
218 | "max_tokens": 12_000,
219 | }
220 |
221 | local_llm_params = {
222 | "max_tokens": 2_000,
223 | "temperature": 0.0,
224 | }
225 |
226 | local_search_engine = LocalSearch(
227 | llm=llm,
228 | context_builder=local_context_builder,
229 | token_encoder=token_encoder,
230 | llm_params=local_llm_params,
231 | context_builder_params=local_context_params,
232 | response_type="multiple paragraphs",
233 | )
234 |
235 | # 设置全局搜索引擎
236 | global_context_builder = GlobalCommunityContext(
237 | community_reports=reports,
238 | entities=entities,
239 | token_encoder=token_encoder,
240 | )
241 |
242 | global_context_builder_params = {
243 | "use_community_summary": False,
244 | "shuffle_data": True,
245 | "include_community_rank": True,
246 | "min_community_rank": 0,
247 | "community_rank_name": "rank",
248 | "include_community_weight": True,
249 | "community_weight_name": "occurrence weight",
250 | "normalize_community_weight": True,
251 | "max_tokens": 12_000,
252 | "context_name": "Reports",
253 | }
254 |
255 | map_llm_params = {
256 | "max_tokens": 1000,
257 | "temperature": 0.0,
258 | "response_format": {"type": "json_object"},
259 | }
260 |
261 | reduce_llm_params = {
262 | "max_tokens": 2000,
263 | "temperature": 0.0,
264 | }
265 |
266 | global_search_engine = GlobalSearch(
267 | llm=llm,
268 | context_builder=global_context_builder,
269 | token_encoder=token_encoder,
270 | max_data_tokens=12_000,
271 | map_llm_params=map_llm_params,
272 | reduce_llm_params=reduce_llm_params,
273 | allow_general_knowledge=False,
274 | json_mode=True,
275 | context_builder_params=global_context_builder_params,
276 | concurrent_coroutines=32,
277 | response_type="multiple paragraphs",
278 | )
279 |
280 | logger.info("搜索引擎设置完成")
281 | return local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params
282 |
283 |
284 | def format_response(response):
285 | """
286 | 格式化响应,添加适当的换行和段落分隔。
287 | """
288 | paragraphs = re.split(r'\n{2,}', response)
289 |
290 | formatted_paragraphs = []
291 | for para in paragraphs:
292 | if '```' in para:
293 | parts = para.split('```')
294 | for i, part in enumerate(parts):
295 | if i % 2 == 1: # 这是代码块
296 | parts[i] = f"\n```\n{part.strip()}\n```\n"
297 | para = ''.join(parts)
298 | else:
299 | para = para.replace('. ', '.\n')
300 |
301 | formatted_paragraphs.append(para.strip())
302 |
303 | return '\n\n'.join(formatted_paragraphs)
304 |
305 |
306 | async def tavily_search(prompt: str):
307 | """
308 | 使用Tavily API进行搜索
309 | """
310 | try:
311 | client = TavilyClient(api_key=os.environ['TAVILY_API_KEY'])
312 | resp = client.search(prompt, search_depth="advanced")
313 |
314 | # 将Tavily响应转换为Markdown格式
315 | markdown_response = "# 搜索结果\n\n"
316 | for result in resp.get('results', []):
317 | markdown_response += f"## [{result['title']}]({result['url']})\n\n"
318 | markdown_response += f"{result['content']}\n\n"
319 |
320 | return markdown_response
321 | except Exception as e:
322 | raise HTTPException(status_code=500, detail=f"Tavily搜索错误: {str(e)}")
323 |
324 |
325 | @asynccontextmanager
326 | async def lifespan(app: FastAPI):
327 | # 启动时执行
328 | global local_search_engine, global_search_engine, question_generator
329 | try:
330 | logger.info("正在初始化搜索引擎和问题生成器...")
331 | llm, token_encoder, text_embedder = await setup_llm_and_embedder()
332 | entities, relationships, reports, text_units, description_embedding_store, covariates = await load_context()
333 | local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params = await setup_search_engines(
334 | llm, token_encoder, text_embedder, entities, relationships, reports, text_units,
335 | description_embedding_store, covariates
336 | )
337 |
338 | question_generator = LocalQuestionGen(
339 | llm=llm,
340 | context_builder=local_context_builder,
341 | token_encoder=token_encoder,
342 | llm_params=local_llm_params,
343 | context_builder_params=local_context_params,
344 | )
345 | logger.info("初始化完成。")
346 | except Exception as e:
347 | logger.error(f"初始化过程中出错: {str(e)}")
348 | raise
349 |
350 | yield
351 |
352 | # 关闭时执行
353 | logger.info("正在关闭...")
354 |
355 |
356 | app = FastAPI(lifespan=lifespan)
357 |
358 |
359 | # 在 chat_completions 函数中添加以下代码
360 |
361 | async def full_model_search(prompt: str):
362 | """
363 | 执行全模型搜索,包括本地检索、全局检索和 Tavily 搜索
364 | """
365 | local_result = await local_search_engine.asearch(prompt)
366 | global_result = await global_search_engine.asearch(prompt)
367 | tavily_result = await tavily_search(prompt)
368 |
369 | # 格式化结果
370 | formatted_result = "# 🔥🔥🔥综合搜索结果\n\n"
371 |
372 | formatted_result += "## 🔥🔥🔥本地检索结果\n"
373 | formatted_result += format_response(local_result.response) + "\n\n"
374 |
375 | formatted_result += "## 🔥🔥🔥全局检索结果\n"
376 | formatted_result += format_response(global_result.response) + "\n\n"
377 |
378 | formatted_result += "## 🔥🔥🔥Tavily 搜索结果\n"
379 | formatted_result += tavily_result + "\n\n"
380 |
381 | return formatted_result
382 |
383 |
384 | @app.post("/v1/chat/completions")
385 | async def chat_completions(request: ChatCompletionRequest):
386 | if not local_search_engine or not global_search_engine:
387 | logger.error("搜索引擎未初始化")
388 | raise HTTPException(status_code=500, detail="搜索引擎未初始化")
389 |
390 | try:
391 | logger.info(f"收到聊天完成请求: {request}")
392 | prompt = request.messages[-1].content
393 | logger.info(f"处理提示: {prompt}")
394 |
395 | # 根据模型选择使用不同的搜索方法
396 | if request.model == "graphrag-global-search:latest":
397 | result = await global_search_engine.asearch(prompt)
398 | formatted_response = format_response(result.response)
399 | elif request.model == "tavily-search:latest":
400 | result = await tavily_search(prompt)
401 | formatted_response = result
402 | elif request.model == "full-model:latest":
403 | formatted_response = await full_model_search(prompt)
404 | else: # 默认使用本地搜索
405 | result = await local_search_engine.asearch(prompt)
406 | formatted_response = format_response(result.response)
407 |
408 | logger.info(f"格式化的搜索结果: {formatted_response}")
409 |
410 | # 流式响应和非流式响应的处理保持不变
411 | if request.stream:
412 | async def generate_stream():
413 | chunk_id = f"chatcmpl-{uuid.uuid4().hex}"
414 | lines = formatted_response.split('\n')
415 | for i, line in enumerate(lines):
416 | chunk = {
417 | "id": chunk_id,
418 | "object": "chat.completion.chunk",
419 | "created": int(time.time()),
420 | "model": request.model,
421 | "choices": [
422 | {
423 | "index": 0,
424 | "delta": {"content": line + '\n'}, # if i > 0 else {"role": "assistant", "content": ""},
425 | "finish_reason": None
426 | }
427 | ]
428 | }
429 | yield f"data: {json.dumps(chunk)}\n\n"
430 | await asyncio.sleep(0.05)
431 |
432 | final_chunk = {
433 | "id": chunk_id,
434 | "object": "chat.completion.chunk",
435 | "created": int(time.time()),
436 | "model": request.model,
437 | "choices": [
438 | {
439 | "index": 0,
440 | "delta": {},
441 | "finish_reason": "stop"
442 | }
443 | ]
444 | }
445 | yield f"data: {json.dumps(final_chunk)}\n\n"
446 | yield "data: [DONE]\n\n"
447 |
448 | return StreamingResponse(generate_stream(), media_type="text/event-stream")
449 | else:
450 | response = ChatCompletionResponse(
451 | model=request.model,
452 | choices=[
453 | ChatCompletionResponseChoice(
454 | index=0,
455 | message=Message(role="assistant", content=formatted_response),
456 | finish_reason="stop"
457 | )
458 | ],
459 | usage=Usage(
460 | prompt_tokens=len(prompt.split()),
461 | completion_tokens=len(formatted_response.split()),
462 | total_tokens=len(prompt.split()) + len(formatted_response.split())
463 | )
464 | )
465 | logger.info(f"发送响应: {response}")
466 | return JSONResponse(content=response.dict())
467 |
468 | except Exception as e:
469 | logger.error(f"处理聊天完成时出错: {str(e)}")
470 | raise HTTPException(status_code=500, detail=str(e))
471 |
472 | @app.get("/v1/models")
473 | async def list_models():
474 | """
475 | 返回可用模型列表
476 | """
477 | logger.info("收到模型列表请求")
478 | current_time = int(time.time())
479 | models = [
480 | {"id": "graphrag-local-search:latest", "object": "model", "created": current_time - 100000, "owned_by": "graphrag"},
481 | {"id": "graphrag-global-search:latest", "object": "model", "created": current_time - 95000, "owned_by": "graphrag"},
482 | # {"id": "graphrag-question-generator:latest", "object": "model", "created": current_time - 90000, "owned_by": "graphrag"},
483 | # {"id": "gpt-3.5-turbo:latest", "object": "model", "created": current_time - 80000, "owned_by": "openai"},
484 | # {"id": "text-embedding-3-small:latest", "object": "model", "created": current_time - 70000, "owned_by": "openai"},
485 | {"id": "tavily-search:latest", "object": "model", "created": current_time - 85000, "owned_by": "tavily"},
486 | {"id": "full-model:latest", "object": "model", "created": current_time - 80000, "owned_by": "combined"}
487 |
488 | ]
489 |
490 | response = {
491 | "object": "list",
492 | "data": models
493 | }
494 |
495 | logger.info(f"发送模型列表: {response}")
496 | return JSONResponse(content=response)
497 |
498 | if __name__ == "__main__":
499 | import uvicorn
500 |
501 | logger.info(f"在端口 {PORT} 上启动服务器")
502 | uvicorn.run(app, host="0.0.0.0", port=PORT)
503 |
504 |
--------------------------------------------------------------------------------
/main-en.py:
--------------------------------------------------------------------------------
1 | import os
2 | import asyncio
3 | import time
4 | import uuid
5 | import json
6 | import re
7 | import pandas as pd
8 | import tiktoken
9 | import logging
10 | from fastapi import FastAPI, HTTPException, Request
11 | from fastapi.responses import JSONResponse, StreamingResponse
12 | from pydantic import BaseModel, Field
13 | from typing import List, Optional, Dict, Any, Union
14 | from contextlib import asynccontextmanager
15 | from tavily import TavilyClient
16 |
17 |
18 | # GraphRAG related imports
19 | from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
20 | from graphrag.query.indexer_adapters import (
21 | read_indexer_covariates,
22 | read_indexer_entities,
23 | read_indexer_relationships,
24 | read_indexer_reports,
25 | read_indexer_text_units,
26 | )
27 | from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
28 | from graphrag.query.llm.oai.chat_openai import ChatOpenAI
29 | from graphrag.query.llm.oai.embedding import OpenAIEmbedding
30 | from graphrag.query.llm.oai.typing import OpenaiApiType
31 | from graphrag.query.question_gen.local_gen import LocalQuestionGen
32 | from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
33 | from graphrag.query.structured_search.local_search.search import LocalSearch
34 | from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
35 | from graphrag.query.structured_search.global_search.search import GlobalSearch
36 | from graphrag.vector_stores.lancedb import LanceDBVectorStore
37 |
38 | # Set up logging
39 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
40 | logger = logging.getLogger(__name__)
41 |
42 | # Set constants and configurations
43 | INPUT_DIR = os.getenv('INPUT_DIR')
44 | LANCEDB_URI = f"{INPUT_DIR}/lancedb"
45 | COMMUNITY_REPORT_TABLE = "create_final_community_reports"
46 | ENTITY_TABLE = "create_final_nodes"
47 | ENTITY_EMBEDDING_TABLE = "create_final_entities"
48 | RELATIONSHIP_TABLE = "create_final_relationships"
49 | COVARIATE_TABLE = "create_final_covariates"
50 | TEXT_UNIT_TABLE = "create_final_text_units"
51 | COMMUNITY_LEVEL = 2
52 | PORT = 8012
53 |
54 | # Global variables for storing search engines and question generator
55 | local_search_engine = None
56 | global_search_engine = None
57 | question_generator = None
58 |
59 |
60 | # Data models
61 | class Message(BaseModel):
62 | role: str
63 | content: str
64 |
65 |
66 | class ChatCompletionRequest(BaseModel):
67 | model: str
68 | messages: List[Message]
69 | temperature: Optional[float] = 1.0
70 | top_p: Optional[float] = 1.0
71 | n: Optional[int] = 1
72 | stream: Optional[bool] = False
73 | stop: Optional[Union[str, List[str]]] = None
74 | max_tokens: Optional[int] = None
75 | presence_penalty: Optional[float] = 0
76 | frequency_penalty: Optional[float] = 0
77 | logit_bias: Optional[Dict[str, float]] = None
78 | user: Optional[str] = None
79 |
80 |
81 | class ChatCompletionResponseChoice(BaseModel):
82 | index: int
83 | message: Message
84 | finish_reason: Optional[str] = None
85 |
86 |
87 | class Usage(BaseModel):
88 | prompt_tokens: int
89 | completion_tokens: int
90 | total_tokens: int
91 |
92 |
93 | class ChatCompletionResponse(BaseModel):
94 | id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
95 | object: str = "chat.completion"
96 | created: int = Field(default_factory=lambda: int(time.time()))
97 | model: str
98 | choices: List[ChatCompletionResponseChoice]
99 | usage: Usage
100 | system_fingerprint: Optional[str] = None
101 |
102 |
103 | async def setup_llm_and_embedder():
104 | """
105 | Set up Language Model (LLM) and embedding model
106 | """
107 | logger.info("Setting up LLM and embedder")
108 |
109 | # Get API keys and base URLs
110 | api_key = os.environ.get("GRAPHRAG_API_KEY", "YOUR_API_KEY")
111 | api_key_embedding = os.environ.get("GRAPHRAG_API_KEY_EMBEDDING", api_key)
112 | api_base = os.environ.get("API_BASE", "https://api.openai.com/v1")
113 | api_base_embedding = os.environ.get("API_BASE_EMBEDDING", "https://api.openai.com/v1")
114 |
115 | # Get model names
116 | llm_model = os.environ.get("GRAPHRAG_LLM_MODEL", "gpt-3.5-turbo-0125")
117 | embedding_model = os.environ.get("GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small")
118 |
119 | # Check if API key exists
120 | if api_key == "YOUR_API_KEY":
121 | logger.error("Valid GRAPHRAG_API_KEY not found in environment variables")
122 | raise ValueError("GRAPHRAG_API_KEY is not set correctly")
123 |
124 | # Initialize ChatOpenAI instance
125 | llm = ChatOpenAI(
126 | api_key=api_key,
127 | api_base=api_base,
128 | model=llm_model,
129 | api_type=OpenaiApiType.OpenAI,
130 | max_retries=20,
131 | )
132 |
133 | # Initialize token encoder
134 | token_encoder = tiktoken.get_encoding("cl100k_base")
135 |
136 | # Initialize text embedding model
137 | text_embedder = OpenAIEmbedding(
138 | api_key=api_key_embedding,
139 | api_base=api_base_embedding,
140 | api_type=OpenaiApiType.OpenAI,
141 | model=embedding_model,
142 | deployment_name=embedding_model,
143 | max_retries=20,
144 | )
145 |
146 |
147 | logger.info("LLM and embedder setup complete")
148 | return llm, token_encoder, text_embedder
149 |
150 |
151 | async def load_context():
152 | """
153 | Load context data including entities, relationships, reports, text units, and covariates
154 | """
155 | logger.info("Loading context data")
156 | try:
157 | entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
158 | entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet")
159 | entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)
160 |
161 | description_embedding_store = LanceDBVectorStore(collection_name="entity_description_embeddings")
162 | description_embedding_store.connect(db_uri=LANCEDB_URI)
163 | store_entity_semantic_embeddings(entities=entities, vectorstore=description_embedding_store)
164 |
165 | relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
166 | relationships = read_indexer_relationships(relationship_df)
167 |
168 | report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
169 | reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)
170 |
171 | text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
172 | text_units = read_indexer_text_units(text_unit_df)
173 |
174 | covariate_df = pd.read_parquet(f"{INPUT_DIR}/{COVARIATE_TABLE}.parquet")
175 | claims = read_indexer_covariates(covariate_df)
176 | logger.info(f"Number of claim records: {len(claims)}")
177 | covariates = {"claims": claims}
178 |
179 | logger.info("Context data loading complete")
180 | return entities, relationships, reports, text_units, description_embedding_store, covariates
181 | except Exception as e:
182 | logger.error(f"Error loading context data: {str(e)}")
183 | raise
184 |
185 |
186 | async def setup_search_engines(llm, token_encoder, text_embedder, entities, relationships, reports, text_units,
187 | description_embedding_store, covariates):
188 | """
189 | Set up local and global search engines
190 | """
191 | logger.info("Setting up search engines")
192 |
193 | # Set up local search engine
194 | local_context_builder = LocalSearchMixedContext(
195 | community_reports=reports,
196 | text_units=text_units,
197 | entities=entities,
198 | relationships=relationships,
199 | covariates=covariates,
200 | entity_text_embeddings=description_embedding_store,
201 | embedding_vectorstore_key=EntityVectorStoreKey.ID,
202 | text_embedder=text_embedder,
203 | token_encoder=token_encoder,
204 | )
205 |
206 | local_context_params = {
207 | "text_unit_prop": 0.5,
208 | "community_prop": 0.1,
209 | "conversation_history_max_turns": 5,
210 | "conversation_history_user_turns_only": True,
211 | "top_k_mapped_entities": 10,
212 | "top_k_relationships": 10,
213 | "include_entity_rank": True,
214 | "include_relationship_weight": True,
215 | "include_community_rank": False,
216 | "return_candidate_context": False,
217 | "embedding_vectorstore_key": EntityVectorStoreKey.ID,
218 | "max_tokens": 12_000,
219 | }
220 |
221 | local_llm_params = {
222 | "max_tokens": 2_000,
223 | "temperature": 0.0,
224 | }
225 |
226 | local_search_engine = LocalSearch(
227 | llm=llm,
228 | context_builder=local_context_builder,
229 | token_encoder=token_encoder,
230 | llm_params=local_llm_params,
231 | context_builder_params=local_context_params,
232 | response_type="multiple paragraphs",
233 | )
234 |
235 | # Set up global search engine
236 | global_context_builder = GlobalCommunityContext(
237 | community_reports=reports,
238 | entities=entities,
239 | token_encoder=token_encoder,
240 | )
241 |
242 | global_context_builder_params = {
243 | "use_community_summary": False,
244 | "shuffle_data": True,
245 | "include_community_rank": True,
246 | "min_community_rank": 0,
247 | "community_rank_name": "rank",
248 | "include_community_weight": True,
249 | "community_weight_name": "occurrence weight",
250 | "normalize_community_weight": True,
251 | "max_tokens": 12_000,
252 | "context_name": "Reports",
253 | }
254 |
255 | map_llm_params = {
256 | "max_tokens": 1000,
257 | "temperature": 0.0,
258 | "response_format": {"type": "json_object"},
259 | }
260 |
261 | reduce_llm_params = {
262 | "max_tokens": 2000,
263 | "temperature": 0.0,
264 | }
265 |
266 | global_search_engine = GlobalSearch(
267 | llm=llm,
268 | context_builder=global_context_builder,
269 | token_encoder=token_encoder,
270 | max_data_tokens=12_000,
271 | map_llm_params=map_llm_params,
272 | reduce_llm_params=reduce_llm_params,
273 | allow_general_knowledge=False,
274 | json_mode=True,
275 | context_builder_params=global_context_builder_params,
276 | concurrent_coroutines=32,
277 | response_type="multiple paragraphs",
278 | )
279 |
280 | logger.info("Search engines setup complete")
281 | return local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params
282 |
283 |
284 | def format_response(response):
285 | """
286 | Format the response by adding appropriate line breaks and paragraph separations.
287 | """
288 | paragraphs = re.split(r'\n{2,}', response)
289 |
290 | formatted_paragraphs = []
291 | for para in paragraphs:
292 | if '```' in para:
293 | parts = para.split('```')
294 | for i, part in enumerate(parts):
295 | if i % 2 == 1: # This is a code block
296 | parts[i] = f"\n```\n{part.strip()}\n```\n"
297 | para = ''.join(parts)
298 | else:
299 | para = para.replace('. ', '.\n')
300 |
301 | formatted_paragraphs.append(para.strip())
302 |
303 | return '\n\n'.join(formatted_paragraphs)
304 |
305 |
306 | async def tavily_search(prompt: str):
307 | """
308 | Perform a search using the Tavily API
309 | """
310 | try:
311 | client = TavilyClient(api_key=os.environ['TAVILY_API_KEY'])
312 | resp = client.search(prompt, search_depth="advanced")
313 |
314 | # Convert Tavily response to Markdown format
315 | markdown_response = "# Search Results\n\n"
316 | for result in resp.get('results', []):
317 | markdown_response += f"## [{result['title']}]({result['url']})\n\n"
318 | markdown_response += f"{result['content']}\n\n"
319 |
320 | return markdown_response
321 | except Exception as e:
322 | raise HTTPException(status_code=500, detail=f"Tavily search error: {str(e)}")
323 |
324 |
325 | @asynccontextmanager
326 | async def lifespan(app: FastAPI):
327 | # Execute on startup
328 | global local_search_engine, global_search_engine, question_generator
329 | try:
330 | logger.info("Initializing search engines and question generator...")
331 | llm, token_encoder, text_embedder = await setup_llm_and_embedder()
332 | entities, relationships, reports, text_units, description_embedding_store, covariates = await load_context()
333 | local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params = await setup_search_engines(
334 | llm, token_encoder, text_embedder, entities, relationships, reports, text_units,
335 | description_embedding_store, covariates
336 | )
337 |
338 | question_generator = LocalQuestionGen(
339 | llm=llm,
340 | context_builder=local_context_builder,
341 | token_encoder=token_encoder,
342 | llm_params=local_llm_params,
343 | context_builder_params=local_context_params,
344 | )
345 | logger.info("Initialization complete.")
346 | except Exception as e:
347 | logger.error(f"Error during initialization: {str(e)}")
348 | raise
349 |
350 | yield
351 |
352 | # Execute on shutdown
353 | logger.info("Shutting down...")
354 |
355 |
356 | app = FastAPI(lifespan=lifespan)
357 |
358 |
359 | # Add the following code to the chat_completions function
360 |
361 | async def full_model_search(prompt: str):
362 | """
363 | Perform a full model search, including local retrieval, global retrieval, and Tavily search
364 | """
365 | local_result = await local_search_engine.asearch(prompt)
366 | global_result = await global_search_engine.asearch(prompt)
367 | tavily_result = await tavily_search(prompt)
368 |
369 | # Format results
370 | formatted_result = "# 🔥🔥🔥Comprehensive Search Results\n\n"
371 |
372 | formatted_result += "## 🔥🔥🔥Local Retrieval Results\n"
373 | formatted_result += format_response(local_result.response) + "\n\n"
374 |
375 | formatted_result += "## 🔥🔥🔥Global Retrieval Results\n"
376 | formatted_result += format_response(global_result.response) + "\n\n"
377 |
378 | formatted_result += "## 🔥🔥🔥Tavily Search Results\n"
379 | formatted_result += tavily_result + "\n\n"
380 |
381 | return formatted_result
382 |
383 | @app.post("/v1/chat/completions")
384 | async def chat_completions(request: ChatCompletionRequest):
385 | if not local_search_engine or not global_search_engine:
386 | logger.error("Search engines not initialized")
387 | raise HTTPException(status_code=500, detail="Search engines not initialized")
388 |
389 | try:
390 | logger.info(f"Received chat completion request: {request}")
391 | prompt = request.messages[-1].content
392 | logger.info(f"Processing prompt: {prompt}")
393 |
394 | # Choose different search methods based on the model
395 | if request.model == "graphrag-global-search:latest":
396 | result = await global_search_engine.asearch(prompt)
397 | formatted_response = format_response(result.response)
398 | elif request.model == "tavily-search:latest":
399 | result = await tavily_search(prompt)
400 | formatted_response = result
401 | elif request.model == "full-model:latest":
402 | formatted_response = await full_model_search(prompt)
403 | else: # Default to local search
404 | result = await local_search_engine.asearch(prompt)
405 | formatted_response = format_response(result.response)
406 |
407 | logger.info(f"Formatted search result: {formatted_response}")
408 |
409 | # Handle streaming and non-streaming responses
410 | if request.stream:
411 | async def generate_stream():
412 | chunk_id = f"chatcmpl-{uuid.uuid4().hex}"
413 | lines = formatted_response.split('\n')
414 | for i, line in enumerate(lines):
415 | chunk = {
416 | "id": chunk_id,
417 | "object": "chat.completion.chunk",
418 | "created": int(time.time()),
419 | "model": request.model,
420 | "choices": [
421 | {
422 | "index": 0,
423 | "delta": {"content": line + '\n'}, # if i > 0 else {"role": "assistant", "content": ""},
424 | "finish_reason": None
425 | }
426 | ]
427 | }
428 | yield f"data: {json.dumps(chunk)}\n\n"
429 | await asyncio.sleep(0.05)
430 |
431 | final_chunk = {
432 | "id": chunk_id,
433 | "object": "chat.completion.chunk",
434 | "created": int(time.time()),
435 | "model": request.model,
436 | "choices": [
437 | {
438 | "index": 0,
439 | "delta": {},
440 | "finish_reason": "stop"
441 | }
442 | ]
443 | }
444 | yield f"data: {json.dumps(final_chunk)}\n\n"
445 | yield "data: [DONE]\n\n"
446 |
447 | return StreamingResponse(generate_stream(), media_type="text/event-stream")
448 | else:
449 | response = ChatCompletionResponse(
450 | model=request.model,
451 | choices=[
452 | ChatCompletionResponseChoice(
453 | index=0,
454 | message=Message(role="assistant", content=formatted_response),
455 | finish_reason="stop"
456 | )
457 | ],
458 | usage=Usage(
459 | prompt_tokens=len(prompt.split()),
460 | completion_tokens=len(formatted_response.split()),
461 | total_tokens=len(prompt.split()) + len(formatted_response.split())
462 | )
463 | )
464 | logger.info(f"Sending response: {response}")
465 | return JSONResponse(content=response.dict())
466 |
467 | except Exception as e:
468 | logger.error(f"Error processing chat completion: {str(e)}")
469 | raise HTTPException(status_code=500, detail=str(e))
470 |
471 | @app.get("/v1/models")
472 | async def list_models():
473 | """
474 | Return a list of available models
475 | """
476 | logger.info("Received model list request")
477 | current_time = int(time.time())
478 | models = [
479 | {"id": "graphrag-local-search:latest", "object": "model", "created": current_time - 100000, "owned_by": "graphrag"},
480 | {"id": "graphrag-global-search:latest", "object": "model", "created": current_time - 95000, "owned_by": "graphrag"},
481 | # {"id": "graphrag-question-generator:latest", "object": "model", "created": current_time - 90000, "owned_by": "graphrag"},
482 | # {"id": "gpt-3.5-turbo:latest", "object": "model", "created": current_time - 80000, "owned_by": "openai"},
483 | # {"id": "text-embedding-3-small:latest", "object": "model", "created": current_time - 70000, "owned_by": "openai"},
484 | {"id": "tavily-search:latest", "object": "model", "created": current_time - 85000, "owned_by": "tavily"},
485 | {"id": "full-model:latest", "object": "model", "created": current_time - 80000, "owned_by": "combined"}
486 | ]
487 |
488 | response = {
489 | "object": "list",
490 | "data": models
491 | }
492 |
493 | logger.info(f"Sending model list: {response}")
494 | return JSONResponse(content=response)
495 |
496 | if __name__ == "__main__":
497 | import uvicorn
498 |
499 | logger.info(f"Starting server on port {PORT}")
500 | uvicorn.run(app, host="0.0.0.0", port=PORT)
501 |
--------------------------------------------------------------------------------