├── .gitignore
├── LICENSE
├── case
├── cluster_sparsity.py
├── cluster_sparsity_ratio.py
├── connect_case.py
├── graph_sparsity.py
└── graphml_visualize.py
├── config.yaml
├── docs
└── notes_of_codes.md
├── eval
├── batch_eval.py
├── cal_time.py
├── cal_tokens.py
├── datasets
│ ├── agriculture
│ │ ├── agriculture_eval_hi_fastgraphrag_result_openai.jsonl
│ │ ├── agriculture_eval_hi_graphrag_result_openai.jsonl
│ │ ├── agriculture_eval_hi_kag_result_openai.jsonl
│ │ ├── agriculture_eval_hi_lightrag_result_openai.jsonl
│ │ ├── agriculture_eval_hi_naive_result_openai.jsonl
│ │ ├── agriculture_fastGraphrag_result_deepseek.jsonl
│ │ ├── agriculture_kag_result_deepseek.jsonl
│ │ ├── agriculture_query.jsonl
│ │ └── agriculture_unique_contexts.json
│ ├── cs
│ │ ├── cs_eval_hi_fastgraphrag_result_openai.jsonl
│ │ ├── cs_eval_hi_graphrag_result_openai.jsonl
│ │ ├── cs_eval_hi_kag_result_openai.jsonl
│ │ ├── cs_eval_hi_lightrag_result_openai.jsonl
│ │ ├── cs_eval_hi_naive_result_openai.jsonl
│ │ ├── cs_fastGraphrag_result_deepseek.jsonl
│ │ ├── cs_kag_result_deepseek.jsonl
│ │ ├── cs_local_result_deepseek.jsonl
│ │ ├── cs_naive_result_deepseek.jsonl
│ │ ├── cs_query.jsonl
│ │ └── cs_unique_contexts.json
│ ├── legal
│ │ ├── legal_eval_hi_fastgraphrag_result_openai.jsonl
│ │ ├── legal_eval_hi_graphrag_result_openai.jsonl
│ │ ├── legal_eval_hi_kag_result_openai.jsonl
│ │ ├── legal_eval_hi_lightrag_result_openai.jsonl
│ │ ├── legal_eval_hi_naiveR_result_openai.jsonl
│ │ ├── legal_fastGraphrag_result_deepseek.jsonl
│ │ ├── legal_kag_result_deepseek.jsonl
│ │ ├── legal_query.jsonl
│ │ └── legal_unique_contexts.json
│ └── mix
│ │ ├── mix.jsonl
│ │ ├── mix_eval.jsonl
│ │ ├── mix_eval_hi.jsonl
│ │ ├── mix_eval_hi_fastgraphrag_result_openai.jsonl
│ │ ├── mix_eval_hi_graphrag_result_openai.jsonl
│ │ ├── mix_eval_hi_hi_naiveR_baseindex_result_openai.jsonl
│ │ ├── mix_eval_hi_kag_result_openai.jsonl
│ │ ├── mix_eval_hi_lightrag_result_openai.jsonl
│ │ ├── mix_eval_hi_naiveR_hi.jsonl
│ │ ├── mix_eval_hi_naiveR_hi_result_openai.jsonl
│ │ ├── mix_eval_hi_naive_result_openai.jsonl
│ │ ├── mix_eval_lightrag_result_openai.jsonl
│ │ ├── mix_eval_naive_result_openai.jsonl
│ │ ├── mix_fastGraphrag_result_deepseek.jsonl
│ │ ├── mix_hi_naiveR_result_deepseek.jsonl
│ │ ├── mix_hi_naiveR_result_deepseek_baseindex.jsonl
│ │ ├── mix_hi_result_deepseek.jsonl
│ │ ├── mix_kag_result_deepseek.jsonl
│ │ ├── mix_lightrag_result_deepseek.jsonl
│ │ ├── mix_local_result_deepseek.jsonl
│ │ ├── mix_naive_result_deepseek.jsonl
│ │ └── mix_unique_contexts.json
├── extract_context.py
├── extract_query.py
├── insert_context_deepseek.py
├── insert_context_glm.py
├── insert_context_openai.py
├── test_deepseek.py
├── test_glm.py
└── test_openai.py
├── hi_Search_deepseek.py
├── hi_Search_glm.py
├── hi_Search_openai.py
├── hirag
├── __init__.py
├── _cluster_utils.py
├── _llm.py
├── _op.py
├── _splitter.py
├── _storage
│ ├── __init__.py
│ ├── gdb_neo4j.py
│ ├── gdb_networkx.py
│ ├── kv_json.py
│ └── vdb_nanovectordb.py
├── _utils.py
├── base.py
├── hirag.py
└── prompt.py
├── imgs
├── hirag_ds.drawio.png
├── hirag_ds_trans.drawio.png
├── hirag_icon.png
└── hirag_image.png
├── readme.md
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/python
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python
3 | test_cache.json
4 | run_test*.py
5 | nano_graphrag_cache*/
6 | examples/benchmarks/fixtures/
7 | tests/original_workflow.txt
8 | # eval/datasets/agriculture
9 | # eval/datasets/legal
10 | # eval/datasets/cs
11 | eval/datasets/music
12 | eval/datasets/mix/work_dir_glm_base
13 | eval/datasets/mix/work_dir_deepseek_hi
14 | eval/datasets/mix/work_dir_openai
15 | eval/datasets/mix/work_dir_glm_hi
16 | case/work_dir_graphrag
17 | case/work_dir_hi
18 | case/work_dir_hi_small
19 | case/work_dir_base_small
20 | multi_hop_rag_test_hi
21 | multi_hop_rag_test
22 | config.yaml
23 | web3_test
24 | ### Python ###
25 | # Byte-compiled / optimized / DLL files
26 | __pycache__/
27 | *.py[cod]
28 | *$py.class
29 | .vscode
30 | .DS_Store
31 | # C extensions
32 | *.so
33 |
34 | # Distribution / packaging
35 | .Python
36 | build/
37 | develop-eggs/
38 | dist/
39 | downloads/
40 | eggs/
41 | .eggs/
42 | lib/
43 | lib64/
44 | parts/
45 | sdist/
46 | var/
47 | wheels/
48 | share/python-wheels/
49 | *.egg-info/
50 | .installed.cfg
51 | *.egg
52 | MANIFEST
53 |
54 | # PyInstaller
55 | # Usually these files are written by a python script from a template
56 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
57 | *.manifest
58 | *.spec
59 |
60 | # Installer logs
61 | pip-log.txt
62 | pip-delete-this-directory.txt
63 |
64 | # Unit test / coverage reports
65 | htmlcov/
66 | .tox/
67 | .nox/
68 | .coverage
69 | .coverage.*
70 | .cache
71 | nosetests.xml
72 | coverage.xml
73 | *.cover
74 | *.py,cover
75 | .hypothesis/
76 | .pytest_cache/
77 | cover/
78 |
79 | # Translations
80 | *.mo
81 | *.pot
82 |
83 | # Django stuff:
84 | *.log
85 | local_settings.py
86 | db.sqlite3
87 | db.sqlite3-journal
88 |
89 | # Flask stuff:
90 | instance/
91 | .webassets-cache
92 |
93 | # Scrapy stuff:
94 | .scrapy
95 |
96 | # Sphinx documentation
97 | docs/_build/
98 |
99 | # PyBuilder
100 | .pybuilder/
101 | target/
102 |
103 | # Jupyter Notebook
104 | .ipynb_checkpoints
105 |
106 | # IPython
107 | profile_default/
108 | ipython_config.py
109 |
110 | # pyenv
111 | # For a library or package, you might want to ignore these files since the code is
112 | # intended to run in multiple environments; otherwise, check them in:
113 | # .python-version
114 |
115 | # pipenv
116 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
117 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
118 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
119 | # install all needed dependencies.
120 | #Pipfile.lock
121 |
122 | # poetry
123 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
124 | # This is especially recommended for binary packages to ensure reproducibility, and is more
125 | # commonly ignored for libraries.
126 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
127 | #poetry.lock
128 |
129 | # pdm
130 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
131 | #pdm.lock
132 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
133 | # in version control.
134 | # https://pdm.fming.dev/#use-with-ide
135 | .pdm.toml
136 |
137 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
138 | __pypackages__/
139 |
140 | # Celery stuff
141 | celerybeat-schedule
142 | celerybeat.pid
143 |
144 | # SageMath parsed files
145 | *.sage.py
146 |
147 | # Environments
148 | .env
149 | .venv
150 | env/
151 | venv/
152 | ENV/
153 | env.bak/
154 | venv.bak/
155 |
156 | # Spyder project settings
157 | .spyderproject
158 | .spyproject
159 |
160 | # Rope project settings
161 | .ropeproject
162 |
163 | # mkdocs documentation
164 | /site
165 |
166 | # mypy
167 | .mypy_cache/
168 | .dmypy.json
169 | dmypy.json
170 |
171 | # Pyre type checker
172 | .pyre/
173 |
174 | # pytype static type analyzer
175 | .pytype/
176 |
177 | # Cython debug symbols
178 | cython_debug/
179 |
180 | # PyCharm
181 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
182 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
183 | # and can be added to the global gitignore or merged into this file. For a more nuclear
184 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
185 | #.idea/
186 |
187 | ### Python Patch ###
188 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
189 | poetry.toml
190 |
191 | # ruff
192 | .ruff_cache/
193 |
194 | # LSP config files
195 | pyrightconfig.json
196 |
197 | # End of https://www.toptal.com/developers/gitignore/api/python
198 |
199 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 haoyuhuang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/case/cluster_sparsity.py:
--------------------------------------------------------------------------------
1 | import seaborn as sns
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import statistics
5 |
6 | # 设置Seaborn样式
7 | sns.set(style="whitegrid")
8 |
9 | # 示例数据
10 | x = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
11 | y_values = [[90.31, 60.29, 69.97],
12 | [99.51, 98.72, 87.53],
13 | [99.74, 98.48, 99.17],
14 | [99.74, 98.18, 99.17],
15 | [99.72, 97.78, 99.05],
16 | [99.69, 97.78, 98.72],
17 | [99.67, 97.78, 98.72],
18 | [99.67, 97.22, 98.72],
19 | [99.64, 97.22, 98.48],
20 | [99.64, 97.22, 98.18]]
21 | y = [statistics.mean(x) for x in y_values]
22 |
23 | # 创建图形
24 | plt.figure(figsize=(5, 3))
25 |
26 | # 绘制折线图
27 | sns.lineplot(x=x, y=y, marker='o', label='Sparsity')
28 |
29 | # 准备填充区域的数据
30 | y_min = [min(val) for val in y_values]
31 | y_max = [max(val) for val in y_values]
32 | # 计算每一层的三个数据的变化率
33 | change_rates = []
34 | for i in range(0, len(y_values)-1):
35 | prev_values = y_values[i]
36 | curr_values = y_values[i+1]
37 | rates = [(curr - prev) / prev * 100 for prev, curr in zip(prev_values, curr_values)]
38 | change_rates.append(rates)
39 |
40 | # 计算变化率的平均值、最小值和最大值
41 | change_rates_mean = [statistics.mean(rates) for rates in change_rates]
42 | change_rates_min = [min(rates) for rates in change_rates]
43 | change_rates_max = [max(rates) for rates in change_rates]
44 | print(change_rates_mean[1])
45 | # 绘制变化率的折线图
46 | sns.lineplot(x=x[1:], y=change_rates_mean, marker='o', color='#BD86A4', label='Change Rate')
47 |
48 | # 填充变化率的最大值和最小值区域
49 | plt.fill_between(x[1:], change_rates_min, change_rates_max, alpha=0.3, color='gray')
50 |
51 | # 填充区域
52 | plt.fill_between(x, y_min, y_max, alpha=0.3, color='gray')
53 | plt.xticks(fontsize=16)
54 | plt.yticks(fontsize=16)
55 | # plt.axhline(y=98, color='r', linestyle='--', linewidth=1, label='y=98')
56 | # yticks = plt.gca().get_yticks()
57 | # 在 y=99 位置添加文本注释,与 y 轴刻度对齐
58 | # plt.text(0, 98, '98', color='r', verticalalignment='top', horizontalalignment='center')# 设置标题和标签
59 | # plt.title('Cluster Sparsity vs. Layer')
60 | plt.xlabel('Layer', fontsize=19)
61 | plt.ylabel('Cluster Sparsity (%)', fontsize=19)
62 | plt.legend(title='', fontsize=15)
63 | plt.subplots_adjust(left=0.18, right=0.97, top=0.98, bottom=0.23)
64 |
65 | # 显示图形
66 | plt.savefig('./test.png')
--------------------------------------------------------------------------------
/case/cluster_sparsity_ratio.py:
--------------------------------------------------------------------------------
1 | import seaborn as sns
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import statistics
5 |
6 | # 设置Seaborn样式
7 | sns.set(style="whitegrid")
8 |
9 | # 示例数据
10 | x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
11 | y_values = [[90.31, 60.29, 69.97],
12 | [99.51, 98.72, 87.53],
13 | [99.74, 98.48, 99.17],
14 | [99.74, 98.18, 99.17],
15 | [99.72, 97.78, 99.05],
16 | [99.69, 97.78, 98.72],
17 | [99.67, 97.78, 98.72],
18 | [99.67, 97.22, 98.72],
19 | [99.64, 97.22, 98.48],
20 | [99.64, 97.22, 98.18]]
21 | y = [statistics.mean(x) for x in y_values]
22 |
23 | # 计算每一层的三个数据的变化率
24 | change_rates = []
25 | for i in range(1, len(y_values)):
26 | prev_values = y_values[i-1]
27 | curr_values = y_values[i]
28 | rates = [(curr - prev) / prev * 100 for prev, curr in zip(prev_values, curr_values)]
29 | change_rates.append(rates)
30 |
31 | # 计算变化率的平均值、最小值和最大值
32 | change_rates_mean = [statistics.mean(rates) for rates in change_rates]
33 | change_rates_min = [min(rates) for rates in change_rates]
34 | change_rates_max = [max(rates) for rates in change_rates]
35 |
36 | # 创建图形
37 | plt.figure(figsize=(4, 3))
38 |
39 | # 绘制变化率的折线图
40 | sns.lineplot(x=x[1:], y=change_rates_mean, marker='o', color='#6656A6', label='Change Rate')
41 |
42 | # 填充变化率的最大值和最小值区域
43 | plt.fill_between(x[1:], change_rates_min, change_rates_max, alpha=0.3, color='gray')
44 |
45 | # 设置标题和标签
46 | plt.title('Cluster Sparsity vs. Layer')
47 | plt.xlabel('Layer')
48 | plt.ylabel('Cluster Sparsity (%)')
49 | plt.subplots_adjust(left=0.18, right=0.95, top=0.9, bottom=0.17)
50 |
51 | # 添加图例
52 | plt.legend()
53 |
54 | # 显示图形
55 | plt.savefig('./test_ratio.png')
--------------------------------------------------------------------------------
/case/connect_case.py:
--------------------------------------------------------------------------------
1 | import seaborn as sns
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 |
5 | # 示例数据
6 | data = {
7 | 'Category': ['Mix', 'Mix', 'Mix',
8 | 'CS', 'CS', 'CS',
9 | 'Legal', 'Legal', 'Legal',
10 | 'Agriculure', 'Agriculure', 'Agriculure'
11 | ],
12 | 'Method': ['GraphRAG', 'LightRAG', 'HiRAG',
13 | 'GraphRAG', 'LightRAG', 'HiRAG',
14 | 'GraphRAG', 'LightRAG', 'HiRAG',
15 | 'GraphRAG', 'LightRAG', 'HiRAG'],
16 | 'Value': [0.0091, 0.0086, 0.0692,
17 | 0.0133, 0.0181, 0.0305,
18 | 0.0185, 0.0086, 0.0236,
19 | 0.0250, 0.0173, 0.0350]
20 | }
21 |
22 | # 将数据转换为DataFrame
23 | df = pd.DataFrame(data)
24 |
25 | # 设置Seaborn样式
26 | sns.set(style="whitegrid")
27 |
28 | # 创建图形
29 | plt.figure(figsize=(6, 3))
30 |
31 | # 绘制柱状图
32 | custom_palette = ['#3C3B8B', '#917BBD', '#E5CBE1'] # 蓝色和橙色
33 | sns.barplot(x='Category', y='Value', hue='Method', data=df, palette=custom_palette)
34 | plt.xticks(fontsize=12)
35 | plt.yticks(fontsize=12)
36 | # 设置标题和标签
37 | plt.title('', fontsize=15)
38 | plt.xlabel('Dataset', fontsize=15)
39 | plt.ylabel('Clustering Coefficient', fontsize=15)
40 | plt.subplots_adjust(left=0.13, right=0.95, top=0.95, bottom=0.18)
41 | plt.legend(title='')
42 |
43 | # 显示图形
44 | plt.savefig("./connect.png")
--------------------------------------------------------------------------------
/case/graph_sparsity.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import time
5 | import argparse
6 | import numpy as np
7 | import networkx as nx
8 | sys.path.append("../")
9 | from hirag import HiRAG, QueryParam
10 | from openai import AsyncOpenAI, OpenAI
11 | from dataclasses import dataclass
12 | from hirag.base import BaseKVStorage
13 | from hirag._utils import compute_args_hash
14 | from tqdm import tqdm
15 | from networkx.linalg.algebraicconnectivity import algebraic_connectivity
16 |
17 | os.environ["OPENAI_API_KEY"] = "***"
18 | GLM_API_KEY = "***"
19 | MODEL = "deepseek-chat"
20 | DEEPSEEK_API_KEY = "***"
21 |
22 |
23 | @dataclass
24 | class EmbeddingFunc:
25 | embedding_dim: int
26 | max_token_size: int
27 | func: callable
28 |
29 | async def __call__(self, *args, **kwargs) -> np.ndarray:
30 | return await self.func(*args, **kwargs)
31 |
32 | def wrap_embedding_func_with_attrs(**kwargs):
33 | """Wrap a function with attributes"""
34 |
35 | def final_decro(func) -> EmbeddingFunc:
36 | new_func = EmbeddingFunc(**kwargs, func=func)
37 | return new_func
38 |
39 | return final_decro
40 |
41 | @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
42 | async def GLM_embedding(texts: list[str]) -> np.ndarray:
43 | model_name = "embedding-3"
44 | client = OpenAI(
45 | api_key=GLM_API_KEY,
46 | base_url="https://open.bigmodel.cn/api/paas/v4/"
47 | )
48 | embedding = client.embeddings.create(
49 | input=texts,
50 | model=model_name,
51 | )
52 | final_embedding = [d.embedding for d in embedding.data]
53 | return np.array(final_embedding)
54 |
55 |
56 | async def deepseepk_model_if_cache(
57 | prompt, system_prompt=None, history_messages=[], **kwargs
58 | ) -> str:
59 | openai_async_client = AsyncOpenAI(
60 | api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com"
61 | )
62 | messages = []
63 | if system_prompt:
64 | messages.append({"role": "system", "content": system_prompt})
65 |
66 | # Get the cached response if having-------------------
67 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
68 | messages.extend(history_messages)
69 | messages.append({"role": "user", "content": prompt})
70 | if hashing_kv is not None:
71 | args_hash = compute_args_hash(MODEL, messages)
72 | if_cache_return = await hashing_kv.get_by_id(args_hash)
73 | if if_cache_return is not None:
74 | return if_cache_return["return"]
75 | # -----------------------------------------------------
76 |
77 | response = await openai_async_client.chat.completions.create(
78 | model=MODEL, messages=messages, **kwargs
79 | )
80 |
81 | # Cache the response if having-------------------
82 | if hashing_kv is not None:
83 | await hashing_kv.upsert(
84 | {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
85 | )
86 | # -----------------------------------------------------
87 | return response.choices[0].message.content
88 |
89 | DATASET = "legal"
90 |
91 | graph_func = HiRAG(
92 | working_dir=f"../eval/datasets/{DATASET}/work_dir_glm_hi_clustercase",
93 | enable_llm_cache=False,
94 | embedding_func=GLM_embedding,
95 | best_model_func=deepseepk_model_if_cache,
96 | cheap_model_func=deepseepk_model_if_cache,
97 | enable_hierachical_mode=True,
98 | embedding_func_max_async=16,
99 | enable_naive_rag=True)
100 |
101 | nx_graph = graph_func.chunk_entity_relation_graph._graph
102 | num_nodes = nx_graph.number_of_nodes()
103 | num_edges = nx_graph.number_of_edges()
104 | max_edges_directed = num_nodes * (num_nodes - 1) / 2
105 | sparsity_directed = 1 - (num_edges / max_edges_directed)
106 | alg_connectivity = nx.transitivity(nx_graph)
107 |
108 | print("Dataset:", DATASET)
109 | print("Sparsity:", sparsity_directed)
110 | print("Global Clustering Coefficient:", alg_connectivity)
111 |
--------------------------------------------------------------------------------
/case/graphml_visualize.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import json
3 | import os
4 | import webbrowser
5 | import threading
6 | import socketserver
7 | import http.server
8 |
9 | # Load GraphML file and transfer to JSON
10 | def graphml_to_json(graphml_file):
11 | G = nx.read_graphml(graphml_file)
12 | data = nx.node_link_data(G)
13 | return json.dumps(data)
14 |
15 | # Create HTML file with improved visualization
16 | def create_html(html_path):
17 | html_content = '''
18 |
19 |
20 |
21 |
22 |
23 | Graph Visualization
24 |
25 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
244 |
245 |
246 | '''
247 |
248 | with open(html_path, 'w', encoding='utf-8') as f:
249 | f.write(html_content)
250 |
251 | # Create JSON file
252 | def create_json(json_data, json_path):
253 | json_data = "var graphJson = " + json_data.replace('\\"', '').replace("'", "\\'").replace("\n", "")
254 | with open(json_path, 'w', encoding='utf-8') as f:
255 | f.write(json_data)
256 |
257 | # Start simple HTTP server
258 | def start_server(port):
259 | handler = http.server.SimpleHTTPRequestHandler
260 | with socketserver.TCPServer(("", port), handler) as httpd:
261 | print(f"Server started at http://localhost:{port}")
262 | httpd.serve_forever()
263 |
264 | # Main function
265 | def visualize_graphml(graphml_file, html_path, port=8000):
266 | json_data = graphml_to_json(graphml_file)
267 | html_dir = os.path.dirname(html_path)
268 | if not os.path.exists(html_dir):
269 | os.makedirs(html_dir)
270 | json_path = os.path.join(html_dir, 'graph_json.js')
271 | create_json(json_data, json_path)
272 | create_html(html_path)
273 | # Start server in background
274 | server_thread = threading.Thread(target=start_server, args=(port,))
275 | server_thread.daemon = True
276 | server_thread.start()
277 | # Open default browser
278 | webbrowser.open(f'http://localhost:{port}/{os.path.basename(html_path)}')
279 | print("Visualization is ready. Press Ctrl+C to exit.")
280 | try:
281 | # Keep main thread running
282 | while True:
283 | pass
284 | except KeyboardInterrupt:
285 | print("Shutting down...")
286 |
287 | # Usage
288 | if __name__ == "__main__":
289 | graphml_file = r"./work_dir_base_small/graph_chunk_entity_relation.graphml" # Replace with your GraphML file path
290 | html_path = "./graph_visualization.html"
291 | visualize_graphml(graphml_file, html_path, 11236)
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | # OpenAI Configuration
2 | openai:
3 | embedding_model: "text-embedding-ada-002"
4 | model: "gpt-4o"
5 | api_key: "***"
6 | base_url: "***"
7 |
8 | # GLM Configuration
9 | glm:
10 | model: "glm-4-plus"
11 | api_key: "***"
12 | base_url: "https://open.bigmodel.cn/api/paas/v4"
13 | embedding_model: "embedding-3"
14 |
15 | # Deepseek Configuration
16 | deepseek:
17 | model: "deepseek-chat"
18 | api_key: "***"
19 | base_url: "https://api.deepseek.com"
20 |
21 | # Model Parameters
22 | model_params:
23 | openai_embedding_dim: 1536
24 | glm_embedding_dim: 2048
25 | max_token_size: 8192
26 |
27 | # HiRAG Configuration
28 | hirag:
29 | working_dir: "your_work_dir"
30 | enable_llm_cache: false
31 | enable_hierachical_mode: true
32 | embedding_batch_num: 6
33 | embedding_func_max_async: 8
34 | enable_naive_rag: true
--------------------------------------------------------------------------------
/docs/notes_of_codes.md:
--------------------------------------------------------------------------------
1 | From the contributor [hhh2210](https://github.com/hhh2210).
2 | ## Text Chunking
3 | **Core Code**: `extract_hierarchical_entities` in `hirag/_op.py`
4 |
5 | **Key Steps**:
6 | - The function processes text chunks to extract entities and relationships
7 | - It uses LLM prompts defined in `PROMPTS["hi_entity_extraction"]` for entity extraction
8 | - Each chunk is processed to extract entities via `_process_single_content_entity`
9 | - Embeddings are created for all extracted entities
10 | - It also extracts relationships between entities via `_process_single_content_relation`
11 |
12 | ## Entity extraction
13 | - Happens in `_process_single_content_entity` and `_process_single_content_relation` functions within `extract_hierarchical_entities`. These functions:
14 | - Use an LLM to extract entities with structured prompts
15 | - Extract entity attributes like name, type, description, and source
16 | - Store entities in a knowledge graph and vector database
17 | - The extracted entity information is stored in the knowledge graph and processed by the `_handle_single_entity_extraction` function (line 165), which parses entity attributes from LLM output.
18 |
19 | ## GMM Clustering
20 | **Core Code**: Functions in `hirag/_cluster_utils.py`
21 |
22 | **Key Steps**:
23 | - Uses `sklearn.mixture.GaussianMixture` for clustering
24 | - Automatically determines optimal number of clusters with `get_optimal_clusters`
25 | - Applies dimension reduction with UMAP before clustering
26 | - Returns clusters as labels and probabilities
27 |
28 | ## Summarization of Entities
29 | - For each cluster from GMM clustering, generates a prompt with all entities in the cluster
30 | - Uses LLM to generate summary entities for the cluster
31 | - Parses the LLM response to extract new higher-level entities and relationships
32 | - Creates embeddings for these summary entities
33 | - Adds these summaries to the next layer in the hierarchy
34 |
35 | **Prompt Design**: The `summary_clusters` prompt instructs the LLM to:
36 | - "Identify at least one attribute entity for the given entity description list"
37 | - Generate entities matching types from the meta attribute list: `["organization", "person", "location", "event", "product", "technology", "industry", "mathematics", "social sciences"]`
38 |
39 |
--------------------------------------------------------------------------------
/eval/cal_time.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | content = ""
4 | with open("../log/graphrag_mix_retrieval_time.txt", 'r') as f:
5 | lines = f.readlines()
6 | for line in lines:
7 | content += str(line)
8 | # 使用正则表达式提取时间
9 | pattern = r"\[Retrieval Time: ([0-9.]+) seconds\]"
10 | matches = re.findall(pattern, content)
11 |
12 | # 将提取的时间转换为浮点数
13 | times = [float(match) for match in matches]
14 |
15 | # 计算平均值
16 | average_time = sum(times) / len(times) if times else 0
17 |
18 | print(len(times))
19 | print(f"Average Retrieval Time: {average_time:.6f} seconds")
--------------------------------------------------------------------------------
/eval/cal_tokens.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import time
5 | sys.path.append("../")
6 | from hirag import HiRAG, QueryParam
7 | import os
8 | import logging
9 | import numpy as np
10 | import tiktoken
11 | import yaml
12 | from openai import AsyncOpenAI, OpenAI
13 | from dataclasses import dataclass
14 | from hirag.base import BaseKVStorage
15 | from hirag._utils import compute_args_hash
16 |
17 | logging.basicConfig(level=logging.WARNING)
18 | logging.getLogger("HiRAG").setLevel(logging.INFO)
19 |
20 |
21 | with open('config.yaml', 'r') as file:
22 | config = yaml.safe_load(file)
23 |
24 | TOTAL_TOKEN_COST = 0
25 | TOTAL_API_CALL_COST = 0
26 |
27 | tokenizer = tiktoken.get_encoding("cl100k_base")
28 |
29 | if __name__ == "__main__":
30 | # file_path = f"./datasets/{DATASET}/{DATASET}_unique_contexts.json"
31 | # with open(file_path, mode="r") as f:
32 | # unique_contexts = json.load(f)
33 | # TOTAL_TOKEN_COST += len(tokenizer.encode(str(unique_contexts[:100])))
34 | with open("./datasets/mix/mix_kag_result_deepseek.jsonl", "r") as f:
35 | doc = f.readlines()
36 | full_doc = ""
37 | for item in doc:
38 | full_doc += json.loads(item)['answer']
39 |
40 | TOTAL_TOKEN_COST += len(tokenizer.encode(str(full_doc)))
41 | logging.info(f"[Total token cost: {TOTAL_TOKEN_COST}]")
42 | # logging.info(f"[Total document num: {len(unique_contexts)}]")
--------------------------------------------------------------------------------
/eval/datasets/agriculture/agriculture_fastGraphrag_result_deepseek.jsonl:
--------------------------------------------------------------------------------
1 | {"query": "What are the steps involved in extracting and handling honey, and why is it important to strain the honey after extraction?", "answer": "No relevant information was found."}
2 | {"query": "What is the definition of a small farm according to the author?", "answer": "No relevant information was found."}
3 | {"query": "How does the Intervale in Burlington, Vermont, support new farmers?", "answer": "No relevant information was found."}
4 | {"query": "What is the significance of the rest period in grass growth as described in the book?", "answer": "The rest period in grass growth is significant as it allows the grass to recover and regrow, ensuring sustainable grazing practices. This is particularly important in systems like rational grazing, where controlled access and rest periods for the grass are managed using technologies such as electric fencing. The rest period helps in maintaining the health and productivity of the grass, which is crucial for the overall sustainability of grazing systems."}
5 | {"query": "What innovative practices does Greensgrow Farm in Philadelphia employ to make urban farming sustainable?", "answer": "No relevant information was found."}
6 | {"query": "What are the two dogmas of environmental philosophy mentioned in the book?", "answer": "No relevant information was found."}
7 | {"query": "What are the benefits of direct marketing for small farmers, as highlighted in the book?", "answer": "Direct marketing benefits small farmers by enabling them to sell their products directly to consumers, which can lead to higher profit margins, stronger relationships with customers, and greater control over pricing and branding. This approach also supports sustainable and regenerative agriculture practices, as highlighted by entities like Polyface Farm, which practices direct marketing of its products."}
8 | {"query": "Who is the series editor of \"Culture of the Land\"?", "answer": "No relevant information was found."}
9 | {"query": "Describe the role of the varroa mite in the decline of honey bee populations.", "answer": "The Varroa destructor mite plays a significant role in the decline of honey bee populations by parasitizing Apis mellifera, the European honey bee. This parasitism leads to significant colony losses, as the mites weaken the bees, making them more susceptible to diseases and viruses. The impact of Varroa destructor on bee-keeping practices has been discussed, highlighting the economic losses experienced by beekeepers like John Miller. Control methods, including the use of Apistan and Coumaphos, have been employed to treat varroa mite infestations, though resistance and negative effects on bee health have been noted. Biological control agents, such as Metarhizium anisopliae, have also been used to reduce Varroa destructor populations."}
10 | {"query": "What is the concept of \"intensity of grazing\" as defined in the book?", "answer": "The concept of 'intensity of grazing' is not explicitly defined in the provided input data."}
11 | {"query": "What are the environmental and ethical considerations in beekeeping, particularly regarding the use of pesticides and the welfare of bees?", "answer": "Environmental and ethical considerations in beekeeping include the impact of pesticides like Apistan and Coumaphos on bee health, with Apistan becoming less effective due to mite resistance and Coumaphos being harder on bees. Organic certifiers specify acceptable methods of control for pests like the wax moth and small hive beetle, indicating a move towards more sustainable practices. The use of biological control agents, such as Metarhizium anisopliae against Varroa destructor, represents an environmentally friendly approach to pest management. The welfare of bees is also a concern, with practices aimed at reducing stress and preventing colony losses, as seen in the efforts of beekeepers like John Miller and Dave Hackenberg, who have experienced significant losses due to Colony Collapse Disorder. Sustainable beekeeping practices are essential for maintaining ecological integrity and ensuring the health of bee populations."}
12 | {"query": "What are the environmental considerations in market farming?", "answer": "Environmental considerations in market farming include adopting sustainable agriculture practices such as no-till farming, crop rotation, and the use of nitrogenous fertilizers to enhance grass productivity while managing the risk of grass tetany. Farms like Polyface and Cronin Farms implement these practices to maintain ecological integrity and social sustainability. Additionally, the use of rational grazing systems with electric fencing helps optimize pasture utilization and animal health. The Rodale Institute and No-Till Center advocate for organic farming and no-till practices to reduce environmental impact and promote soil health."}
13 | {"query": "What historical agricultural practices does Howard criticize, and why?", "answer": "No relevant information was found in the provided data to answer the query about Howard's criticism of historical agricultural practices."}
14 | {"query": "What is the primary motivation for writing \"Natural Beekeeping\"?", "answer": "No relevant information was found."}
15 | {"query": "How does Growing Gardens support home gardeners in Portland, Oregon?", "answer": "No relevant information was found."}
16 | {"query": "How does the author define agripreneurship?", "answer": "No relevant information was found."}
17 | {"query": "What is the role of cover crops in farming, according to the book?", "answer": "No relevant information was found."}
18 | {"query": "How does the book describe the influence of fertilizers on grass growth?", "answer": "The book describes that nitrogenous fertilizers are used in rational grazing systems to enhance grass productivity and compensate for seasonal fluctuations in grass growth. However, improper use of nitrogenous fertilizers can increase the risk of grass tetany by affecting the mineral balance in the grass."}
19 | {"query": "How does the Stametsian Model for Permaculture integrate mushrooms into sustainable agriculture?", "answer": "The input data does not provide specific information on the Stametsian Model for Permaculture or how it integrates mushrooms into sustainable agriculture. Therefore, no relevant information was found to directly answer the user query based on the provided data."}
20 | {"query": "How does a beekeeper typically control swarming in a colony?", "answer": "No relevant information was found."}
21 | {"query": "What is the \"Law of Return\" in agriculture, according to Howard?", "answer": "No relevant information was found."}
22 | {"query": "How does the book address the issue of grass tetany in livestock?", "answer": "The input data does not provide specific information on how the book addresses the issue of grass tetany in livestock."}
23 | {"query": "What are the financial considerations for starting a market farming business?", "answer": "No relevant information was found."}
24 | {"query": "What is the central theme of \"Growing a Revolution: Bringing Our Soil Back to Life\"?", "answer": "The central theme of 'Growing a Revolution: Bringing Our Soil Back to Life' is the promotion and implementation of sustainable and regenerative agriculture practices, such as no-till farming, organic farming, and soil health improvement, to restore and maintain the vitality of the soil. This theme is supported by the work of entities like the No-Till Center, Rodale Institute, and individuals such as Gabe Brown and Jeff Moyer, who advocate for and demonstrate the benefits of these practices in enhancing soil health and agricultural sustainability."}
25 | {"query": "How does the book address the relationship between soil health and climate change?", "answer": "No relevant information was found."}
26 | {"query": "What are the three principles of conservation agriculture as outlined in the book?", "answer": "No relevant information was found."}
27 | {"query": "What are the potential downsides of beekeeping as described in the book?", "answer": "The potential downsides of beekeeping, as described in the book, include significant economic losses due to varroa mite infestations, which have devastated bee colonies. Additionally, the use of certain treatments like Coumaphos and Apistan for varroa mites has had negative effects on bee health, and mites have developed resistance to Apistan. Organic certifiers specify acceptable methods of control for pests like wax moths and small hive beetles, indicating challenges in managing these pests organically."}
28 | {"query": "What are the common reasons people get into market farming?", "answer": "Common reasons people get into market farming include a desire for sustainable and regenerative agriculture practices, direct marketing of products to consumers, and creating an incentivized, autonomous farmer model. This approach allows farmers to be paid for their ability to produce results, emphasizing local, sustainable food production and the importance of maintaining natural ecosystems."}
29 | {"query": "What is the role of draft animals in sustainable farming, as discussed in the book?", "answer": "No relevant information was found regarding the role of draft animals in sustainable farming in the provided input data."}
30 | {"query": "What are the primary motivations for people to grow their own vegetables according to the National Gardening Association?", "answer": "No relevant information was found."}
31 | {"query": "What are the three main reasons people choose to farm, as identified by the Rural Sociology Department at the University of Missouri in the 1950s?", "answer": "No relevant information was found."}
32 | {"query": "What are some of the historical chores mentioned in the book that helped shape the character and skills of young people in the past?", "answer": "No relevant information was found."}
33 | {"query": "How has the rise of the almond industry affected commercial beekeepers like John Miller?", "answer": "The rise of the almond industry has significantly impacted commercial beekeepers like John Miller by increasing the demand for migratory beekeeping services for pollination. This demand has led to the transportation of bee colonies across states, a practice pioneered by John Miller's ancestor, Nephi Ephraim Miller. However, this increased movement and concentration of bees in almond orchards have also heightened the exposure to pests and diseases, such as the Varroa destructor mite, leading to significant colony losses and economic challenges for beekeepers."}
34 | {"query": "What are the benefits of including agriculture in planned urban developments as discussed in the book?", "answer": "No relevant information was found."}
35 | {"query": "How does Ross Conrad describe the relationship between the honey bee and the plant kingdom?", "answer": "No relevant information was found."}
36 | {"query": "What is the significance of soil organic matter in agriculture, as discussed in the book?", "answer": "No relevant information was found."}
37 | {"query": "What role does agriculture play in environmental sustainability according to the book?", "answer": "Agriculture plays a crucial role in environmental sustainability by promoting practices that encourage independence in harmony with the environment, as highlighted by Storey Publishing's mission. Sustainable agriculture practices, such as no-till farming advocated by the No-Till Center and regenerative agriculture practices by Gabe Brown, aim to maintain ecological integrity and social sustainability. These practices focus on reproducing ecosystems and achieving desirable political goals, such as rural community preservation and fair treatment for small farmers, contributing to a sustainable future."}
38 | {"query": "What is the significance of crop rotation in sustainable farming?", "answer": "Crop rotation is significant in sustainable farming as it helps in maintaining soil health, reducing pest and disease buildup, and improving soil fertility. By alternating crops with different nutrient needs and pest vulnerabilities, it prevents the depletion of specific soil nutrients and disrupts the life cycles of pests and diseases, leading to a more sustainable and productive farming system."}
39 | {"query": "What is the importance of farm income for small farmers, according to the author?", "answer": "No relevant information was found."}
40 | {"query": "How does the book differentiate between continuous grazing and rational grazing?", "answer": "The book differentiates between continuous grazing and rational grazing by highlighting that continuous grazing allows animals to graze freely over a large area without control, leading to overgrazing and undergrazing in different areas. Rational grazing, on the other hand, is a system of grazing management that aims to optimize the productivity of grass and the health of grazing animals by controlling the timing and intensity of grazing. This method uses electric fencing to manage the movement of animals, ensuring that pastures are grazed evenly and have time to recover, which enhances grass productivity and compensates for seasonal fluctuations in grass growth."}
41 | {"query": "How does the author suggest managing risks in farming?", "answer": "The author suggests managing risks in farming through sustainable agriculture practices, such as no-till farming, crop rotation, and rational grazing systems. These methods aim to optimize productivity while maintaining ecological integrity and social sustainability. Additionally, the use of electric fencing and nitrogenous fertilizers in rational grazing systems helps control animal movement and enhance grass productivity, respectively. Organic farming and the avoidance of synthetic chemicals, like paradichlorobenzene, are also recommended to protect ecosystems and comply with organic certification standards."}
42 | {"query": "How does Howard view the impact of artificial fertilizers on agriculture?", "answer": "No relevant information was found."}
43 | {"query": "What role do community gardens play in the broader food system according to the book?", "answer": "Community gardens, such as those initiated by Janus Youth, play a significant role in the broader food system by addressing social issues like crime and hunger, providing therapeutic benefits, and empowering communities through gardening and selling produce. They also contribute to social sustainability by preserving rural communities and supporting small farmers."}
44 | {"query": "What role do mushrooms play in the environment, particularly in terms of waste decomposition?", "answer": "Mushrooms play a crucial role in the environment, particularly in waste decomposition. They are capable of breaking down complex organic materials, such as dead plants and animals, into simpler substances. This process not only recycles nutrients back into the soil, making them available for other plants and organisms, but also helps in maintaining the ecological balance. Mushrooms, being fungi, secrete enzymes that decompose organic matter, thus contributing to the nutrient cycle and soil health. This decomposition process is vital for the sustainability of ecosystems, as it ensures the continuous recycling of essential nutrients."}
45 | {"query": "How does the book define sustainability?", "answer": "The input data does not directly define sustainability as described in a book. However, it mentions various aspects of sustainability through entities and relationships, such as social sustainability focusing on desirable political goals, ecological integrity emphasizing the reproduction of ecosystems, and sustainable agriculture practices advocated by individuals and organizations. For a specific book's definition of sustainability, more detailed information from the sources would be required."}
46 | {"query": "How does the author suggest managing machinery on a small farm?", "answer": "No relevant information was found."}
47 | {"query": "How does the Goat Justice League advocate for urban goats in Seattle?", "answer": "No relevant information was found."}
48 | {"query": "What are some challenges faced by market farmers in terms of food safety?", "answer": "Market farmers face challenges such as increasingly onerous regulations from the FSIS, which have run small operations out of business, making it exponentially harder to launch a business in an embryonic form. They also deal with the capital-intensive and paperwork requirements that deny them the ability to bring local, ecologically friendly food to their community. Additionally, the Food Safety Modernization Act has given the FDA the right to inspect farms without a warrant, adding another layer of regulatory burden. These challenges are compounded by the need for expensive testing, which is a significant financial burden for small operations like T&E Meats, potentially consuming an entire week's profit for a single test."}
49 | {"query": "How does the annual migration of beekeepers and their bees impact the agricultural system in the United States?", "answer": "The annual migration of beekeepers and their bees, such as the practices pioneered by Nephi Ephraim Miller and continued by John Miller, significantly impacts the agricultural system in the United States by facilitating the pollination of crops across different states. This migratory beekeeping helps in the production of a wide variety of fruits, vegetables, and nuts, contributing to the diversity and abundance of the agricultural output. However, it also exposes bees to various challenges, including the spread of pests like Varroa destructor and diseases, which can lead to significant colony losses and affect the overall health of bee populations and, consequently, agricultural productivity."}
50 | {"query": "What are the potential drawbacks of using plastic hive components in beekeeping?", "answer": "No relevant information was found regarding the potential drawbacks of using plastic hive components in beekeeping within the provided input data."}
51 | {"query": "How does Ross Conrad suggest obtaining bees for beginners?", "answer": "No relevant information was found."}
52 | {"query": "What are the benefits and challenges of using top bar hives compared to Langstroth hives?", "answer": "No relevant information was found in the input data regarding the benefits and challenges of using top bar hives compared to Langstroth hives."}
53 | {"query": "What is the significance of the \"P\" Value System in mushroom strain preservation?", "answer": "No relevant information was found regarding the 'P' Value System in mushroom strain preservation in the provided input data."}
54 | {"query": "How does Howard's approach to agriculture differ from conventional industrial agriculture?", "answer": "No relevant information was found."}
55 | {"query": "What are the other hive products besides honey, and what are their potential uses?", "answer": "Besides honey, other hive products include beeswax, used in candles, cosmetics, and as a food additive; propolis, a resinous mixture used for its antimicrobial properties in health supplements and natural remedies; royal jelly, consumed for its nutritional benefits and used in cosmetics; and bee pollen, considered a superfood and used in dietary supplements. However, the input data provided does not contain specific details on these products and their uses."}
56 | {"query": "What are some strategies for marketing farm-grown products?", "answer": "Strategies for marketing farm-grown products include direct marketing to consumers, leveraging local and sustainable sourcing to appeal to eco-conscious buyers, and utilizing online platforms and social media for broader reach. Partnerships with local businesses and restaurants, like Chipotle Mexican Grill's collaboration with Polyface Farm, can also enhance visibility and credibility. Additionally, participating in farmers' markets and local food events can help in building a loyal customer base and community support."}
57 | {"query": "What are microgreens and why are they considered high-value crops?", "answer": "No relevant information was found."}
58 | {"query": "What are the advantages of part-time farming over full-time farming?", "answer": "No relevant information was found."}
59 | {"query": "What are the six vectors of contamination in mushroom cultivation, and why are they significant?", "answer": "No relevant information was found in the provided data to directly answer the question about the six vectors of contamination in mushroom cultivation and their significance."}
60 | {"query": "How does the book contrast industrial and agrarian philosophies of agriculture?", "answer": "The input data does not provide specific information contrasting industrial and agrarian philosophies of agriculture directly. However, it mentions Fukuoka's critique of scientific farming, which he believes is limited to short-term goals and will lead to failure, contrasting it with natural farming methods that emphasize working with nature. Additionally, the data highlights sustainable and regenerative agriculture practices at farms like Polyface and Dakota Lakes Research Farm, which may reflect agrarian philosophies focusing on harmony with the environment and sustainability, as opposed to industrial agriculture's focus on efficiency and scale."}
61 | {"query": "How can beekeepers treat colonies for diseases such as American foulbrood and nosema disease?", "answer": "No relevant information was found in the input data to directly answer how beekeepers can treat colonies for diseases such as American foulbrood and nosema disease."}
62 | {"query": "How does the P-Patch Community Gardening Program in Seattle manage its large network of gardens?", "answer": "No relevant information was found."}
63 | {"query": "How does the author suggest evaluating land for farming?", "answer": "The author suggests evaluating land for farming by considering the use of temporary and permanent pastures, the application of nitrogenous fertilizers to enhance grass productivity, and the implementation of rational grazing systems that utilize electric fencing to control animal movement and manage pasture utilization. This approach aims to optimize the productivity of grass and the health of grazing animals."}
64 | {"query": "How does Howard propose to combat soil erosion and alkali land formation?", "answer": "No relevant information was found."}
65 | {"query": "How can farmers manage weeds in their fields?", "answer": "Farmers can manage weeds in their fields through sustainable agriculture practices such as no-till farming, crop rotation, and the use of cover crops. These methods help suppress weed growth by improving soil health and reducing the need for chemical herbicides. Additionally, practices like rational grazing and the use of electric fencing can help manage weeds in pasture systems by controlling animal movement and grazing patterns, which in turn affects weed proliferation."}
66 | {"query": "What examples does the book provide of successful soil restoration by farmers?", "answer": "The input data does not provide specific examples of successful soil restoration by farmers as mentioned in a book. However, it mentions entities like Gabe Brown, known for his regenerative agriculture practices including no-till farming and intensive grazing, and the Rodale Institute, which focuses on organic farming research and education, including long-term field trials comparing organic and conventional farming. These entities are involved in practices that could lead to soil restoration, but direct examples from a book are not provided in the given data."}
67 | {"query": "What is the significance of the varroa mite in the beekeeping industry?", "answer": "The varroa mite (Varroa destructor) is significant in the beekeeping industry as it parasitizes honey bees (Apis mellifera), leading to significant colony losses. This has caused economic losses for beekeepers, such as John Miller, whose colonies have been devastated by varroa mite infestations. The mite's impact has also influenced bee-keeping practices and led to the development of treatments like Apistan and Coumaphos, though resistance and negative effects on bee health have been issues. Biological control agents, such as the fungus Metarhizium anisopliae, have been explored as alternatives to chemical treatments."}
68 | {"query": "How can farmers ensure they are selling high-quality produce to restaurants?", "answer": "Farmers can ensure they are selling high-quality produce to restaurants by adopting sustainable agriculture practices, such as those advocated by Beck at Dakota Lakes Research Farm and implemented at Cronin Farms by Dan Forgey. Emphasizing cleanliness and small-scale operations, as seen with T&E Meats, can also contribute to quality. Additionally, direct marketing and partnerships with restaurants that value sustainable sourcing, like Chipotle Mexican Grill's relationship with Polyface Farm, can help maintain high standards. Engaging in organic farming and adhering to regulations, such as those set by the FDA and FSIS, further ensures the quality and safety of produce."}
69 | {"query": "What are the methods for removing honey from the supers, and what precautions should be taken during this process?", "answer": "No relevant information was found in the provided input data regarding the methods for removing honey from the supers and the precautions to be taken during this process."}
70 | {"query": "What is the primary definition of grazing according to the book?", "answer": "The primary definition of grazing according to the book is not directly provided in the input data."}
71 | {"query": "How can market farmers extend their growing season?", "answer": "Market farmers can extend their growing season by using techniques such as temporary pastures with legumes like red clover for additional forage, employing electric fencing for controlled grazing management, and utilizing no-till farming practices to improve soil health and sustainability. Additionally, incorporating sustainable agriculture practices advocated by experts like Fukuoka and Joel Salatin, and leveraging local partnerships for direct marketing can also contribute to extending the growing season."}
72 | {"query": "What are the four universal laws of rational grazing according to the book?", "answer": "No relevant information was found."}
73 | {"query": "How does the Industrial Revolution impact agriculture according to Howard?", "answer": "No relevant information was found."}
74 | {"query": "How does the annual population cycle of a colony affect beekeeping practices, and what are the implications for winter management?", "answer": "The annual population cycle of a bee colony significantly impacts beekeeping practices, especially in terms of winter management. During winter, bee colonies naturally reduce their population to conserve resources, as there is less forage available. This reduction necessitates careful management by beekeepers to ensure the survival of the colony through the winter months. Beekeepers must monitor the health of the colony, manage pests like Varroa destructor, which can devastate bee populations, and ensure that the hive has sufficient food stores. The relationship between Varroa destructor and Apis mellifera (the European honey bee) is particularly critical, as infestations can lead to significant colony losses. Effective winter management strategies may include treating for pests, providing supplemental feeding, and ensuring the hive is properly insulated against cold temperatures. The input data highlights the importance of managing Varroa destructor infestations, as seen in the experiences of beekeepers like John Miller, who has faced significant economic losses due to these mites. Therefore, understanding and adapting to the annual population cycle is crucial for successful beekeeping and the survival of colonies through winter."}
75 | {"query": "What is the significance of the mycorrhizal association in plant health, as per Howard's research?", "answer": "No relevant information was found."}
76 | {"query": "What is the main mission of Storey Publishing?", "answer": "The main mission of Storey Publishing is to serve their customers by providing practical information that encourages independence in harmony with the environment."}
77 | {"query": "How much money can a market farmer expect to make per acre?", "answer": "No relevant information was found."}
78 | {"query": "How does the book suggest adapting agrarian ideals to contemporary urban life?", "answer": "No relevant information was found."}
79 | {"query": "What role does compost play in Howard's agricultural system?", "answer": "No relevant information was found."}
80 | {"query": "What does Salatin propose as a solution to the problem of food packaging waste in the airline industry?", "answer": "No relevant information was found."}
81 | {"query": "What is the role of communication in a bee colony, particularly the 'dance language'?", "answer": "No relevant information was found."}
82 | {"query": "What are the recommended steps for a beginner to get started with beekeeping, and why is it advised to start with new equipment?", "answer": "The input data does not provide specific recommended steps for a beginner to get started with beekeeping or the reasons for starting with new equipment. Therefore, no relevant information was found to directly answer the user query based on the provided data."}
83 | {"query": "How does the author suggest diversifying a farm to ensure sustainability?", "answer": "The author suggests diversifying a farm to ensure sustainability by adopting practices such as no-till farming, crop rotation, and regenerative agriculture. These methods are highlighted through examples like Cronin Farms adopting sustainable agriculture practices advocated by Beck, and Gabe Brown's regenerative agriculture practices at Brown Ranch. Additionally, the use of electric fencing for rational grazing and the application of nitrogenous fertilizers to enhance grass productivity are mentioned as part of sustainable farming strategies."}
84 | {"query": "What is the role of a CSA (Community-Supported Agriculture) in market farming?", "answer": "No relevant information was found."}
85 | {"query": "What are the nutritional needs of honey bees, and how do they meet them?", "answer": "Honey bees primarily require carbohydrates, proteins, vitamins, and minerals for their nutritional needs. They meet these needs by collecting nectar and pollen from flowers. Nectar provides carbohydrates in the form of sugars, which are converted into honey for energy. Pollen supplies proteins, fats, vitamins, and minerals essential for growth and development, especially for larvae and young bees. Additionally, honey bees may collect water and other substances to meet their hydration and nutritional requirements."}
86 | {"query": "What is the significance of Thomas Jefferson in the context of agrarian ideals and sustainability?", "answer": "Thomas Jefferson is significant in the context of agrarian ideals and sustainability as he is mentioned advocating for a constitutional amendment guaranteeing the right to choose and procure food, highlighting his belief in food freedom and the importance of sustainable, local food systems. His ideals suggest a vision of a society where individuals have the autonomy to make their own food choices, free from excessive government regulation, aligning with modern sustainability and agrarian principles."}
87 | {"query": "What are the benefits of using grain spawn in mushroom cultivation?", "answer": "Grain spawn is beneficial in mushroom cultivation as it serves as a nutrient-rich medium that supports the growth of mushroom mycelium. This method allows for the efficient colonization of the substrate by the mushroom culture, leading to faster and more uniform mushroom growth. Additionally, using grain spawn can help in maintaining the genetic integrity of the mushroom strain through cloning, ensuring consistent quality and characteristics of the cultivated mushrooms."}
88 | {"query": "What is the importance of setting goals for a farming business, according to the author?", "answer": "No relevant information was found."}
89 | {"query": "How does Salatin describe the transition from an agrarian to an industrial society in terms of food production and consumption?", "answer": "No relevant information was found."}
90 | {"query": "What strategies does the South End / Lower Roxbury Open Space Land Trust use to preserve community gardens?", "answer": "No relevant information was found."}
91 | {"query": "What is the concept of a stale seedbed in market farming?", "answer": "No relevant information was found."}
92 | {"query": "How does the book suggest we can reverse the trend of soil degradation?", "answer": "The input data does not provide specific information on how a book suggests reversing the trend of soil degradation. Relevant entities and relationships related to soil health and sustainable agriculture practices, such as no-till farming, crop rotation, and the use of cover crops, are mentioned in the context of various farms and research institutions (e.g., Menoken Farm, Dakota Lakes Research Farm, Rodale Institute). However, without direct references to a book's recommendations on reversing soil degradation, no specific answer can be provided based on the given data."}
93 | {"query": "Can mushrooms be used for bioremediation, and if so, how?", "answer": "Yes, mushrooms can be used for bioremediation. Certain species of mushrooms, such as Pleurotus spp (Oyster mushrooms), are known for their ability to break down and absorb pollutants from the environment, including heavy metals and petroleum products. This process, known as mycoremediation, utilizes the natural enzymatic capabilities of fungi to degrade or sequester contaminants in soil and water, making it an eco-friendly method for environmental cleanup."}
94 | {"query": "What technological innovations in beekeeping have had the most significant impact on the industry?", "answer": "Technological innovations in beekeeping that have had a significant impact on the industry include the use of Laminar Flow Hoods and Glove Boxes for preventing contamination in biological samples, the development of treatments for pests like Varroa Destructor and Small Hive Beetle, and the application of cloning techniques for mushroom strains which can be analogously applied to bee breeding for disease resistance. Additionally, the introduction of new media and online platforms like PlanetRadiocity.com by Radio City has facilitated the dissemination of beekeeping knowledge and practices."}
95 | {"query": "What are the two parasitic mites that have become a serious problem for honey bees in North America, and what are their impacts?", "answer": "The two parasitic mites that have become a serious problem for honey bees in North America are Varroa destructor and Aethina tumida (small hive beetle). Varroa destructor parasitizes and causes significant losses in honey bee colonies, leading to colony collapse. Aethina tumida, the small hive beetle, is a significant pest in warm southern states, affecting honey bee colonies by damaging hives and honeycombs."}
96 | {"query": "What is the primary focus of the book \"Culture of the Land: A Series in the New Agrarianism\"?", "answer": "No relevant information was found."}
97 | {"query": "What is the significance of John Miller's purchase of a Corvette in the context of his beekeeping career?", "answer": "No relevant information was found regarding the significance of John Miller's purchase of a Corvette in the context of his beekeeping career."}
98 | {"query": "What is the significance of the \"hygiene hypothesis\" and how does it relate to the modern lifestyle?", "answer": "No relevant information was found in the input data regarding the 'hygiene hypothesis' and its relation to modern lifestyle."}
99 | {"query": "How does the book challenge the conventional wisdom about industrialized agrochemical agriculture?", "answer": "The book challenges conventional wisdom about industrialized agrochemical agriculture by advocating for sustainable and natural farming methods. It contrasts scientific farming, which relies on human knowledge and action, with natural farming methods like Hinayana Natural Farming, which emphasizes working with nature rather than against it. The book also highlights the work of individuals and organizations, such as Fukuoka and the Rodale Institute, who promote organic and no-till farming practices as superior alternatives to conventional methods. These practices aim to achieve ecological integrity and social sustainability, challenging the short-term goals and potential long-term failures of industrialized agrochemical agriculture."}
100 | {"query": "What is the role of smoke in beekeeping, and how should it be used effectively?", "answer": "No relevant information was found."}
101 |
--------------------------------------------------------------------------------
/eval/datasets/agriculture/agriculture_query.jsonl:
--------------------------------------------------------------------------------
1 | {"query": "What are the steps involved in extracting and handling honey, and why is it important to strain the honey after extraction?"}
2 | {"query": "What is the definition of a small farm according to the author?"}
3 | {"query": "How does the Intervale in Burlington, Vermont, support new farmers?"}
4 | {"query": "What is the significance of the rest period in grass growth as described in the book?"}
5 | {"query": "What innovative practices does Greensgrow Farm in Philadelphia employ to make urban farming sustainable?"}
6 | {"query": "What are the two dogmas of environmental philosophy mentioned in the book?"}
7 | {"query": "What are the benefits of direct marketing for small farmers, as highlighted in the book?"}
8 | {"query": "Who is the series editor of \"Culture of the Land\"?"}
9 | {"query": "Describe the role of the varroa mite in the decline of honey bee populations."}
10 | {"query": "What is the concept of \"intensity of grazing\" as defined in the book?"}
11 | {"query": "What are the environmental and ethical considerations in beekeeping, particularly regarding the use of pesticides and the welfare of bees?"}
12 | {"query": "What are the environmental considerations in market farming?"}
13 | {"query": "What historical agricultural practices does Howard criticize, and why?"}
14 | {"query": "What is the primary motivation for writing \"Natural Beekeeping\"?"}
15 | {"query": "How does Growing Gardens support home gardeners in Portland, Oregon?"}
16 | {"query": "How does the author define agripreneurship?"}
17 | {"query": "What is the role of cover crops in farming, according to the book?"}
18 | {"query": "How does the book describe the influence of fertilizers on grass growth?"}
19 | {"query": "How does the Stametsian Model for Permaculture integrate mushrooms into sustainable agriculture?"}
20 | {"query": "How does a beekeeper typically control swarming in a colony?"}
21 | {"query": "What is the \"Law of Return\" in agriculture, according to Howard?"}
22 | {"query": "How does the book address the issue of grass tetany in livestock?"}
23 | {"query": "What are the financial considerations for starting a market farming business?"}
24 | {"query": "What is the central theme of \"Growing a Revolution: Bringing Our Soil Back to Life\"?"}
25 | {"query": "How does the book address the relationship between soil health and climate change?"}
26 | {"query": "What are the three principles of conservation agriculture as outlined in the book?"}
27 | {"query": "What are the potential downsides of beekeeping as described in the book?"}
28 | {"query": "What are the common reasons people get into market farming?"}
29 | {"query": "What is the role of draft animals in sustainable farming, as discussed in the book?"}
30 | {"query": "What are the primary motivations for people to grow their own vegetables according to the National Gardening Association?"}
31 | {"query": "What are the three main reasons people choose to farm, as identified by the Rural Sociology Department at the University of Missouri in the 1950s?"}
32 | {"query": "What are some of the historical chores mentioned in the book that helped shape the character and skills of young people in the past?"}
33 | {"query": "How has the rise of the almond industry affected commercial beekeepers like John Miller?"}
34 | {"query": "What are the benefits of including agriculture in planned urban developments as discussed in the book?"}
35 | {"query": "How does Ross Conrad describe the relationship between the honey bee and the plant kingdom?"}
36 | {"query": "What is the significance of soil organic matter in agriculture, as discussed in the book?"}
37 | {"query": "What role does agriculture play in environmental sustainability according to the book?"}
38 | {"query": "What is the significance of crop rotation in sustainable farming?"}
39 | {"query": "What is the importance of farm income for small farmers, according to the author?"}
40 | {"query": "How does the book differentiate between continuous grazing and rational grazing?"}
41 | {"query": "How does the author suggest managing risks in farming?"}
42 | {"query": "How does Howard view the impact of artificial fertilizers on agriculture?"}
43 | {"query": "What role do community gardens play in the broader food system according to the book?"}
44 | {"query": "What role do mushrooms play in the environment, particularly in terms of waste decomposition?"}
45 | {"query": "How does the book define sustainability?"}
46 | {"query": "How does the author suggest managing machinery on a small farm?"}
47 | {"query": "How does the Goat Justice League advocate for urban goats in Seattle?"}
48 | {"query": "What are some challenges faced by market farmers in terms of food safety?"}
49 | {"query": "How does the annual migration of beekeepers and their bees impact the agricultural system in the United States?"}
50 | {"query": "What are the potential drawbacks of using plastic hive components in beekeeping?"}
51 | {"query": "How does Ross Conrad suggest obtaining bees for beginners?"}
52 | {"query": "What are the benefits and challenges of using top bar hives compared to Langstroth hives?"}
53 | {"query": "What is the significance of the \"P\" Value System in mushroom strain preservation?"}
54 | {"query": "How does Howard's approach to agriculture differ from conventional industrial agriculture?"}
55 | {"query": "What are the other hive products besides honey, and what are their potential uses?"}
56 | {"query": "What are some strategies for marketing farm-grown products?"}
57 | {"query": "What are microgreens and why are they considered high-value crops?"}
58 | {"query": "What are the advantages of part-time farming over full-time farming?"}
59 | {"query": "What are the six vectors of contamination in mushroom cultivation, and why are they significant?"}
60 | {"query": "How does the book contrast industrial and agrarian philosophies of agriculture?"}
61 | {"query": "How can beekeepers treat colonies for diseases such as American foulbrood and nosema disease?"}
62 | {"query": "How does the P-Patch Community Gardening Program in Seattle manage its large network of gardens?"}
63 | {"query": "How does the author suggest evaluating land for farming?"}
64 | {"query": "How does Howard propose to combat soil erosion and alkali land formation?"}
65 | {"query": "How can farmers manage weeds in their fields?"}
66 | {"query": "What examples does the book provide of successful soil restoration by farmers?"}
67 | {"query": "What is the significance of the varroa mite in the beekeeping industry?"}
68 | {"query": "How can farmers ensure they are selling high-quality produce to restaurants?"}
69 | {"query": "What are the methods for removing honey from the supers, and what precautions should be taken during this process?"}
70 | {"query": "What is the primary definition of grazing according to the book?"}
71 | {"query": "How can market farmers extend their growing season?"}
72 | {"query": "What are the four universal laws of rational grazing according to the book?"}
73 | {"query": "How does the Industrial Revolution impact agriculture according to Howard?"}
74 | {"query": "How does the annual population cycle of a colony affect beekeeping practices, and what are the implications for winter management?"}
75 | {"query": "What is the significance of the mycorrhizal association in plant health, as per Howard's research?"}
76 | {"query": "What is the main mission of Storey Publishing?"}
77 | {"query": "How much money can a market farmer expect to make per acre?"}
78 | {"query": "How does the book suggest adapting agrarian ideals to contemporary urban life?"}
79 | {"query": "What role does compost play in Howard's agricultural system?"}
80 | {"query": "What does Salatin propose as a solution to the problem of food packaging waste in the airline industry?"}
81 | {"query": "What is the role of communication in a bee colony, particularly the 'dance language'?"}
82 | {"query": "What are the recommended steps for a beginner to get started with beekeeping, and why is it advised to start with new equipment?"}
83 | {"query": "How does the author suggest diversifying a farm to ensure sustainability?"}
84 | {"query": "What is the role of a CSA (Community-Supported Agriculture) in market farming?"}
85 | {"query": "What are the nutritional needs of honey bees, and how do they meet them?"}
86 | {"query": "What is the significance of Thomas Jefferson in the context of agrarian ideals and sustainability?"}
87 | {"query": "What are the benefits of using grain spawn in mushroom cultivation?"}
88 | {"query": "What is the importance of setting goals for a farming business, according to the author?"}
89 | {"query": "How does Salatin describe the transition from an agrarian to an industrial society in terms of food production and consumption?"}
90 | {"query": "What strategies does the South End / Lower Roxbury Open Space Land Trust use to preserve community gardens?"}
91 | {"query": "What is the concept of a stale seedbed in market farming?"}
92 | {"query": "How does the book suggest we can reverse the trend of soil degradation?"}
93 | {"query": "Can mushrooms be used for bioremediation, and if so, how?"}
94 | {"query": "What technological innovations in beekeeping have had the most significant impact on the industry?"}
95 | {"query": "What are the two parasitic mites that have become a serious problem for honey bees in North America, and what are their impacts?"}
96 | {"query": "What is the primary focus of the book \"Culture of the Land: A Series in the New Agrarianism\"?"}
97 | {"query": "What is the significance of John Miller's purchase of a Corvette in the context of his beekeeping career?"}
98 | {"query": "What is the significance of the \"hygiene hypothesis\" and how does it relate to the modern lifestyle?"}
99 | {"query": "How does the book challenge the conventional wisdom about industrialized agrochemical agriculture?"}
100 | {"query": "What is the role of smoke in beekeeping, and how should it be used effectively?"}
101 |
--------------------------------------------------------------------------------
/eval/datasets/cs/cs_query.jsonl:
--------------------------------------------------------------------------------
1 | {"query": "How does Spark Streaming enable real-time data processing?"}
2 | {"query": "What does the book suggest about the use of histograms in data analysis?"}
3 | {"query": "What are some advanced topics covered in the book related to Linux Kernel Networking?"}
4 | {"query": "What is the significance of the R tool in the context of modern optimization methods?"}
5 | {"query": "What are the key features of this text that aid in learning object-oriented concepts in Java?"}
6 | {"query": "What is the role of the RegExr tool in the book?"}
7 | {"query": "How does the text compare to other Java programming texts in terms of content and detail?"}
8 | {"query": "What role do Bayesian inference and priors play in the book?"}
9 | {"query": "What is the difference between recording a macro and writing code from scratch in VBA?"}
10 | {"query": "How does the book address the implementation of IPv6 in comparison to IPv4?"}
11 | {"query": "Can you explain the concept of standard coordinates as discussed in the book?"}
12 | {"query": "What are IP options and why might they be used?"}
13 | {"query": "How does the book approach the teaching of jargon related to regular expressions?"}
14 | {"query": "What role do netlink sockets play in Linux Kernel Networking?"}
15 | {"query": "What is the primary purpose of \"Joe Celko's SQL Programming Style\"?"}
16 | {"query": "What is the role of the tempdb database in SQL Server?"}
17 | {"query": "What audience is the text primarily intended for?"}
18 | {"query": "How does the book recommend handling the complexity of regular expressions?"}
19 | {"query": "What is a principal type in the context of type inference?"}
20 | {"query": "What are user-defined functions (UDFs) in SQL Server and how do they differ from stored procedures?"}
21 | {"query": "What are the two categories of indexes in SQL Server and what distinguishes them?"}
22 | {"query": "What caution does the book provide regarding the use of maximum likelihood estimation?"}
23 | {"query": "What is the significance of the ICMP protocol in Linux Kernel Networking?"}
24 | {"query": "What is the significance of the ALS algorithm in Spark's MLlib?"}
25 | {"query": "What does the book recommend regarding the use of proprietary data types?"}
26 | {"query": "How do you assign a macro to a button on the Quick Access Toolbar in Word?"}
27 | {"query": "What is Apache Spark and what are its key features?"}
28 | {"query": "What does the dollar sign ($) signify in regular expressions?"}
29 | {"query": "How does the book approach the topic of data encoding schemes?"}
30 | {"query": "What are the three main techniques used for semantic definitions in programming languages?"}
31 | {"query": "What are stored procedures (sprocs) and what advantages do they offer over sending individual SQL statements?"}
32 | {"query": "What is the primary purpose of VBA in Office applications?"}
33 | {"query": "What is the role of confluence in the operational semantics of programming languages?"}
34 | {"query": "How does the MovieLens dataset contribute to building recommendation engines?"}
35 | {"query": "What is the primary goal of the book \"Introducing Regular Expressions\"?"}
36 | {"query": "What tools or methodologies does the text use to help readers understand and design programs?"}
37 | {"query": "How does the FOR XML clause in SQL Server facilitate the conversion of relational data into XML format?"}
38 | {"query": "What role do examples and exercises play in the learning process according to the text?"}
39 | {"query": "What is the significance of the correlation coefficient in the book?"}
40 | {"query": "What are the three main approaches to handle multi-objective tasks discussed in the book?"}
41 | {"query": "What is a view in SQL Server and what are its primary uses?"}
42 | {"query": "How can you debug a macro in the Visual Basic Editor?"}
43 | {"query": "How does the book differentiate between probability and statistics?"}
44 | {"query": "What does the book consider as the biggest hurdle in learning SQL?"}
45 | {"query": "What are the four types of operators in VBA?"}
46 | {"query": "What is the book's stance on the use of jargon in regular expressions?"}
47 | {"query": "How does the book advocate for the use of views in SQL?"}
48 | {"query": "What are some of the tools and languages covered in the book for working with regular expressions?"}
49 | {"query": "What is the significance of the Option Explicit statement in VBA?"}
50 | {"query": "What is an object in the context of VBA?"}
51 | {"query": "What is the purpose of the Object Browser in the Visual Basic Editor?"}
52 | {"query": "What is the rationale behind using full reserved words in SQL according to the book?"}
53 | {"query": "Can you name some popular modern optimization methods discussed in the book?"}
54 | {"query": "What fundamental shift in thinking does the book encourage for effective SQL programming?"}
55 | {"query": "How does the author approach the topic of statistical significance?"}
56 | {"query": "What is the primary purpose of the text \"Guide to Java: A Concise Introduction to Programming\"?"}
57 | {"query": "How can you customize the Visual Basic Editor in Office applications?"}
58 | {"query": "What is the significance of the QED editor in the history of regular expressions?"}
59 | {"query": "How does the book address the issue of infeasible solutions in optimization problems?"}
60 | {"query": "What are the main components of a machine learning system designed with Spark?"}
61 | {"query": "What is the purpose of the caret (^) in regular expressions?"}
62 | {"query": "What is the significance of the `fix` construct in PCF (Programming language for computable functions)?"}
63 | {"query": "What does the book suggest as a strategy for testing SQL?"}
64 | {"query": "What is the purpose of normalization in database design and what are its benefits?"}
65 | {"query": "What is the difference between a variable and a constant in VBA?"}
66 | {"query": "How does the concept of \"environment\" differ between denotational and operational semantics?"}
67 | {"query": "How can you ensure that a macro runs automatically when an application starts?"}
68 | {"query": "What is the significance of the XML data type introduced in SQL Server 2005?"}
69 | {"query": "What is the significance of the `DEoptim` package in R for optimization tasks?"}
70 | {"query": "How does the author suggest handling categorical data in the context of plotting?"}
71 | {"query": "How does the text address the potential for errors in programming?"}
72 | {"query": "What is the role of the Immediate window in the Visual Basic Editor?"}
73 | {"query": "What is the concept of Pareto front in multi-objective optimization?"}
74 | {"query": "How does the text handle the introduction of complex topics like inheritance and polymorphism?"}
75 | {"query": "What is the role of the `optim` function in R when dealing with optimization problems?"}
76 | {"query": "What are the three main types of quantifiers discussed in the book?"}
77 | {"query": "What are the three major types of relationships in database design and give an example of each?"}
78 | {"query": "What naming convention does the book recommend for tables and views?"}
79 | {"query": "What is the primary goal of the book \"Modern Optimization with R\"?"}
80 | {"query": "How can you run Spark on Amazon EC2?"}
81 | {"query": "Describe the structure and function of the IPv4 header."}
82 | {"query": "How does the book suggest handling special characters in names?"}
83 | {"query": "What are the challenges in defining a denotational semantics for a language with side effects like references and assignments?"}
84 | {"query": "How does the Macro Recorder work in Word and Excel?"}
85 | {"query": "What are the two types of procedures in VBA?"}
86 | {"query": "How does the use of de Bruijn indices simplify the interpretation of terms in programming languages?"}
87 | {"query": "How does Spark differ from Hadoop in terms of performance?"}
88 | {"query": "How does the model database function as a template in SQL Server?"}
89 | {"query": "What is the primary purpose of the Linux Kernel Networking stack as described in the book?"}
90 | {"query": "How does the fixed point theorem play a role in the semantics of programming languages?"}
91 | {"query": "Explain the process of IPv4 fragmentation and defragmentation."}
92 | {"query": "What is the primary purpose of the master database in SQL Server?"}
93 | {"query": "What are some of the practical applications of Markov chains and Hidden Markov Models discussed in the book?"}
94 | {"query": "What is the significance of the \"dotall\" option in regular expressions?"}
95 | {"query": "How can you run a macro from the Visual Basic Editor?"}
96 | {"query": "What is the book's stance on using triggers in SQL programming?"}
97 | {"query": "What are the challenges in using naive Bayes models with numerical features?"}
98 | {"query": "What is the difference between call by name and call by value reduction strategies?"}
99 | {"query": "How does the book encourage the reader to engage with the R code examples?"}
100 | {"query": "How does the book introduce the concept of alternation in regular expressions?"}
101 |
--------------------------------------------------------------------------------
/eval/extract_context.py:
--------------------------------------------------------------------------------
1 | """
2 | Target: Extract unique contexts from original Q&A datasets.
3 | """
4 | import os
5 | import json
6 | import glob
7 | import argparse
8 |
9 |
10 | def extract_unique_contexts(input_directory, output_directory):
11 | os.makedirs(output_directory, exist_ok=True)
12 |
13 | jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
14 | print(f"Found {len(jsonl_files)} JSONL files.")
15 |
16 | for file_path in jsonl_files:
17 | filename = os.path.basename(file_path)
18 | name, ext = os.path.splitext(filename)
19 | output_filename = f"{name}_unique_contexts.json"
20 | output_path = os.path.join(output_directory, output_filename)
21 |
22 | unique_contexts_dict = {}
23 |
24 | print(f"Processing file: {filename}")
25 |
26 | try:
27 | with open(file_path, "r", encoding="utf-8") as infile:
28 | for line_number, line in enumerate(infile, start=1):
29 | line = line.strip()
30 | if not line:
31 | continue
32 | try:
33 | json_obj = json.loads(line)
34 | context = json_obj.get("context")
35 | if context and context not in unique_contexts_dict:
36 | unique_contexts_dict[context] = None
37 | except json.JSONDecodeError as e:
38 | print(
39 | f"JSON decoding error in file {filename} at line {line_number}: {e}"
40 | )
41 | except FileNotFoundError:
42 | print(f"File not found: {filename}")
43 | continue
44 | except Exception as e:
45 | print(f"An error occurred while processing file {filename}: {e}")
46 | continue
47 |
48 | unique_contexts_list = list(unique_contexts_dict.keys())
49 | print(
50 | f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
51 | )
52 |
53 | try:
54 | with open(output_path, "w", encoding="utf-8") as outfile:
55 | json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
56 | print(f"Unique `context` entries have been saved to: {output_filename}")
57 | except Exception as e:
58 | print(f"An error occurred while saving to the file {output_filename}: {e}")
59 |
60 | print("All files have been processed.")
61 |
62 |
63 | if __name__ == "__main__":
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
66 | parser.add_argument(
67 | "-o", "--output_dir", type=str, default="../datasets/unique_contexts"
68 | )
69 |
70 | args = parser.parse_args()
71 |
72 | extract_unique_contexts(args.input_dir, args.output_dir)
--------------------------------------------------------------------------------
/eval/extract_query.py:
--------------------------------------------------------------------------------
1 | """
2 | This script extract queries of a dataset from UltraDomain
3 |
4 | Example Usage:
5 | python extract_query.py -i xxx/mix.jsonl -o xxx/mix_query.jsonl
6 |
7 | Example Output File (xxx/mix_query.jsonl):
8 | {"query": "This is a query"}
9 | {"query": "Another query"}
10 | """
11 |
12 | import json
13 | import argparse
14 |
15 | def extract_query(input_file, output_file):
16 | print(f"Processing file: {input_file}")
17 |
18 | try:
19 | with open(input_file, "r", encoding="utf-8") as infile:
20 | with open(output_file, "w", encoding="utf-8") as outfile:
21 | for line_number, line in enumerate(infile, start=1):
22 | line = line.strip()
23 | if not line:
24 | continue
25 | try:
26 | json_obj = json.loads(line)
27 | query = json_obj.get("input")
28 | outfile.write(json.dumps({"query": query}, ensure_ascii=False) + "\n")
29 | except json.JSONDecodeError as e:
30 | print(f"JSON decoding error in file {input_file} at line {line_number}: {e}")
31 | except FileNotFoundError:
32 | print(f"File not found: {input_file}")
33 | except Exception as e:
34 | print(f"An error occurred while processing file {input_file}: {e}")
35 |
36 | print(f"Finished processing file: {input_file}")
37 |
38 | if __name__ == "__main__":
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument("-i", "--input_file", type=str, required=True, help="Path to the input file")
41 | parser.add_argument("-o", "--output_file", type=str, required=True, help="Path to the output file")
42 |
43 | args = parser.parse_args()
44 |
45 | extract_query(args.input_file, args.output_file)
46 |
--------------------------------------------------------------------------------
/eval/insert_context_deepseek.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import time
5 | sys.path.append("../")
6 | from hirag import HiRAG, QueryParam
7 | import os
8 | import logging
9 | import numpy as np
10 | import tiktoken
11 | import yaml
12 | from openai import AsyncOpenAI, OpenAI
13 | from dataclasses import dataclass
14 | from hirag.base import BaseKVStorage
15 | from hirag._utils import compute_args_hash
16 |
17 | logging.basicConfig(level=logging.WARNING)
18 | logging.getLogger("HiRAG").setLevel(logging.INFO)
19 |
20 | DATASET = "legal"
21 | file_path = f"./datasets/{DATASET}/{DATASET}_unique_contexts.json"
22 | WORKING_DIR = f"./datasets/{DATASET}/work_dir_deepseekV2.5_hi"
23 |
24 | with open('config.yaml', 'r') as file:
25 | config = yaml.safe_load(file)
26 |
27 | # Extract configurations
28 | MODEL = config['deepseek']['model']
29 | DEEPSEEK_API_KEY = config['deepseek']['api_key']
30 | DEEPSEEK_URL = config['deepseek']['base_url']
31 |
32 | GLM_MODEL = config['glm']['model']
33 | GLM_API_KEY = config['glm']['api_key']
34 | GLM_URL = config['glm']['base_url']
35 |
36 | OPENAI_MODEL = config['openai']['model']
37 | OPENAI_API_KEY = config['openai']['api_key']
38 | OPENAI_URL = config['openai']['base_url']
39 |
40 | TOTAL_TOKEN_COST = 0
41 | TOTAL_API_CALL_COST = 0
42 |
43 | tokenizer = tiktoken.get_encoding("cl100k_base")
44 |
45 | @dataclass
46 | class EmbeddingFunc:
47 | embedding_dim: int
48 | max_token_size: int
49 | func: callable
50 |
51 | async def __call__(self, *args, **kwargs) -> np.ndarray:
52 | return await self.func(*args, **kwargs)
53 |
54 | def wrap_embedding_func_with_attrs(**kwargs):
55 | """Wrap a function with attributes"""
56 |
57 | def final_decro(func) -> EmbeddingFunc:
58 | new_func = EmbeddingFunc(**kwargs, func=func)
59 | return new_func
60 |
61 | return final_decro
62 |
63 | @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
64 | async def GLM_embedding(texts: list[str]) -> np.ndarray:
65 | model_name = "embedding-3"
66 | client = OpenAI(
67 | api_key=GLM_API_KEY,
68 | base_url="https://open.bigmodel.cn/api/paas/v4/"
69 | )
70 | embedding = client.embeddings.create(
71 | input=texts,
72 | model=model_name,
73 | )
74 | final_embedding = [d.embedding for d in embedding.data]
75 | return np.array(final_embedding)
76 |
77 |
78 | async def deepseepk_model_if_cache(
79 | prompt, system_prompt=None, history_messages=[], **kwargs
80 | ) -> str:
81 | global TOTAL_TOKEN_COST
82 | global TOTAL_API_CALL_COST
83 |
84 | openai_async_client = AsyncOpenAI(
85 | api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_URL
86 | )
87 | messages = []
88 | if system_prompt:
89 | messages.append({"role": "system", "content": system_prompt})
90 |
91 | # Get the cached response if having-------------------
92 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
93 | messages.extend(history_messages)
94 | messages.append({"role": "user", "content": prompt})
95 | if hashing_kv is not None:
96 | args_hash = compute_args_hash(MODEL, messages)
97 | if_cache_return = await hashing_kv.get_by_id(args_hash)
98 | if if_cache_return is not None:
99 | return if_cache_return["return"]
100 | # -----------------------------------------------------
101 | retry_time = 3
102 | try:
103 | # logging token cost
104 | cur_token_cost = len(tokenizer.encode(messages[0]['content']))
105 | TOTAL_TOKEN_COST += cur_token_cost
106 | # logging api call cost
107 | TOTAL_API_CALL_COST += 1
108 | # request
109 | response = await openai_async_client.chat.completions.create(
110 | model=MODEL, messages=messages, **kwargs
111 | )
112 | except Exception as e:
113 | print(f"Retry for Error: {e}")
114 | retry_time -= 1
115 | response = ""
116 |
117 | if response == "":
118 | return response
119 |
120 | # Cache the response if having-------------------
121 | if hashing_kv is not None:
122 | await hashing_kv.upsert(
123 | {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
124 | )
125 | # -----------------------------------------------------
126 | return response.choices[0].message.content
127 |
128 |
129 | if __name__ == "__main__":
130 | graph_func = HiRAG(
131 | working_dir=WORKING_DIR,
132 | enable_llm_cache=True,
133 | embedding_func=GLM_embedding,
134 | best_model_func=deepseepk_model_if_cache,
135 | cheap_model_func=deepseepk_model_if_cache,
136 | enable_hierachical_mode=True,
137 | embedding_func_max_async=8,
138 | enable_naive_rag=True)
139 |
140 | with open(file_path, mode="r") as f:
141 | unique_contexts = json.load(f)[:100]
142 | graph_func.insert(unique_contexts)
143 | logging.info(f"[Total token cost: {TOTAL_TOKEN_COST}]")
144 | logging.info(f"[Total api call cost: {TOTAL_API_CALL_COST}]")
--------------------------------------------------------------------------------
/eval/insert_context_glm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import time
5 | sys.path.append("../")
6 | from hirag import HiRAG, QueryParam
7 | import os
8 | import logging
9 | import numpy as np
10 | import tiktoken
11 | import yaml
12 | from openai import AsyncOpenAI, OpenAI
13 | from dataclasses import dataclass
14 | from hirag.base import BaseKVStorage
15 | from hirag._utils import compute_args_hash
16 |
17 | logging.basicConfig(level=logging.WARNING)
18 | logging.getLogger("HiRAG").setLevel(logging.INFO)
19 |
20 | DATASET = "agriculture"
21 | file_path = f"./datasets/{DATASET}/{DATASET}_unique_contexts.json"
22 | WORKING_DIR = f"./datasets/{DATASET}/work_dir_deepseek_hi_clustercase"
23 |
24 | with open('config.yaml', 'r') as file:
25 | config = yaml.safe_load(file)
26 |
27 | MODEL = config['glm']['model']
28 | GLM_API_KEY = config['glm']['api_key']
29 | TOTAL_TOKEN_COST = 0
30 | TOTAL_API_CALL_COST = 0
31 | tokenizer = tiktoken.get_encoding("cl100k_base")
32 |
33 |
34 | @dataclass
35 | class EmbeddingFunc:
36 | embedding_dim: int
37 | max_token_size: int
38 | func: callable
39 |
40 | async def __call__(self, *args, **kwargs) -> np.ndarray:
41 | return await self.func(*args, **kwargs)
42 |
43 | def wrap_embedding_func_with_attrs(**kwargs):
44 | """Wrap a function with attributes"""
45 |
46 | def final_decro(func) -> EmbeddingFunc:
47 | new_func = EmbeddingFunc(**kwargs, func=func)
48 | return new_func
49 |
50 | return final_decro
51 |
52 | @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
53 | async def GLM_embedding(texts: list[str]) -> np.ndarray:
54 | model_name = "embedding-3"
55 | client = OpenAI(
56 | api_key=GLM_API_KEY,
57 | base_url="https://open.bigmodel.cn/api/paas/v4/"
58 | )
59 | embedding = client.embeddings.create(
60 | input=texts,
61 | model=model_name,
62 | )
63 | final_embedding = [d.embedding for d in embedding.data]
64 | return np.array(final_embedding)
65 |
66 |
67 | async def glm_model_if_cache(
68 | prompt, system_prompt=None, history_messages=[], **kwargs
69 | ) -> str:
70 | global TOTAL_TOKEN_COST
71 | global TOTAL_API_CALL_COST
72 |
73 | openai_async_client = AsyncOpenAI(
74 | api_key=GLM_API_KEY, base_url="https://open.bigmodel.cn/api/paas/v4"
75 | )
76 | messages = []
77 | if system_prompt:
78 | messages.append({"role": "system", "content": system_prompt})
79 |
80 | # Get the cached response if having-------------------
81 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
82 | messages.extend(history_messages)
83 | messages.append({"role": "user", "content": prompt})
84 |
85 | # logging token cost
86 | cur_token_cost = len(tokenizer.encode(messages[0]['content']))
87 | TOTAL_TOKEN_COST += cur_token_cost
88 | # logging api call cost
89 | TOTAL_API_CALL_COST += 1
90 |
91 | if hashing_kv is not None:
92 | args_hash = compute_args_hash(MODEL, messages)
93 | if_cache_return = await hashing_kv.get_by_id(args_hash)
94 | if if_cache_return is not None:
95 | return if_cache_return["return"]
96 | # -----------------------------------------------------
97 | try:
98 | # request
99 | response = await openai_async_client.chat.completions.create(
100 | model=MODEL, messages=messages, **kwargs
101 | )
102 | except Exception as e:
103 | logging.info(e)
104 | return "<|COMPLETE|>"
105 |
106 | # Cache the response if having-------------------
107 | if hashing_kv is not None:
108 | await hashing_kv.upsert(
109 | {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
110 | )
111 | # -----------------------------------------------------
112 | return response.choices[0].message.content
113 |
114 |
115 | if __name__ == "__main__":
116 | graph_func = HiRAG(
117 | working_dir=WORKING_DIR,
118 | enable_llm_cache=True,
119 | embedding_func=GLM_embedding,
120 | best_model_func=glm_model_if_cache,
121 | cheap_model_func=glm_model_if_cache,
122 | enable_hierachical_mode=True,
123 | embedding_func_max_async=8,
124 | enable_naive_rag=True)
125 |
126 | with open(file_path, mode="r") as f:
127 | unique_contexts = json.load(f)
128 | graph_func.insert(unique_contexts)
129 | logging.info(f"[Total token cost: {TOTAL_TOKEN_COST}]")
130 | logging.info(f"[Total api call cost: {TOTAL_API_CALL_COST}]")
--------------------------------------------------------------------------------
/eval/insert_context_openai.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import time
4 | from hirag import HiRAG, QueryParam
5 | os.environ["OPENAI_API_KEY"] = "***"
6 |
7 | DATASET = "mix"
8 | file_path = f"./datasets/{DATASET}/{DATASET}_unique_contexts.json"
9 |
10 | graph_func = HiRAG(
11 | working_dir=f"./datasets/{DATASET}/work_dir_hi",
12 | enable_hierachical_mode=True,
13 | embedding_func_max_async=4,
14 | enable_naive_rag=True)
15 |
16 | with open(file_path, mode="r") as f:
17 | unique_contexts = json.load(f)
18 | graph_func.insert(unique_contexts)
--------------------------------------------------------------------------------
/eval/test_deepseek.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import time
5 | import argparse
6 | sys.path.append("../")
7 | import os
8 | import logging
9 | import numpy as np
10 | import tiktoken
11 | import yaml
12 | from hirag import HiRAG, QueryParam
13 | from openai import AsyncOpenAI, OpenAI
14 | from dataclasses import dataclass
15 | from hirag.base import BaseKVStorage
16 | from hirag._utils import compute_args_hash
17 | from tqdm import tqdm
18 |
19 | WORKING_DIR = f"./datasets/cs/work_dir_deepseek_hi"
20 | MAX_QUERIES = 100
21 | TOTAL_TOKEN_COST = 0
22 | TOTAL_API_CALL_COST = 0
23 |
24 | with open('config.yaml', 'r') as file:
25 | config = yaml.safe_load(file)
26 |
27 | # Extract configurations
28 | MODEL = config['deepseek']['model']
29 | DEEPSEEK_API_KEY = config['deepseek']['api_key']
30 | DEEPSEEK_URL = config['deepseek']['base_url']
31 |
32 | GLM_MODEL = config['glm']['model']
33 | GLM_API_KEY = config['glm']['api_key']
34 | GLM_URL = config['glm']['base_url']
35 |
36 | OPENAI_MODEL = config['openai']['model']
37 | OPENAI_API_KEY = config['openai']['api_key']
38 | OPENAI_URL = config['openai']['base_url']
39 | tokenizer = tiktoken.get_encoding("cl100k_base")
40 |
41 | @dataclass
42 | class EmbeddingFunc:
43 | embedding_dim: int
44 | max_token_size: int
45 | func: callable
46 |
47 | async def __call__(self, *args, **kwargs) -> np.ndarray:
48 | return await self.func(*args, **kwargs)
49 |
50 | def wrap_embedding_func_with_attrs(**kwargs):
51 | """Wrap a function with attributes"""
52 |
53 | def final_decro(func) -> EmbeddingFunc:
54 | new_func = EmbeddingFunc(**kwargs, func=func)
55 | return new_func
56 |
57 | return final_decro
58 |
59 | @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
60 | async def GLM_embedding(texts: list[str]) -> np.ndarray:
61 | model_name = "embedding-3"
62 | client = OpenAI(
63 | api_key=GLM_API_KEY,
64 | base_url="https://open.bigmodel.cn/api/paas/v4/"
65 | )
66 | embedding = client.embeddings.create(
67 | input=texts,
68 | model=model_name,
69 | )
70 | final_embedding = [d.embedding for d in embedding.data]
71 | return np.array(final_embedding)
72 |
73 |
74 | async def deepseepk_model_if_cache(
75 | prompt, system_prompt=None, history_messages=[], **kwargs
76 | ) -> str:
77 | global TOTAL_TOKEN_COST
78 | global TOTAL_API_CALL_COST
79 |
80 | openai_async_client = AsyncOpenAI(
81 | api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_URL
82 | )
83 | messages = []
84 | if system_prompt:
85 | messages.append({"role": "system", "content": system_prompt})
86 |
87 | # Get the cached response if having-------------------
88 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
89 | messages.extend(history_messages)
90 | messages.append({"role": "user", "content": prompt})
91 | if hashing_kv is not None:
92 | args_hash = compute_args_hash(MODEL, messages)
93 | if_cache_return = await hashing_kv.get_by_id(args_hash)
94 | if if_cache_return is not None:
95 | return if_cache_return["return"]
96 | # -----------------------------------------------------
97 | # logging token cost
98 | cur_token_cost = len(tokenizer.encode(messages[0]['content']))
99 | TOTAL_TOKEN_COST += cur_token_cost
100 | response = await openai_async_client.chat.completions.create(
101 | model=MODEL, messages=messages, **kwargs
102 | )
103 |
104 | # Cache the response if having-------------------
105 | if hashing_kv is not None:
106 | await hashing_kv.upsert(
107 | {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
108 | )
109 | # -----------------------------------------------------
110 | return response.choices[0].message.content
111 |
112 |
113 |
114 | if __name__ == "__main__":
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument("-d", "--dataset", type=str, default="legal")
117 | parser.add_argument("-m", "--mode", type=str, default="hi", help="hi / naive / hi_global / hi_local / hi_bridge / hi_nobridge")
118 | args = parser.parse_args()
119 |
120 | DATASET = args.dataset
121 | tok_k = 10
122 | if DATASET == "mix":
123 | MAX_QUERIES = 130
124 | elif DATASET == "cs" or DATASET == "agriculture" or DATASET == "legal":
125 | MAX_QUERIES = 100
126 | input_path = f"./datasets/{DATASET}/{DATASET}.jsonl"
127 | output_path = f"./datasets/{DATASET}/{DATASET}_{args.mode}_result_deepseek_pro.jsonl"
128 | graph_func = HiRAG(
129 | working_dir=WORKING_DIR,
130 | enable_llm_cache=False,
131 | embedding_func=GLM_embedding,
132 | best_model_func=deepseepk_model_if_cache,
133 | cheap_model_func=deepseepk_model_if_cache,
134 | enable_hierachical_mode=True,
135 | embedding_func_max_async=8,
136 | enable_naive_rag=True)
137 |
138 | query_list = []
139 | with open(input_path, encoding="utf-8", mode="r") as f: # get context
140 | lines = f.readlines()
141 | for item in lines:
142 | item_dict = json.loads(item)
143 | query_list.append(item_dict["input"])
144 | query_list = query_list[:MAX_QUERIES]
145 | answer_list = []
146 |
147 | print(f"Perform {args.mode} search:")
148 | for query in tqdm(query_list):
149 | logging.info(f"Q: {query}")
150 | retry = 3
151 | while retry:
152 | try:
153 | answer = graph_func.query(query=query, param=QueryParam(mode=args.mode, top_k=tok_k))
154 | retry = 0
155 | except Exception as e:
156 | print(e)
157 | answer = "Error"
158 | retry -= 1
159 | logging.info(f"A: {answer} \n ################################################################################################")
160 | answer_list.append(answer)
161 | logging.info(f"[Token Cost: {TOTAL_TOKEN_COST}]")
162 |
163 | result_to_write = []
164 | for query, answer in zip(query_list, answer_list):
165 | result_to_write.append({"query": query, "answer": answer})
166 | with open(output_path, "w") as f:
167 | for item in result_to_write:
168 | f.write(json.dumps(item) + "\n")
169 |
170 |
--------------------------------------------------------------------------------
/eval/test_glm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | import time
5 | import argparse
6 | sys.path.append("../")
7 | import os
8 | import logging
9 | import numpy as np
10 | import tiktoken
11 | import yaml
12 | from hirag import HiRAG, QueryParam
13 | from openai import AsyncOpenAI, OpenAI
14 | from dataclasses import dataclass
15 | from hirag.base import BaseKVStorage
16 | from hirag._utils import compute_args_hash
17 | from tqdm import tqdm
18 |
19 | WORKING_DIR = f"./datasets/cs/work_dir_deepseek_hi"
20 |
21 | with open('config.yaml', 'r') as file:
22 | config = yaml.safe_load(file)
23 |
24 | MODEL = config['glm']['model']
25 | GLM_API_KEY = config['glm']['api_key']
26 | MAX_QUERIES = 100
27 | TOTAL_TOKEN_COST = 0
28 | TOTAL_API_CALL_COST = 0
29 | tokenizer = tiktoken.get_encoding("cl100k_base")
30 |
31 |
32 | @dataclass
33 | class EmbeddingFunc:
34 | embedding_dim: int
35 | max_token_size: int
36 | func: callable
37 |
38 | async def __call__(self, *args, **kwargs) -> np.ndarray:
39 | return await self.func(*args, **kwargs)
40 |
41 | def wrap_embedding_func_with_attrs(**kwargs):
42 | """Wrap a function with attributes"""
43 |
44 | def final_decro(func) -> EmbeddingFunc:
45 | new_func = EmbeddingFunc(**kwargs, func=func)
46 | return new_func
47 |
48 | return final_decro
49 |
50 | @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
51 | async def GLM_embedding(texts: list[str]) -> np.ndarray:
52 | model_name = "embedding-3"
53 | client = OpenAI(
54 | api_key=GLM_API_KEY,
55 | base_url="https://open.bigmodel.cn/api/paas/v4/"
56 | )
57 | embedding = client.embeddings.create(
58 | input=texts,
59 | model=model_name,
60 | )
61 | final_embedding = [d.embedding for d in embedding.data]
62 | return np.array(final_embedding)
63 |
64 |
65 |
66 | async def glm_model_if_cache(
67 | prompt, system_prompt=None, history_messages=[], **kwargs
68 | ) -> str:
69 | global TOTAL_TOKEN_COST
70 | global TOTAL_API_CALL_COST
71 |
72 | openai_async_client = AsyncOpenAI(
73 | api_key=GLM_API_KEY, base_url="https://open.bigmodel.cn/api/paas/v4"
74 | )
75 | messages = []
76 | if system_prompt:
77 | messages.append({"role": "system", "content": system_prompt})
78 |
79 | # Get the cached response if having-------------------
80 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
81 | messages.extend(history_messages)
82 | messages.append({"role": "user", "content": prompt})
83 | if hashing_kv is not None:
84 | args_hash = compute_args_hash(MODEL, messages)
85 | if_cache_return = await hashing_kv.get_by_id(args_hash)
86 | if if_cache_return is not None:
87 | return if_cache_return["return"]
88 | # -----------------------------------------------------
89 | try:
90 | # logging token cost
91 | cur_token_cost = len(tokenizer.encode(messages[0]['content']))
92 | TOTAL_TOKEN_COST += cur_token_cost
93 | response = await openai_async_client.chat.completions.create(
94 | model=MODEL, messages=messages, **kwargs
95 | )
96 | except Exception as e:
97 | logging.info(e)
98 | return "ERROR"
99 |
100 | # Cache the response if having-------------------
101 | if hashing_kv is not None:
102 | await hashing_kv.upsert(
103 | {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
104 | )
105 | # -----------------------------------------------------
106 | return response.choices[0].message.content
107 |
108 |
109 |
110 | if __name__ == "__main__":
111 | parser = argparse.ArgumentParser()
112 | parser.add_argument("-d", "--dataset", type=str, default="cs")
113 | parser.add_argument("-m", "--mode", type=str, default="hi", help="hi / naive / hi_global / hi_local / hi_bridge / hi_nobridge")
114 | args = parser.parse_args()
115 |
116 | if args.mode == "naive":
117 | mode = True
118 | elif args.mode == "global" or "local":
119 | mode = False
120 |
121 | DATASET = args.dataset
122 | if DATASET == "mix":
123 | MAX_QUERIES = 130
124 | elif DATASET == "cs" or DATASET == "agriculture" or DATASET == "legal":
125 | MAX_QUERIES = 100
126 | tok_k = 20
127 | input_path = f"./datasets/{DATASET}/{DATASET}.jsonl"
128 | output_path = f"./datasets/{DATASET}/{DATASET}_{args.mode}_result_glm.jsonl"
129 | graph_func = HiRAG(
130 | working_dir=WORKING_DIR,
131 | enable_llm_cache=False,
132 | embedding_func=GLM_embedding,
133 | best_model_func=glm_model_if_cache,
134 | cheap_model_func=glm_model_if_cache,
135 | enable_hierachical_mode=True,
136 | embedding_func_max_async=8,
137 | enable_naive_rag=mode)
138 |
139 | query_list = []
140 | with open(input_path, encoding="utf-8", mode="r") as f: # get context
141 | lines = f.readlines()
142 | for item in lines:
143 | item_dict = json.loads(item)
144 | query_list.append(item_dict["input"])
145 | query_list = query_list[:MAX_QUERIES]
146 | answer_list = []
147 |
148 | print(f"Perform {args.mode} search:")
149 | for query in tqdm(query_list):
150 | logging.info(f"Q: {query}")
151 | answer = graph_func.query(query=query, param=QueryParam(mode=args.mode, top_k=tok_k))
152 | logging.info(f"A: {answer} \n ################################################################################################")
153 | answer_list.append(answer)
154 | logging.info(f"[Token Cost: {TOTAL_TOKEN_COST}]")
155 |
156 | result_to_write = []
157 | for query, answer in zip(query_list, answer_list):
158 | result_to_write.append({"query": query, "answer": answer})
159 | with open(output_path, "w") as f:
160 | for item in result_to_write:
161 | f.write(json.dumps(item) + "\n")
162 |
--------------------------------------------------------------------------------
/eval/test_openai.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from tqdm import tqdm
5 | from hirag import HiRAG, QueryParam
6 | os.environ["OPENAI_API_KEY"] = "***"
7 | MAX_QUERIES = 100
8 |
9 | if __name__ == "__main__":
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("-d", "--dataset", type=str, default="mix")
12 | parser.add_argument("-m", "--mode", type=str, default="hi", help="hi / naive / hi_global / hi_local / hi_bridge / hi_nobridge")
13 | args = parser.parse_args()
14 |
15 | if args.mode == "naive":
16 | mode = True
17 | elif args.mode == "global" or "local":
18 | mode = False
19 |
20 | DATASET = args.dataset
21 | input_path = f"./datasets/{DATASET}/{DATASET}.jsonl"
22 | output_path = f"./datasets/{DATASET}/{DATASET}_{args.mode}_result.jsonl"
23 | graph_func = HiRAG(
24 | working_dir=f"./datasets/{DATASET}/work_dir",
25 | enable_hierachical_mode=False,
26 | embedding_func_max_async=4,
27 | enable_naive_rag=mode)
28 |
29 | query_list = []
30 | with open(input_path, encoding="utf-8", mode="r") as f: # get context
31 | lines = f.readlines()
32 | for item in lines:
33 | item_dict = json.loads(item)
34 | query_list.append(item_dict["input"])
35 | query_list = query_list[:MAX_QUERIES]
36 | answer_list = []
37 |
38 | print(f"Perform {args.mode} search:")
39 | for query in tqdm(query_list):
40 | tqdm.write(f"Q: {query}")
41 | answer = graph_func.query(query=query, param=QueryParam(mode=args.mode))
42 | tqdm.write(f"A: {answer} \n ################################################################################################")
43 | answer_list.append(answer)
44 |
45 | result_to_write = []
46 | for query, answer in zip(query_list, answer_list):
47 | result_to_write.append({"query": query, "answer": answer})
48 | with open(output_path, "w") as f:
49 | for item in result_to_write:
50 | f.write(json.dumps(item) + "\n")
51 |
52 |
--------------------------------------------------------------------------------
/hi_Search_deepseek.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import yaml
5 | from hirag import HiRAG, QueryParam
6 | from openai import AsyncOpenAI, OpenAI
7 | from dataclasses import dataclass
8 | from hirag.base import BaseKVStorage
9 | from hirag._utils import compute_args_hash
10 | from tqdm import tqdm
11 |
12 | # Load configuration from YAML file
13 | with open('config.yaml', 'r') as file:
14 | config = yaml.safe_load(file)
15 |
16 | # Extract configurations
17 | GLM_API_KEY = config['glm']['api_key']
18 | MODEL = config['deepseek']['model']
19 | DEEPSEEK_API_KEY = config['deepseek']['api_key']
20 | DEEPSEEK_URL = config['deepseek']['base_url']
21 | GLM_URL = config['glm']['base_url']
22 |
23 |
24 | @dataclass
25 | class EmbeddingFunc:
26 | embedding_dim: int
27 | max_token_size: int
28 | func: callable
29 |
30 | async def __call__(self, *args, **kwargs) -> np.ndarray:
31 | return await self.func(*args, **kwargs)
32 |
33 | def wrap_embedding_func_with_attrs(**kwargs):
34 | """Wrap a function with attributes"""
35 |
36 | def final_decro(func) -> EmbeddingFunc:
37 | new_func = EmbeddingFunc(**kwargs, func=func)
38 | return new_func
39 |
40 | return final_decro
41 |
42 | @wrap_embedding_func_with_attrs(embedding_dim=config['model_params']['glm_embedding_dim'], max_token_size=config['model_params']['max_token_size'])
43 | async def GLM_embedding(texts: list[str]) -> np.ndarray:
44 | model_name = "embedding-3"
45 | client = OpenAI(
46 | api_key=GLM_API_KEY,
47 | base_url=GLM_URL
48 | )
49 | embedding = client.embeddings.create(
50 | input=texts,
51 | model=model_name,
52 | )
53 | final_embedding = [d.embedding for d in embedding.data]
54 | return np.array(final_embedding)
55 |
56 |
57 | async def deepseepk_model_if_cache(
58 | prompt, system_prompt=None, history_messages=[], **kwargs
59 | ) -> str:
60 | openai_async_client = AsyncOpenAI(
61 | api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_URL
62 | )
63 | messages = []
64 | if system_prompt:
65 | messages.append({"role": "system", "content": system_prompt})
66 |
67 | # Get the cached response if having-------------------
68 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
69 | messages.extend(history_messages)
70 | messages.append({"role": "user", "content": prompt})
71 | if hashing_kv is not None:
72 | args_hash = compute_args_hash(MODEL, messages)
73 | if_cache_return = await hashing_kv.get_by_id(args_hash)
74 | if if_cache_return is not None:
75 | return if_cache_return["return"]
76 | # -----------------------------------------------------
77 |
78 | response = await openai_async_client.chat.completions.create(
79 | model=MODEL, messages=messages, **kwargs
80 | )
81 |
82 | # Cache the response if having-------------------
83 | if hashing_kv is not None:
84 | await hashing_kv.upsert(
85 | {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
86 | )
87 | # -----------------------------------------------------
88 | return response.choices[0].message.content
89 |
90 |
91 | graph_func = HiRAG(
92 | working_dir=config['hirag']['working_dir'],
93 | enable_llm_cache=config['hirag']['enable_llm_cache'],
94 | embedding_func=GLM_embedding,
95 | best_model_func=deepseepk_model_if_cache,
96 | cheap_model_func=deepseepk_model_if_cache,
97 | enable_hierachical_mode=config['hirag']['enable_hierachical_mode'],
98 | embedding_batch_num=config['hirag']['embedding_batch_num'],
99 | embedding_func_max_async=config['hirag']['embedding_func_max_async'],
100 | enable_naive_rag=config['hirag']['enable_naive_rag'])
101 |
102 | # comment this if the working directory has already been indexed
103 | with open("your .txt file path") as f:
104 | graph_func.insert(f.read())
105 |
106 | print("Perform hi search:")
107 | print(graph_func.query("What are the top themes in this story?", param=QueryParam(mode="hi")))
108 |
--------------------------------------------------------------------------------
/hi_Search_glm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import yaml
5 | from hirag import HiRAG, QueryParam
6 | from openai import AsyncOpenAI, OpenAI
7 | from dataclasses import dataclass
8 | from hirag.base import BaseKVStorage
9 | from hirag._utils import compute_args_hash
10 | from tqdm import tqdm
11 |
12 | # Load configuration from YAML file
13 | with open('config.yaml', 'r') as file:
14 | config = yaml.safe_load(file)
15 |
16 | # Extract configurations
17 | GLM_API_KEY = config['glm']['api_key']
18 | MODEL = config['glm']['model']
19 | GLM_URL = config['glm']['base_url']
20 |
21 | @dataclass
22 | class EmbeddingFunc:
23 | embedding_dim: int
24 | max_token_size: int
25 | func: callable
26 |
27 | async def __call__(self, *args, **kwargs) -> np.ndarray:
28 | return await self.func(*args, **kwargs)
29 |
30 | def wrap_embedding_func_with_attrs(**kwargs):
31 | """Wrap a function with attributes"""
32 |
33 | def final_decro(func) -> EmbeddingFunc:
34 | new_func = EmbeddingFunc(**kwargs, func=func)
35 | return new_func
36 |
37 | return final_decro
38 |
39 | @wrap_embedding_func_with_attrs(embedding_dim=config['model_params']['glm_embedding_dim'], max_token_size=config['model_params']['max_token_size'])
40 | async def GLM_embedding(texts: list[str]) -> np.ndarray:
41 | model_name = config['glm']['embedding_model']
42 | client = OpenAI(
43 | api_key=GLM_API_KEY,
44 | base_url=GLM_URL
45 | )
46 | embedding = client.embeddings.create(
47 | input=texts,
48 | model=model_name,
49 | )
50 | final_embedding = [d.embedding for d in embedding.data]
51 | return np.array(final_embedding)
52 |
53 |
54 |
55 | async def glm_model_if_cache(
56 | prompt, system_prompt=None, history_messages=[], **kwargs
57 | ) -> str:
58 | openai_async_client = AsyncOpenAI(
59 | api_key=GLM_API_KEY, base_url=GLM_URL
60 | )
61 | messages = []
62 | if system_prompt:
63 | messages.append({"role": "system", "content": system_prompt})
64 |
65 | # Get the cached response if having-------------------
66 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
67 | messages.extend(history_messages)
68 | messages.append({"role": "user", "content": prompt})
69 | if hashing_kv is not None:
70 | args_hash = compute_args_hash(MODEL, messages)
71 | if_cache_return = await hashing_kv.get_by_id(args_hash)
72 | if if_cache_return is not None:
73 | return if_cache_return["return"]
74 | # -----------------------------------------------------
75 | try:
76 | response = await openai_async_client.chat.completions.create(
77 | model=MODEL, messages=messages, **kwargs
78 | )
79 | except Exception as e:
80 | logging.info(e)
81 | return "ERROR"
82 |
83 | # Cache the response if having-------------------
84 | if hashing_kv is not None:
85 | await hashing_kv.upsert(
86 | {args_hash: {"return": response.choices[0].message.content, "model": MODEL}}
87 | )
88 | # -----------------------------------------------------
89 | return response.choices[0].message.content
90 |
91 |
92 | graph_func = HiRAG(
93 | working_dir=config['hirag']['working_dir'],
94 | enable_llm_cache=config['hirag']['enable_llm_cache'],
95 | embedding_func=GLM_embedding,
96 | best_model_func=glm_model_if_cache,
97 | cheap_model_func=glm_model_if_cache,
98 | enable_hierachical_mode=config['hirag']['enable_hierachical_mode'],
99 | embedding_func_max_async=config['hirag']['embedding_func_max_async'],
100 | enable_naive_rag=config['hirag']['enable_naive_rag'])
101 |
102 | # comment this if the working directory has already been indexed
103 | with open("your .txt file path") as f:
104 | graph_func.insert(f.read())
105 |
106 | print("Perform hi search:")
107 | print(graph_func.query("What are the top themes in this story?", param=QueryParam(mode="hi")))
108 |
--------------------------------------------------------------------------------
/hi_Search_openai.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import yaml
5 | from hirag import HiRAG, QueryParam
6 | from openai import AsyncOpenAI, OpenAI
7 | from dataclasses import dataclass
8 | from hirag.base import BaseKVStorage
9 | from hirag._utils import compute_args_hash
10 |
11 | # Load configuration from YAML file
12 | with open('config.yaml', 'r') as file:
13 | config = yaml.safe_load(file)
14 |
15 | # Extract configurations
16 | OPENAI_EMBEDDING_MODEL = config['openai']['embedding_model']
17 | OPENAI_MODEL = config['openai']['model']
18 | OPENAI_API_KEY = config['openai']['api_key']
19 | OPENAI_URL = config['openai']['base_url']
20 | GLM_API_KEY = config['glm']['api_key']
21 | GLM_MODEL = config['glm']['model']
22 | GLM_URL = config['glm']['base_url']
23 |
24 | @dataclass
25 | class EmbeddingFunc:
26 | embedding_dim: int
27 | max_token_size: int
28 | func: callable
29 |
30 | async def __call__(self, *args, **kwargs) -> np.ndarray:
31 | return await self.func(*args, **kwargs)
32 |
33 | def wrap_embedding_func_with_attrs(**kwargs):
34 | """Wrap a function with attributes"""
35 |
36 | def final_decro(func) -> EmbeddingFunc:
37 | new_func = EmbeddingFunc(**kwargs, func=func)
38 | return new_func
39 |
40 | return final_decro
41 |
42 | @wrap_embedding_func_with_attrs(embedding_dim=config['model_params']['openai_embedding_dim'], max_token_size=config['model_params']['max_token_size'])
43 | async def OPENAI_embedding(texts: list[str]) -> np.ndarray:
44 | openai_async_client = AsyncOpenAI(base_url=OPENAI_URL, api_key=OPENAI_API_KEY)
45 | response = await openai_async_client.embeddings.create(
46 | model=OPENAI_EMBEDDING_MODEL, input=texts, encoding_format="float"
47 | )
48 | return np.array([dp.embedding for dp in response.data])
49 |
50 | @wrap_embedding_func_with_attrs(embedding_dim=config['model_params']['glm_embedding_dim'], max_token_size=config['model_params']['max_token_size'])
51 | async def GLM_embedding(texts: list[str]) -> np.ndarray:
52 | model_name = config['glm']['embedding_model']
53 | client = OpenAI(
54 | api_key=GLM_API_KEY,
55 | base_url=GLM_URL
56 | )
57 | embedding = client.embeddings.create(
58 | input=texts,
59 | model=model_name,
60 | )
61 | final_embedding = [d.embedding for d in embedding.data]
62 | return np.array(final_embedding)
63 |
64 | async def OPENAI_model_if_cache(
65 | prompt, system_prompt=None, history_messages=[], **kwargs
66 | ) -> str:
67 | openai_async_client = AsyncOpenAI(
68 | api_key=OPENAI_API_KEY, base_url=OPENAI_URL
69 | )
70 | messages = []
71 | if system_prompt:
72 | messages.append({"role": "system", "content": system_prompt})
73 |
74 | # Get the cached response if having-------------------
75 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
76 | messages.extend(history_messages)
77 | messages.append({"role": "user", "content": prompt})
78 | if hashing_kv is not None:
79 | args_hash = compute_args_hash(OPENAI_MODEL, messages)
80 | if_cache_return = await hashing_kv.get_by_id(args_hash)
81 | if if_cache_return is not None:
82 | return if_cache_return["return"]
83 | # -----------------------------------------------------
84 |
85 | response = await openai_async_client.chat.completions.create(
86 | model=OPENAI_MODEL, messages=messages, **kwargs
87 | )
88 |
89 | # Cache the response if having-------------------
90 | if hashing_kv is not None:
91 | await hashing_kv.upsert(
92 | {args_hash: {"return": response.choices[0].message.content, "model": OPENAI_MODEL}}
93 | )
94 | # -----------------------------------------------------
95 | return response.choices[0].message.content
96 |
97 |
98 | graph_func = HiRAG(working_dir=config['hirag']['working_dir'],
99 | enable_llm_cache=config['hirag']['enable_llm_cache'],
100 | embedding_func=OPENAI_embedding,
101 | best_model_func=OPENAI_model_if_cache,
102 | cheap_model_func=OPENAI_model_if_cache,
103 | enable_hierachical_mode=config['hirag']['enable_hierachical_mode'],
104 | embedding_batch_num=config['hirag']['embedding_batch_num'],
105 | embedding_func_max_async=config['hirag']['embedding_func_max_async'],
106 | enable_naive_rag=config['hirag']['enable_naive_rag'])
107 |
108 | # comment this if the working directory has already been indexed
109 | with open("your .txt file path") as f:
110 | graph_func.insert(f.read())
111 |
112 |
113 | print("Perform hi search:")
114 | print(graph_func.query("What are the top themes in this story?", param=QueryParam(mode="hi")))
115 |
--------------------------------------------------------------------------------
/hirag/__init__.py:
--------------------------------------------------------------------------------
1 | from .hirag import HiRAG, QueryParam
2 |
3 | __version__ = "0.1.0"
4 | __author__ = "Haoyu Huang"
5 | __url__ = "https://github.com/hhy-huang/HiRAG"
--------------------------------------------------------------------------------
/hirag/_cluster_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | import re
4 | import numpy as np
5 | import tiktoken
6 | import umap
7 | import copy
8 | import asyncio
9 | from abc import ABC, abstractmethod
10 | from typing import List, Optional
11 | from sklearn.mixture import GaussianMixture
12 | from tqdm import tqdm
13 | from collections import Counter, defaultdict
14 | from .base import (
15 | BaseGraphStorage,
16 | BaseKVStorage,
17 | BaseVectorStorage
18 | )
19 | from ._utils import split_string_by_multi_markers, clean_str, is_float_regex
20 | from .prompt import GRAPH_FIELD_SEP, PROMPTS
21 |
22 | # Initialize logging
23 | logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
24 |
25 | # Set a random seed for reproducibility
26 | RANDOM_SEED = 224
27 | random.seed(RANDOM_SEED)
28 |
29 |
30 | def global_cluster_embeddings(
31 | embeddings: np.ndarray,
32 | dim: int,
33 | n_neighbors: int = 15,
34 | metric: str = "cosine",
35 | ) -> np.ndarray:
36 | if n_neighbors is None:
37 | n_neighbors = int((len(embeddings) - 1) ** 0.5)
38 | reduced_embeddings = umap.UMAP(
39 | n_neighbors=n_neighbors, n_components=dim, metric=metric
40 | ).fit_transform(embeddings)
41 | return reduced_embeddings
42 |
43 |
44 | def local_cluster_embeddings(
45 | embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine"
46 | ) -> np.ndarray:
47 | reduced_embeddings = umap.UMAP(
48 | n_neighbors=num_neighbors, n_components=dim, metric=metric
49 | ).fit_transform(embeddings)
50 | return reduced_embeddings
51 |
52 |
53 | def fit_gaussian_mixture(n_components, embeddings, random_state):
54 | gm = GaussianMixture(
55 | n_components=n_components,
56 | random_state=random_state,
57 | n_init=5,
58 | init_params='k-means++'
59 | )
60 | gm.fit(embeddings)
61 | return gm.bic(embeddings)
62 |
63 |
64 | def get_optimal_clusters(embeddings, max_clusters=50, random_state=0, rel_tol=1e-3):
65 | max_clusters = min(len(embeddings), max_clusters)
66 | n_clusters = np.arange(1, max_clusters)
67 | bics = []
68 | prev_bic = float('inf')
69 | for n in tqdm(n_clusters):
70 | bic = fit_gaussian_mixture(n, embeddings, random_state)
71 | # print(bic)
72 | bics.append(bic)
73 | # early stop
74 | if (abs(prev_bic - bic) / abs(prev_bic)) < rel_tol:
75 | break
76 | prev_bic = bic
77 | optimal_clusters = n_clusters[np.argmin(bics)]
78 | return optimal_clusters
79 |
80 |
81 | def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0):
82 | n_clusters = get_optimal_clusters(embeddings)
83 | gm = GaussianMixture(
84 | n_components=n_clusters,
85 | random_state=random_state,
86 | n_init=5,
87 | init_params='k-means++')
88 | gm.fit(embeddings)
89 | probs = gm.predict_proba(embeddings) # [num, cluster_num]
90 | labels = [np.where(prob > threshold)[0] for prob in probs]
91 | return labels, n_clusters
92 |
93 |
94 | def perform_clustering(
95 | embeddings: np.ndarray, dim: int, threshold: float, verbose: bool = False
96 | ) -> List[np.ndarray]:
97 | reduced_embeddings_global = global_cluster_embeddings(embeddings, min(dim, len(embeddings) -2))
98 | global_clusters, n_global_clusters = GMM_cluster( # (num, 2)
99 | reduced_embeddings_global, threshold
100 | )
101 |
102 | if verbose:
103 | logging.info(f"Global Clusters: {n_global_clusters}")
104 |
105 | all_clusters = [[] for _ in range(len(embeddings))]
106 | embedding_to_index = {tuple(embedding): idx for idx, embedding in enumerate(embeddings)}
107 | for i in tqdm(range(n_global_clusters)):
108 | global_cluster_embeddings_ = embeddings[
109 | np.array([i in gc for gc in global_clusters])
110 | ]
111 | if verbose:
112 | logging.info(
113 | f"Nodes in Global Cluster {i}: {len(global_cluster_embeddings_)}"
114 | )
115 | if len(global_cluster_embeddings_) == 0:
116 | continue
117 |
118 | # embedding indices
119 | indices = [
120 | embedding_to_index[tuple(embedding)]
121 | for embedding in global_cluster_embeddings_
122 | ]
123 |
124 | # update
125 | for idx in indices:
126 | all_clusters[idx].append(i)
127 |
128 | all_clusters = [np.array(cluster) for cluster in all_clusters]
129 |
130 | if verbose:
131 | logging.info(f"Total Clusters: {len(n_global_clusters)}")
132 | return all_clusters
133 |
134 |
135 | async def _handle_single_entity_extraction(
136 | record_attributes: list[str],
137 | chunk_key: str,
138 | ):
139 | if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
140 | return None
141 | # add this record as a node in the G
142 | entity_name = clean_str(record_attributes[1].upper())
143 | if not entity_name.strip():
144 | return None
145 | entity_type = clean_str(record_attributes[2].upper())
146 | entity_description = clean_str(record_attributes[3])
147 | entity_source_id = chunk_key
148 | return dict(
149 | entity_name=entity_name,
150 | entity_type=entity_type,
151 | description=entity_description,
152 | source_id=entity_source_id,
153 | )
154 |
155 |
156 | async def _handle_single_relationship_extraction(
157 | record_attributes: list[str],
158 | chunk_key: str,
159 | ):
160 | if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
161 | return None
162 | # add this record as edge
163 | source = clean_str(record_attributes[1].upper())
164 | target = clean_str(record_attributes[2].upper())
165 | edge_description = clean_str(record_attributes[3])
166 | edge_source_id = chunk_key
167 | weight = (
168 | float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
169 | )
170 | return dict(
171 | src_id=source,
172 | tgt_id=target,
173 | weight=weight,
174 | description=edge_description,
175 | source_id=edge_source_id,
176 | )
177 |
178 |
179 | class ClusteringAlgorithm(ABC):
180 | @abstractmethod
181 | def perform_clustering(self, embeddings: np.ndarray, **kwargs) -> List[List[int]]:
182 | pass
183 |
184 |
185 | class Hierarchical_Clustering(ClusteringAlgorithm):
186 | async def perform_clustering(
187 | self,
188 | entity_vdb: BaseVectorStorage,
189 | global_config: dict,
190 | entities: dict,
191 | layers: int = 50,
192 | max_length_in_cluster: int = 60000,
193 | tokenizer=tiktoken.get_encoding("cl100k_base"),
194 | reduction_dimension: int = 2,
195 | cluster_threshold: float = 0.1,
196 | verbose: bool = False,
197 | threshold: float = 0.98, # 0.99
198 | thredshold_change_rate: float = 0.05
199 | ) -> List[dict]:
200 | use_llm_func: callable = global_config["best_model_func"]
201 | # Get the embeddings from the nodes
202 | nodes = list(entities.values())
203 | embeddings = np.array([x["embedding"] for x in nodes])
204 |
205 | hierarchical_clusters = [nodes]
206 | pre_cluster_sparsity = 0.01
207 | for layer in range(layers):
208 | logging.info(f"############ Layer[{layer}] Clustering ############")
209 | # Perform the clustering
210 | clusters = perform_clustering(
211 | embeddings, dim=reduction_dimension, threshold=cluster_threshold
212 | )
213 | # Initialize an empty list to store the clusters of nodes
214 | node_clusters = []
215 | # Iterate over each unique label in the clusters
216 | unique_clusters = np.unique(np.concatenate(clusters))
217 | logging.info(f"[Clustered Label Num: {len(unique_clusters)} / Last Layer Total Entity Num: {len(nodes)}]")
218 | # calculate the number of nodes belong to each cluster
219 | cluster_sizes = Counter(np.concatenate(clusters))
220 | # calculate cluster sparsity
221 | cluster_sparsity = 1 - sum([x * (x - 1) for x in cluster_sizes.values()])/(len(nodes) * (len(nodes) - 1))
222 | cluster_sparsity_change_rate = (abs(cluster_sparsity - pre_cluster_sparsity) / (pre_cluster_sparsity + 1e-8))
223 | pre_cluster_sparsity = cluster_sparsity
224 | logging.info(f"[Cluster Sparsity: {round(cluster_sparsity, 4) * 100}%]")
225 | # stop if there will be no improvements on clustering
226 | if cluster_sparsity >= threshold:
227 | logging.info(f"[Stop Clustering at Layer{layer} with Cluster Sparsity {cluster_sparsity}]")
228 | break
229 | if cluster_sparsity_change_rate <= thredshold_change_rate:
230 | logging.info(f"[Stop Clustering at Layer{layer} with Cluster Sparsity Change Rate {round(cluster_sparsity_change_rate, 4) * 100}%]")
231 | break
232 | # summarize
233 | for label in unique_clusters:
234 | # Get the indices of the nodes that belong to this cluster
235 | indices = [i for i, cluster in enumerate(clusters) if label in cluster]
236 | # Add the corresponding nodes to the node_clusters list
237 | cluster_nodes = [nodes[i] for i in indices]
238 | # Base case: if the cluster only has one node, do not attempt to recluster it
239 | logging.info(f"[Label{str(int(label))} Size: {len(cluster_nodes)}]")
240 | if len(cluster_nodes) == 1:
241 | node_clusters += cluster_nodes
242 | continue
243 | # Calculate the total length of the text in the nodes
244 | total_length = sum(
245 | [len(tokenizer.encode(node["description"])) + len(tokenizer.encode(node["entity_name"])) for node in cluster_nodes]
246 | )
247 | base_discount = 0.8
248 | discount_times = 0
249 | # If the total length exceeds the maximum allowed length, reduce the node size
250 | while total_length > max_length_in_cluster:
251 | logging.info(
252 | f"Reducing cluster size with {base_discount * 100 * (base_discount**discount_times):.2f}% of entities"
253 | )
254 |
255 | # for node in cluster_nodes:
256 | # description = node["description"]
257 | # node['description'] = description[:int(len(description) * base_discount)]
258 |
259 | # Randomly select 80% of the nodes
260 | num_to_select = max(1, int(len(cluster_nodes) * base_discount)) # Ensure at least one node is selected
261 | cluster_nodes = random.sample(cluster_nodes, num_to_select)
262 |
263 | # Recalculate the total length
264 | total_length = sum(
265 | [len(tokenizer.encode(node["description"])) + len(tokenizer.encode(node["entity_name"])) for node in cluster_nodes]
266 | )
267 | discount_times += 1
268 | # summarize and generate new entities
269 | entity_description_list = [f"({x['entity_name']}, {x['description']})" for x in cluster_nodes]
270 | context_base_summarize = dict(
271 | tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
272 | record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
273 | completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
274 | meta_attribute_list=PROMPTS["META_ENTITY_TYPES"],
275 | entity_description_list=",".join(entity_description_list)
276 | )
277 | summarize_prompt = PROMPTS["summary_clusters"]
278 | hint_prompt = summarize_prompt.format(**context_base_summarize)
279 | summarize_result = await use_llm_func(hint_prompt)
280 | chunk_key = ""
281 | # resolve results
282 | records = split_string_by_multi_markers( # split entities from result --> list of entities
283 | summarize_result,
284 | [context_base_summarize["record_delimiter"], context_base_summarize["completion_delimiter"]],
285 | )
286 | maybe_nodes = defaultdict(list)
287 | maybe_edges = defaultdict(list)
288 | for record in records:
289 | record = re.search(r"\((.*)\)", record)
290 | if record is None:
291 | continue
292 | record = record.group(1)
293 | record_attributes = split_string_by_multi_markers( # split entity
294 | record, [context_base_summarize["tuple_delimiter"]]
295 | )
296 | if_entities = await _handle_single_entity_extraction( # get the name, type, desc, source_id of entity--> dict
297 | record_attributes, chunk_key
298 | )
299 | if if_entities is not None:
300 | maybe_nodes[if_entities["entity_name"]].append(if_entities)
301 | continue
302 |
303 | if_relation = await _handle_single_relationship_extraction(
304 | record_attributes, chunk_key
305 | )
306 | if if_relation is not None:
307 | maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
308 | if_relation
309 | )
310 | # fetch all entities from results
311 | entity_results = (dict(maybe_nodes), dict(maybe_edges))
312 | all_entities_relations = {}
313 | for item in entity_results:
314 | for k, v in item.items():
315 | value = v[0]
316 | all_entities_relations[k] = v[0]
317 | # fetch embeddings
318 | entity_discriptions = [v["description"] for k, v in all_entities_relations.items()]
319 | entity_sequence_embeddings = []
320 | embeddings_batch_size = 64
321 | num_embeddings_batches = (len(entity_discriptions) + embeddings_batch_size - 1) // embeddings_batch_size
322 | for i in range(num_embeddings_batches):
323 | start_index = i * embeddings_batch_size
324 | end_index = min((i + 1) * embeddings_batch_size, len(entity_discriptions))
325 | batch = entity_discriptions[start_index:end_index]
326 | result = await entity_vdb.embedding_func(batch)
327 | entity_sequence_embeddings.extend(result)
328 | entity_embeddings = entity_sequence_embeddings
329 | for (k, v), x in zip(all_entities_relations.items(), entity_embeddings):
330 | value = v
331 | value["embedding"] = x
332 | all_entities_relations[k] = value
333 | # append the attribute entities of current clustered set to results
334 | all_entities_relations = [v for k, v in all_entities_relations.items()]
335 | node_clusters += all_entities_relations
336 | hierarchical_clusters.append(node_clusters)
337 | # update nodes to be clustered in the next layer
338 | nodes = copy.deepcopy([x for x in node_clusters if "entity_name" in x.keys()])
339 | # filter the duplicate entities
340 | seen = set()
341 | unique_nodes = []
342 | for item in nodes:
343 | entity_name = item['entity_name']
344 | if entity_name not in seen:
345 | seen.add(entity_name)
346 | unique_nodes.append(item)
347 | nodes = unique_nodes
348 | embeddings = np.array([x["embedding"] for x in unique_nodes])
349 | # stop if the number of deduplicated cluster is too small
350 | if len(embeddings) <= 2:
351 | logging.info(f"[Stop Clustering at Layer{layer} with entity num {len(embeddings)}]")
352 | break
353 | return hierarchical_clusters
--------------------------------------------------------------------------------
/hirag/_llm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from openai import AsyncOpenAI, AsyncAzureOpenAI, APIConnectionError, RateLimitError
4 |
5 | from tenacity import (
6 | retry,
7 | stop_after_attempt,
8 | wait_exponential,
9 | retry_if_exception_type,
10 | )
11 | import os
12 |
13 | from ._utils import compute_args_hash, wrap_embedding_func_with_attrs
14 | from .base import BaseKVStorage
15 |
16 | global_openai_async_client = None
17 | global_azure_openai_async_client = None
18 |
19 |
20 | def get_openai_async_client_instance():
21 | global global_openai_async_client
22 | if global_openai_async_client is None:
23 | global_openai_async_client = AsyncOpenAI()
24 | return global_openai_async_client
25 |
26 |
27 | def get_azure_openai_async_client_instance():
28 | global global_azure_openai_async_client
29 | if global_azure_openai_async_client is None:
30 | global_azure_openai_async_client = AsyncAzureOpenAI()
31 | return global_azure_openai_async_client
32 |
33 |
34 | @retry(
35 | stop=stop_after_attempt(5),
36 | wait=wait_exponential(multiplier=1, min=4, max=10),
37 | retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
38 | )
39 | async def openai_complete_if_cache(
40 | model, prompt, system_prompt=None, history_messages=[], **kwargs
41 | ) -> str:
42 | openai_async_client = get_openai_async_client_instance()
43 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
44 | messages = []
45 | if system_prompt:
46 | messages.append({"role": "system", "content": system_prompt})
47 | messages.extend(history_messages)
48 | messages.append({"role": "user", "content": prompt})
49 | if hashing_kv is not None:
50 | args_hash = compute_args_hash(model, messages)
51 | if_cache_return = await hashing_kv.get_by_id(args_hash)
52 | if if_cache_return is not None:
53 | return if_cache_return["return"]
54 |
55 | response = await openai_async_client.chat.completions.create(
56 | model=model, messages=messages, **kwargs
57 | )
58 |
59 | if hashing_kv is not None:
60 | await hashing_kv.upsert(
61 | {args_hash: {"return": response.choices[0].message.content, "model": model}}
62 | )
63 | await hashing_kv.index_done_callback()
64 | return response.choices[0].message.content
65 |
66 |
67 | async def gpt_4o_complete(
68 | prompt, system_prompt=None, history_messages=[], **kwargs
69 | ) -> str:
70 | return await openai_complete_if_cache(
71 | "gpt-4o",
72 | prompt,
73 | system_prompt=system_prompt,
74 | history_messages=history_messages,
75 | **kwargs,
76 | )
77 |
78 | async def gpt_35_turbo_complete(
79 | prompt, system_prompt=None, history_messages=[], **kwargs
80 | ) -> str:
81 | return await openai_complete_if_cache(
82 | "gpt-3.5-turbo",
83 | prompt,
84 | system_prompt=system_prompt,
85 | history_messages=history_messages,
86 | **kwargs,
87 | )
88 |
89 |
90 | async def gpt_4o_mini_complete(
91 | prompt, system_prompt=None, history_messages=[], **kwargs
92 | ) -> str:
93 | return await openai_complete_if_cache(
94 | "gpt-4o-mini",
95 | prompt,
96 | system_prompt=system_prompt,
97 | history_messages=history_messages,
98 | **kwargs,
99 | )
100 |
101 |
102 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
103 | @retry(
104 | stop=stop_after_attempt(5),
105 | wait=wait_exponential(multiplier=1, min=4, max=10),
106 | retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
107 | )
108 | async def openai_embedding(texts: list[str]) -> np.ndarray:
109 | openai_async_client = get_openai_async_client_instance()
110 | response = await openai_async_client.embeddings.create(
111 | model="text-embedding-3-small", input=texts, encoding_format="float"
112 | )
113 | return np.array([dp.embedding for dp in response.data])
114 |
115 |
116 | @retry(
117 | stop=stop_after_attempt(3),
118 | wait=wait_exponential(multiplier=1, min=4, max=10),
119 | retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
120 | )
121 | async def azure_openai_complete_if_cache(
122 | deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs
123 | ) -> str:
124 | azure_openai_client = get_azure_openai_async_client_instance()
125 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
126 | messages = []
127 | if system_prompt:
128 | messages.append({"role": "system", "content": system_prompt})
129 | messages.extend(history_messages)
130 | messages.append({"role": "user", "content": prompt})
131 | if hashing_kv is not None:
132 | args_hash = compute_args_hash(deployment_name, messages)
133 | if_cache_return = await hashing_kv.get_by_id(args_hash)
134 | if if_cache_return is not None:
135 | return if_cache_return["return"]
136 |
137 | response = await azure_openai_client.chat.completions.create(
138 | model=deployment_name, messages=messages, **kwargs
139 | )
140 |
141 | if hashing_kv is not None:
142 | await hashing_kv.upsert(
143 | {
144 | args_hash: {
145 | "return": response.choices[0].message.content,
146 | "model": deployment_name,
147 | }
148 | }
149 | )
150 | await hashing_kv.index_done_callback()
151 | return response.choices[0].message.content
152 |
153 |
154 | async def azure_gpt_4o_complete(
155 | prompt, system_prompt=None, history_messages=[], **kwargs
156 | ) -> str:
157 | return await azure_openai_complete_if_cache(
158 | "gpt-4o",
159 | prompt,
160 | system_prompt=system_prompt,
161 | history_messages=history_messages,
162 | **kwargs,
163 | )
164 |
165 |
166 | async def azure_gpt_4o_mini_complete(
167 | prompt, system_prompt=None, history_messages=[], **kwargs
168 | ) -> str:
169 | return await azure_openai_complete_if_cache(
170 | "gpt-4o-mini",
171 | prompt,
172 | system_prompt=system_prompt,
173 | history_messages=history_messages,
174 | **kwargs,
175 | )
176 |
177 |
178 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
179 | @retry(
180 | stop=stop_after_attempt(3),
181 | wait=wait_exponential(multiplier=1, min=4, max=10),
182 | retry=retry_if_exception_type((RateLimitError, APIConnectionError)),
183 | )
184 | async def azure_openai_embedding(texts: list[str]) -> np.ndarray:
185 | azure_openai_client = get_azure_openai_async_client_instance()
186 | response = await azure_openai_client.embeddings.create(
187 | model="text-embedding-3-small", input=texts, encoding_format="float"
188 | )
189 | return np.array([dp.embedding for dp in response.data])
190 |
--------------------------------------------------------------------------------
/hirag/_splitter.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union, Literal
2 |
3 | class SeparatorSplitter:
4 | def __init__(
5 | self,
6 | separators: Optional[List[List[int]]] = None,
7 | keep_separator: Union[bool, Literal["start", "end"]] = "end",
8 | chunk_size: int = 4000,
9 | chunk_overlap: int = 200,
10 | length_function: callable = len,
11 | ):
12 | self._separators = separators or []
13 | self._keep_separator = keep_separator
14 | self._chunk_size = chunk_size
15 | self._chunk_overlap = chunk_overlap
16 | self._length_function = length_function
17 |
18 | def split_tokens(self, tokens: List[int]) -> List[List[int]]:
19 | splits = self._split_tokens_with_separators(tokens)
20 | return self._merge_splits(splits)
21 |
22 | def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
23 | splits = []
24 | current_split = []
25 | i = 0
26 | while i < len(tokens):
27 | separator_found = False
28 | for separator in self._separators:
29 | if tokens[i:i+len(separator)] == separator:
30 | if self._keep_separator in [True, "end"]:
31 | current_split.extend(separator)
32 | if current_split:
33 | splits.append(current_split)
34 | current_split = []
35 | if self._keep_separator == "start":
36 | current_split.extend(separator)
37 | i += len(separator)
38 | separator_found = True
39 | break
40 | if not separator_found:
41 | current_split.append(tokens[i])
42 | i += 1
43 | if current_split:
44 | splits.append(current_split)
45 | return [s for s in splits if s]
46 |
47 | def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
48 | if not splits:
49 | return []
50 |
51 | merged_splits = []
52 | current_chunk = []
53 |
54 | for split in splits:
55 | if not current_chunk:
56 | current_chunk = split
57 | elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
58 | current_chunk.extend(split)
59 | else:
60 | merged_splits.append(current_chunk)
61 | current_chunk = split
62 |
63 | if current_chunk:
64 | merged_splits.append(current_chunk)
65 |
66 | if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
67 | return self._split_chunk(merged_splits[0])
68 |
69 | if self._chunk_overlap > 0:
70 | return self._enforce_overlap(merged_splits)
71 |
72 | return merged_splits
73 |
74 | def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
75 | result = []
76 | for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
77 | new_chunk = chunk[i:i + self._chunk_size]
78 | if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
79 | result.append(new_chunk)
80 | return result
81 |
82 | def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
83 | result = []
84 | for i, chunk in enumerate(chunks):
85 | if i == 0:
86 | result.append(chunk)
87 | else:
88 | overlap = chunks[i-1][-self._chunk_overlap:]
89 | new_chunk = overlap + chunk
90 | if self._length_function(new_chunk) > self._chunk_size:
91 | new_chunk = new_chunk[:self._chunk_size]
92 | result.append(new_chunk)
93 | return result
94 |
95 |
--------------------------------------------------------------------------------
/hirag/_storage/__init__.py:
--------------------------------------------------------------------------------
1 | from .gdb_networkx import NetworkXStorage
2 | from .gdb_neo4j import Neo4jStorage
3 | from .vdb_nanovectordb import NanoVectorDBStorage
4 | from .kv_json import JsonKVStorage
5 |
--------------------------------------------------------------------------------
/hirag/_storage/gdb_neo4j.py:
--------------------------------------------------------------------------------
1 | import json
2 | import asyncio
3 | from collections import defaultdict
4 | from neo4j import AsyncGraphDatabase
5 | from dataclasses import dataclass
6 | from typing import Union
7 | from ..base import BaseGraphStorage, SingleCommunitySchema
8 | from .._utils import logger
9 | from ..prompt import GRAPH_FIELD_SEP
10 |
11 | neo4j_lock = asyncio.Lock()
12 |
13 |
14 | def make_path_idable(path):
15 | return path.replace(".", "_").replace("/", "__").replace("-", "_")
16 |
17 |
18 | @dataclass
19 | class Neo4jStorage(BaseGraphStorage):
20 | def __post_init__(self):
21 | self.neo4j_url = self.global_config["addon_params"].get("neo4j_url", None)
22 | self.neo4j_auth = self.global_config["addon_params"].get("neo4j_auth", None)
23 | self.namespace = (
24 | f"{make_path_idable(self.global_config['working_dir'])}__{self.namespace}"
25 | )
26 | logger.info(f"Using the label {self.namespace} for Neo4j as identifier")
27 | if self.neo4j_url is None or self.neo4j_auth is None:
28 | raise ValueError("Missing neo4j_url or neo4j_auth in addon_params")
29 | self.async_driver = AsyncGraphDatabase.driver(
30 | self.neo4j_url, auth=self.neo4j_auth
31 | )
32 |
33 | # async def create_database(self):
34 | # async with self.async_driver.session() as session:
35 | # try:
36 | # constraints = await session.run("SHOW CONSTRAINTS")
37 | # # TODO I don't know why CREATE CONSTRAINT IF NOT EXISTS still trigger error
38 | # # so have to check if the constrain exists
39 | # constrain_exists = False
40 |
41 | # async for record in constraints:
42 | # if (
43 | # self.namespace in record["labelsOrTypes"]
44 | # and "id" in record["properties"]
45 | # and record["type"] == "UNIQUENESS"
46 | # ):
47 | # constrain_exists = True
48 | # break
49 | # if not constrain_exists:
50 | # await session.run(
51 | # f"CREATE CONSTRAINT FOR (n:{self.namespace}) REQUIRE n.id IS UNIQUE"
52 | # )
53 | # logger.info(f"Add constraint for namespace: {self.namespace}")
54 |
55 | # except Exception as e:
56 | # logger.error(f"Error accessing or setting up the database: {str(e)}")
57 | # raise
58 |
59 | async def _init_workspace(self):
60 | await self.async_driver.verify_authentication()
61 | await self.async_driver.verify_connectivity()
62 | # TODOLater: create database if not exists always cause an error when async
63 | # await self.create_database()
64 |
65 | async def index_start_callback(self):
66 | logger.info("Init Neo4j workspace")
67 | await self._init_workspace()
68 |
69 | async def has_node(self, node_id: str) -> bool:
70 | async with self.async_driver.session() as session:
71 | result = await session.run(
72 | f"MATCH (n:{self.namespace}) WHERE n.id = $node_id RETURN COUNT(n) > 0 AS exists",
73 | node_id=node_id,
74 | )
75 | record = await result.single()
76 | return record["exists"] if record else False
77 |
78 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
79 | async with self.async_driver.session() as session:
80 | result = await session.run(
81 | f"MATCH (s:{self.namespace})-[r]->(t:{self.namespace}) "
82 | "WHERE s.id = $source_id AND t.id = $target_id "
83 | "RETURN COUNT(r) > 0 AS exists",
84 | source_id=source_node_id,
85 | target_id=target_node_id,
86 | )
87 | record = await result.single()
88 | return record["exists"] if record else False
89 |
90 | async def node_degree(self, node_id: str) -> int:
91 | async with self.async_driver.session() as session:
92 | result = await session.run(
93 | f"MATCH (n:{self.namespace}) WHERE n.id = $node_id "
94 | f"RETURN COUNT {{(n)-[]-(:{self.namespace})}} AS degree",
95 | node_id=node_id,
96 | )
97 | record = await result.single()
98 | return record["degree"] if record else 0
99 |
100 | async def edge_degree(self, src_id: str, tgt_id: str) -> int:
101 | async with self.async_driver.session() as session:
102 | result = await session.run(
103 | f"MATCH (s:{self.namespace}), (t:{self.namespace}) "
104 | "WHERE s.id = $src_id AND t.id = $tgt_id "
105 | f"RETURN COUNT {{(s)-[]-(:{self.namespace})}} + COUNT {{(t)-[]-(:{self.namespace})}} AS degree",
106 | src_id=src_id,
107 | tgt_id=tgt_id,
108 | )
109 | record = await result.single()
110 | return record["degree"] if record else 0
111 |
112 | async def get_node(self, node_id: str) -> Union[dict, None]:
113 | async with self.async_driver.session() as session:
114 | result = await session.run(
115 | f"MATCH (n:{self.namespace}) WHERE n.id = $node_id RETURN properties(n) AS node_data",
116 | node_id=node_id,
117 | )
118 | record = await result.single()
119 | raw_node_data = record["node_data"] if record else None
120 | if raw_node_data is None:
121 | return None
122 | raw_node_data["clusters"] = json.dumps(
123 | [
124 | {
125 | "level": index,
126 | "cluster": cluster_id,
127 | }
128 | for index, cluster_id in enumerate(
129 | raw_node_data.get("communityIds", [])
130 | )
131 | ]
132 | )
133 | return raw_node_data
134 |
135 | async def get_edge(
136 | self, source_node_id: str, target_node_id: str
137 | ) -> Union[dict, None]:
138 | async with self.async_driver.session() as session:
139 | result = await session.run(
140 | f"MATCH (s:{self.namespace})-[r]->(t:{self.namespace}) "
141 | "WHERE s.id = $source_id AND t.id = $target_id "
142 | "RETURN properties(r) AS edge_data",
143 | source_id=source_node_id,
144 | target_id=target_node_id,
145 | )
146 | record = await result.single()
147 | return record["edge_data"] if record else None
148 |
149 | async def get_node_edges(
150 | self, source_node_id: str
151 | ) -> Union[list[tuple[str, str]], None]:
152 | async with self.async_driver.session() as session:
153 | result = await session.run(
154 | f"MATCH (s:{self.namespace})-[r]->(t:{self.namespace}) WHERE s.id = $source_id "
155 | "RETURN s.id AS source, t.id AS target",
156 | source_id=source_node_id,
157 | )
158 | edges = []
159 | async for record in result:
160 | edges.append((record["source"], record["target"]))
161 | return edges
162 |
163 | async def upsert_node(self, node_id: str, node_data: dict[str, str]):
164 | node_type = node_data.get("entity_type", "UNKNOWN").strip('"')
165 | async with self.async_driver.session() as session:
166 | await session.run(
167 | f"MERGE (n:{self.namespace}:{node_type} {{id: $node_id}}) "
168 | "SET n += $node_data",
169 | node_id=node_id,
170 | node_data=node_data,
171 | )
172 |
173 | async def upsert_edge(
174 | self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
175 | ):
176 | edge_data.setdefault("weight", 0.0)
177 | async with self.async_driver.session() as session:
178 | await session.run(
179 | f"MATCH (s:{self.namespace}), (t:{self.namespace}) "
180 | "WHERE s.id = $source_id AND t.id = $target_id "
181 | "MERGE (s)-[r:RELATED]->(t) " # Added relationship type 'RELATED'
182 | "SET r += $edge_data",
183 | source_id=source_node_id,
184 | target_id=target_node_id,
185 | edge_data=edge_data,
186 | )
187 |
188 | async def clustering(self, algorithm: str):
189 | if algorithm != "leiden":
190 | raise ValueError(
191 | f"Clustering algorithm {algorithm} not supported in Neo4j implementation"
192 | )
193 |
194 | random_seed = self.global_config["graph_cluster_seed"]
195 | max_level = self.global_config["max_graph_cluster_size"]
196 | async with self.async_driver.session() as session:
197 | try:
198 | # Project the graph with undirected relationships
199 | await session.run(
200 | f"""
201 | CALL gds.graph.project(
202 | 'graph_{self.namespace}',
203 | ['{self.namespace}'],
204 | {{
205 | RELATED: {{
206 | orientation: 'UNDIRECTED',
207 | properties: ['weight']
208 | }}
209 | }}
210 | )
211 | """
212 | )
213 |
214 | # Run Leiden algorithm
215 | result = await session.run(
216 | f"""
217 | CALL gds.leiden.write(
218 | 'graph_{self.namespace}',
219 | {{
220 | writeProperty: 'communityIds',
221 | includeIntermediateCommunities: True,
222 | relationshipWeightProperty: "weight",
223 | maxLevels: {max_level},
224 | tolerance: 0.0001,
225 | gamma: 1.0,
226 | theta: 0.01,
227 | randomSeed: {random_seed}
228 | }}
229 | )
230 | YIELD communityCount, modularities;
231 | """
232 | )
233 | result = await result.single()
234 | community_count: int = result["communityCount"]
235 | modularities = result["modularities"]
236 | logger.info(
237 | f"Performed graph clustering with {community_count} communities and modularities {modularities}"
238 | )
239 | finally:
240 | # Drop the projected graph
241 | await session.run(f"CALL gds.graph.drop('graph_{self.namespace}')")
242 |
243 | async def community_schema(self) -> dict[str, SingleCommunitySchema]:
244 | results = defaultdict(
245 | lambda: dict(
246 | level=None,
247 | title=None,
248 | edges=set(),
249 | nodes=set(),
250 | chunk_ids=set(),
251 | occurrence=0.0,
252 | sub_communities=[],
253 | )
254 | )
255 |
256 | async with self.async_driver.session() as session:
257 | # Fetch community data
258 | result = await session.run(
259 | f"""
260 | MATCH (n:{self.namespace})
261 | WITH n, n.communityIds AS communityIds, [(n)-[]-(m:{self.namespace}) | m.id] AS connected_nodes
262 | RETURN n.id AS node_id, n.source_id AS source_id,
263 | communityIds AS cluster_key,
264 | connected_nodes
265 | """
266 | )
267 |
268 | # records = await result.fetch()
269 |
270 | max_num_ids = 0
271 | async for record in result:
272 | for index, c_id in enumerate(record["cluster_key"]):
273 | node_id = str(record["node_id"])
274 | source_id = record["source_id"]
275 | level = index
276 | cluster_key = str(c_id)
277 | connected_nodes = record["connected_nodes"]
278 |
279 | results[cluster_key]["level"] = level
280 | results[cluster_key]["title"] = f"Cluster {cluster_key}"
281 | results[cluster_key]["nodes"].add(node_id)
282 | results[cluster_key]["edges"].update(
283 | [
284 | tuple(sorted([node_id, str(connected)]))
285 | for connected in connected_nodes
286 | if connected != node_id
287 | ]
288 | )
289 | chunk_ids = source_id.split(GRAPH_FIELD_SEP)
290 | results[cluster_key]["chunk_ids"].update(chunk_ids)
291 | max_num_ids = max(
292 | max_num_ids, len(results[cluster_key]["chunk_ids"])
293 | )
294 |
295 | # Process results
296 | for k, v in results.items():
297 | v["edges"] = [list(e) for e in v["edges"]]
298 | v["nodes"] = list(v["nodes"])
299 | v["chunk_ids"] = list(v["chunk_ids"])
300 | v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
301 |
302 | # Compute sub-communities (this is a simplified approach)
303 | for cluster in results.values():
304 | cluster["sub_communities"] = [
305 | sub_key
306 | for sub_key, sub_cluster in results.items()
307 | if sub_cluster["level"] > cluster["level"]
308 | and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"]))
309 | ]
310 |
311 | return dict(results)
312 |
313 | async def index_done_callback(self):
314 | await self.async_driver.close()
315 |
316 | async def _debug_delete_all_node_edges(self):
317 | async with self.async_driver.session() as session:
318 | try:
319 | # Delete all relationships in the namespace
320 | await session.run(f"MATCH (n:{self.namespace})-[r]-() DELETE r")
321 |
322 | # Delete all nodes in the namespace
323 | await session.run(f"MATCH (n:{self.namespace}) DELETE n")
324 |
325 | logger.info(
326 | f"All nodes and edges in namespace '{self.namespace}' have been deleted."
327 | )
328 | except Exception as e:
329 | logger.error(f"Error deleting nodes and edges: {str(e)}")
330 | raise
331 |
--------------------------------------------------------------------------------
/hirag/_storage/gdb_networkx.py:
--------------------------------------------------------------------------------
1 | import html
2 | import json
3 | import os
4 | from collections import defaultdict
5 | from dataclasses import dataclass
6 | from typing import Any, Union, cast
7 | import networkx as nx
8 | import numpy as np
9 |
10 | from .._utils import logger
11 | from ..base import (
12 | BaseGraphStorage,
13 | SingleCommunitySchema,
14 | )
15 | from ..prompt import GRAPH_FIELD_SEP
16 |
17 |
18 | @dataclass
19 | class NetworkXStorage(BaseGraphStorage):
20 | @staticmethod
21 | def load_nx_graph(file_name) -> nx.Graph:
22 | if os.path.exists(file_name):
23 | return nx.read_graphml(file_name)
24 | return None
25 |
26 | @staticmethod
27 | def write_nx_graph(graph: nx.Graph, file_name):
28 | logger.info(
29 | f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
30 | )
31 | nx.write_graphml(graph, file_name)
32 |
33 | @staticmethod
34 | def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
35 | """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
36 | Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
37 | """
38 | from graspologic.utils import largest_connected_component
39 |
40 | graph = graph.copy()
41 | graph = cast(nx.Graph, largest_connected_component(graph))
42 | node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
43 | graph = nx.relabel_nodes(graph, node_mapping)
44 | return NetworkXStorage._stabilize_graph(graph)
45 |
46 | @staticmethod
47 | def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
48 | """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
49 | Ensure an undirected graph with the same relationships will always be read the same way.
50 | """
51 | fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
52 |
53 | sorted_nodes = graph.nodes(data=True)
54 | sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
55 |
56 | fixed_graph.add_nodes_from(sorted_nodes)
57 | edges = list(graph.edges(data=True))
58 |
59 | if not graph.is_directed():
60 |
61 | def _sort_source_target(edge):
62 | source, target, edge_data = edge
63 | if source > target:
64 | temp = source
65 | source = target
66 | target = temp
67 | return source, target, edge_data
68 |
69 | edges = [_sort_source_target(edge) for edge in edges]
70 |
71 | def _get_edge_key(source: Any, target: Any) -> str:
72 | return f"{source} -> {target}"
73 |
74 | edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
75 |
76 | fixed_graph.add_edges_from(edges)
77 | return fixed_graph
78 |
79 | def __post_init__(self):
80 | self._graphml_xml_file = os.path.join(
81 | self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
82 | )
83 | preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
84 | if preloaded_graph is not None:
85 | logger.info(
86 | f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
87 | )
88 | self._graph = preloaded_graph or nx.Graph()
89 | self._clustering_algorithms = {
90 | "leiden": self._leiden_clustering,
91 | }
92 | self._node_embed_algorithms = {
93 | "node2vec": self._node2vec_embed,
94 | }
95 |
96 | async def index_done_callback(self):
97 | NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
98 |
99 | async def has_node(self, node_id: str) -> bool:
100 | return self._graph.has_node(node_id)
101 |
102 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
103 | return self._graph.has_edge(source_node_id, target_node_id)
104 |
105 | async def get_node(self, node_id: str) -> Union[dict, None]:
106 | return self._graph.nodes.get(node_id)
107 |
108 | async def node_degree(self, node_id: str) -> int:
109 | # [numberchiffre]: node_id not part of graph returns `DegreeView({})` instead of 0
110 | return self._graph.degree(node_id) if self._graph.has_node(node_id) else 0
111 |
112 | async def edge_degree(self, src_id: str, tgt_id: str) -> int:
113 | return (self._graph.degree(src_id) if self._graph.has_node(src_id) else 0) + (
114 | self._graph.degree(tgt_id) if self._graph.has_node(tgt_id) else 0
115 | )
116 |
117 | async def get_edge(
118 | self, source_node_id: str, target_node_id: str
119 | ) -> Union[dict, None]:
120 | return self._graph.edges.get((source_node_id, target_node_id))
121 |
122 | async def get_node_edges(self, source_node_id: str):
123 | if self._graph.has_node(source_node_id):
124 | return list(self._graph.edges(source_node_id))
125 | return None
126 |
127 | async def upsert_node(self, node_id: str, node_data: dict[str, str]):
128 | self._graph.add_node(node_id, **node_data)
129 |
130 | async def upsert_edge(
131 | self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
132 | ):
133 | self._graph.add_edge(source_node_id, target_node_id, **edge_data)
134 |
135 | async def clustering(self, algorithm: str):
136 | if algorithm not in self._clustering_algorithms:
137 | raise ValueError(f"Clustering algorithm {algorithm} not supported")
138 | await self._clustering_algorithms[algorithm]()
139 |
140 | async def community_schema(self) -> dict[str, SingleCommunitySchema]:
141 | results = defaultdict(
142 | lambda: dict(
143 | level=None,
144 | title=None,
145 | edges=set(),
146 | nodes=set(),
147 | chunk_ids=set(),
148 | occurrence=0.0,
149 | sub_communities=[],
150 | )
151 | )
152 | max_num_ids = 0
153 | levels = defaultdict(set)
154 | for node_id, node_data in self._graph.nodes(data=True):
155 | if "clusters" not in node_data:
156 | continue
157 | clusters = json.loads(node_data["clusters"])
158 | this_node_edges = self._graph.edges(node_id)
159 |
160 | for cluster in clusters:
161 | level = cluster["level"]
162 | cluster_key = str(cluster["cluster"])
163 | levels[level].add(cluster_key)
164 | results[cluster_key]["level"] = level
165 | results[cluster_key]["title"] = f"Cluster {cluster_key}"
166 | results[cluster_key]["nodes"].add(node_id)
167 | results[cluster_key]["edges"].update(
168 | [tuple(sorted(e)) for e in this_node_edges]
169 | )
170 | results[cluster_key]["chunk_ids"].update(
171 | node_data["source_id"].split(GRAPH_FIELD_SEP)
172 | )
173 | max_num_ids = max(max_num_ids, len(results[cluster_key]["chunk_ids"]))
174 |
175 | ordered_levels = sorted(levels.keys())
176 | for i, curr_level in enumerate(ordered_levels[:-1]):
177 | next_level = ordered_levels[i + 1]
178 | this_level_comms = levels[curr_level]
179 | next_level_comms = levels[next_level]
180 | # compute the sub-communities by nodes intersection
181 | for comm in this_level_comms:
182 | results[comm]["sub_communities"] = [
183 | c
184 | for c in next_level_comms
185 | if results[c]["nodes"].issubset(results[comm]["nodes"])
186 | ]
187 |
188 | for k, v in results.items():
189 | v["edges"] = list(v["edges"])
190 | v["edges"] = [list(e) for e in v["edges"]]
191 | v["nodes"] = list(v["nodes"])
192 | v["chunk_ids"] = list(v["chunk_ids"])
193 | v["occurrence"] = len(v["chunk_ids"]) / max_num_ids
194 | return dict(results)
195 |
196 | def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
197 | for node_id, clusters in cluster_data.items():
198 | self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
199 |
200 | async def _leiden_clustering(self):
201 | from graspologic.partition import hierarchical_leiden
202 | """
203 | It uses the hierarchical_leiden function from the graspologic library
204 | The Leiden algorithm is used in the HiRAG.ainsert method
205 | """
206 | graph = NetworkXStorage.stable_largest_connected_component(self._graph)
207 | community_mapping = hierarchical_leiden(
208 | graph,
209 | max_cluster_size=self.global_config["max_graph_cluster_size"],
210 | random_seed=self.global_config["graph_cluster_seed"],
211 | )
212 |
213 | node_communities: dict[str, list[dict[str, str]]] = defaultdict(list)
214 | __levels = defaultdict(set)
215 | for partition in community_mapping:
216 | level_key = partition.level
217 | cluster_id = partition.cluster
218 | node_communities[partition.node].append(
219 | {"level": level_key, "cluster": cluster_id}
220 | )
221 | __levels[level_key].add(cluster_id)
222 | node_communities = dict(node_communities)
223 | __levels = {k: len(v) for k, v in __levels.items()}
224 | logger.info(f"Each level has communities: {dict(__levels)}")
225 | self._cluster_data_to_subgraphs(node_communities)
226 |
227 | async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
228 | if algorithm not in self._node_embed_algorithms:
229 | raise ValueError(f"Node embedding algorithm {algorithm} not supported")
230 | return await self._node_embed_algorithms[algorithm]()
231 |
232 | async def _node2vec_embed(self):
233 | from graspologic import embed
234 |
235 | embeddings, nodes = embed.node2vec_embed(
236 | self._graph,
237 | **self.global_config["node2vec_params"],
238 | )
239 |
240 | nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
241 | return embeddings, nodes_ids
242 |
--------------------------------------------------------------------------------
/hirag/_storage/kv_json.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 |
4 | from .._utils import load_json, logger, write_json
5 | from ..base import (
6 | BaseKVStorage,
7 | )
8 |
9 |
10 | @dataclass
11 | class JsonKVStorage(BaseKVStorage):
12 | def __post_init__(self):
13 | working_dir = self.global_config["working_dir"]
14 | self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
15 | self._data = load_json(self._file_name) or {}
16 | logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
17 |
18 | async def all_keys(self) -> list[str]:
19 | return list(self._data.keys())
20 |
21 | async def index_done_callback(self):
22 | write_json(self._data, self._file_name)
23 |
24 | async def get_by_id(self, id):
25 | return self._data.get(id, None)
26 |
27 | async def get_by_ids(self, ids, fields=None):
28 | if fields is None:
29 | return [self._data.get(id, None) for id in ids]
30 | return [
31 | (
32 | {k: v for k, v in self._data[id].items() if k in fields}
33 | if self._data.get(id, None)
34 | else None
35 | )
36 | for id in ids
37 | ]
38 |
39 | async def filter_keys(self, data: list[str]) -> set[str]:
40 | return set([s for s in data if s not in self._data])
41 |
42 | async def upsert(self, data: dict[str, dict]):
43 | self._data.update(data)
44 |
45 | async def drop(self):
46 | self._data = {}
47 |
--------------------------------------------------------------------------------
/hirag/_storage/vdb_nanovectordb.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from dataclasses import dataclass
4 | import numpy as np
5 | from nano_vectordb import NanoVectorDB
6 |
7 | from .._utils import logger
8 | from ..base import BaseVectorStorage
9 |
10 |
11 | @dataclass
12 | class NanoVectorDBStorage(BaseVectorStorage):
13 | cosine_better_than_threshold: float = 0.2
14 |
15 | def __post_init__(self):
16 |
17 | self._client_file_name = os.path.join(
18 | self.global_config["working_dir"], f"vdb_{self.namespace}.json"
19 | )
20 | self._max_batch_size = self.global_config["embedding_batch_num"]
21 | self._client = NanoVectorDB(
22 | self.embedding_func.embedding_dim, storage_file=self._client_file_name
23 | )
24 | self.cosine_better_than_threshold = self.global_config.get(
25 | "query_better_than_threshold", self.cosine_better_than_threshold
26 | )
27 |
28 | async def upsert(self, data: dict[str, dict]):
29 | logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
30 | if not len(data):
31 | logger.warning("You insert an empty data to vector DB")
32 | return []
33 | list_data = [
34 | {
35 | "__id__": k,
36 | **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
37 | }
38 | for k, v in data.items()
39 | ]
40 | contents = [v["content"] for v in data.values()]
41 | batches = [
42 | contents[i : i + self._max_batch_size]
43 | for i in range(0, len(contents), self._max_batch_size)
44 | ]
45 | embeddings_list = await asyncio.gather(
46 | *[self.embedding_func(batch) for batch in batches]
47 | )
48 | embeddings = np.concatenate(embeddings_list)
49 | for i, d in enumerate(list_data):
50 | d["__vector__"] = embeddings[i]
51 | results = self._client.upsert(datas=list_data)
52 | return results
53 |
54 | async def query(self, query: str, top_k=5):
55 | embedding = await self.embedding_func([query])
56 | embedding = embedding[0]
57 | results = self._client.query(
58 | query=embedding,
59 | top_k=top_k,
60 | better_than_threshold=self.cosine_better_than_threshold,
61 | )
62 | results = [
63 | {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
64 | ]
65 | return results
66 |
67 | async def index_done_callback(self):
68 | self._client.save()
69 |
--------------------------------------------------------------------------------
/hirag/_utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import html
3 | import json
4 | import logging
5 | import os
6 | import re
7 | import numbers
8 | from dataclasses import dataclass
9 | from functools import wraps
10 | from hashlib import md5
11 | from typing import Any, Union
12 |
13 | import numpy as np
14 | import tiktoken
15 |
16 | logger = logging.getLogger("HiRAG")
17 | ENCODER = None
18 |
19 | def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
20 | try:
21 | # If there is already an event loop, use it.
22 | loop = asyncio.get_event_loop()
23 | except RuntimeError:
24 | # If in a sub-thread, create a new event loop.
25 | logger.info("Creating a new event loop in a sub-thread.")
26 | loop = asyncio.new_event_loop()
27 | asyncio.set_event_loop(loop)
28 | return loop
29 |
30 |
31 | def extract_first_complete_json(s: str):
32 | """Extract the first complete JSON object from the string using a stack to track braces."""
33 | stack = []
34 | first_json_start = None
35 |
36 | for i, char in enumerate(s):
37 | if char == '{':
38 | stack.append(i)
39 | if first_json_start is None:
40 | first_json_start = i
41 | elif char == '}':
42 | if stack:
43 | start = stack.pop()
44 | if not stack:
45 | first_json_str = s[first_json_start:i+1]
46 | try:
47 | # Attempt to parse the JSON string
48 | return json.loads(first_json_str.replace("\n", ""))
49 | except json.JSONDecodeError as e:
50 | logger.error(f"JSON decoding failed: {e}. Attempted string: {first_json_str[:50]}...")
51 | return None
52 | finally:
53 | first_json_start = None
54 | logger.warning("No complete JSON object found in the input string.")
55 | return None
56 |
57 | def parse_value(value: str):
58 | """Convert a string value to its appropriate type (int, float, bool, None, or keep as string). Work as a more broad 'eval()'"""
59 | value = value.strip()
60 |
61 | if value == "null":
62 | return None
63 | elif value == "true":
64 | return True
65 | elif value == "false":
66 | return False
67 | else:
68 | # Try to convert to int or float
69 | try:
70 | if '.' in value: # If there's a dot, it might be a float
71 | return float(value)
72 | else:
73 | return int(value)
74 | except ValueError:
75 | # If conversion fails, return the value as-is (likely a string)
76 | return value.strip('"') # Remove surrounding quotes if they exist
77 |
78 | def extract_values_from_json(json_string, keys=["reasoning", "answer", "data"], allow_no_quotes=False):
79 | """Extract key values from a non-standard or malformed JSON string, handling nested objects."""
80 | extracted_values = {}
81 |
82 | # Enhanced pattern to match both quoted and unquoted values, as well as nested objects
83 | regex_pattern = r'(?P"?\w+"?)\s*:\s*(?P{[^}]*}|".*?"|[^,}]+)'
84 |
85 | for match in re.finditer(regex_pattern, json_string, re.DOTALL):
86 | key = match.group('key').strip('"') # Strip quotes from key
87 | value = match.group('value').strip()
88 |
89 | # If the value is another nested JSON (starts with '{' and ends with '}'), recursively parse it
90 | if value.startswith('{') and value.endswith('}'):
91 | extracted_values[key] = extract_values_from_json(value)
92 | else:
93 | # Parse the value into the appropriate type (int, float, bool, etc.)
94 | extracted_values[key] = parse_value(value)
95 |
96 | if not extracted_values:
97 | logger.warning("No values could be extracted from the string.")
98 |
99 | return extracted_values
100 |
101 |
102 | def convert_response_to_json(response: str) -> dict:
103 | """Convert response string to JSON, with error handling and fallback to non-standard JSON extraction."""
104 | prediction_json = extract_first_complete_json(response)
105 |
106 | if prediction_json is None:
107 | logger.info("Attempting to extract values from a non-standard JSON string...")
108 | prediction_json = extract_values_from_json(response, allow_no_quotes=True)
109 |
110 | if not prediction_json:
111 | logger.error("Unable to extract meaningful data from the response.")
112 | else:
113 | logger.info("JSON data successfully extracted.")
114 |
115 | return prediction_json
116 |
117 |
118 |
119 |
120 | def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
121 | global ENCODER
122 | if ENCODER is None:
123 | ENCODER = tiktoken.encoding_for_model(model_name)
124 | tokens = ENCODER.encode(content)
125 | return tokens
126 |
127 |
128 | def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
129 | global ENCODER
130 | if ENCODER is None:
131 | ENCODER = tiktoken.encoding_for_model(model_name)
132 | content = ENCODER.decode(tokens)
133 | return content
134 |
135 |
136 | def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
137 | """Truncate a list of data by token size"""
138 | if max_token_size <= 0:
139 | return []
140 | tokens = 0
141 | for i, data in enumerate(list_data):
142 | tokens += len(encode_string_by_tiktoken(key(data)))
143 | if tokens > max_token_size:
144 | return list_data[:i]
145 | return list_data
146 |
147 |
148 | def compute_mdhash_id(content, prefix: str = ""):
149 | return prefix + md5(content.encode()).hexdigest()
150 |
151 |
152 | def write_json(json_obj, file_name):
153 | with open(file_name, "w", encoding="utf-8") as f:
154 | json.dump(json_obj, f, indent=2, ensure_ascii=False)
155 |
156 |
157 | def load_json(file_name):
158 | if not os.path.exists(file_name):
159 | return None
160 | with open(file_name, encoding="utf-8") as f:
161 | return json.load(f)
162 |
163 |
164 | # it's dirty to type, so it's a good way to have fun
165 | def pack_user_ass_to_openai_messages(*args: str):
166 | roles = ["user", "assistant"]
167 | return [
168 | {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
169 | ]
170 |
171 |
172 | def is_float_regex(value):
173 | return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
174 |
175 |
176 | def compute_args_hash(*args):
177 | return md5(str(args).encode()).hexdigest()
178 |
179 |
180 | def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
181 | """Split a string by multiple markers"""
182 | if not markers:
183 | return [content]
184 | results = re.split("|".join(re.escape(marker) for marker in markers), content)
185 | return [r.strip() for r in results if r.strip()]
186 |
187 |
188 | def enclose_string_with_quotes(content: Any) -> str:
189 | """Enclose a string with quotes"""
190 | if isinstance(content, numbers.Number):
191 | return str(content)
192 | content = str(content)
193 | content = content.strip().strip("'").strip('"')
194 | return f'"{content}"'
195 |
196 |
197 | def list_of_list_to_csv(data: list[list]):
198 | return "\n".join(
199 | [
200 | ",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d])
201 | for data_d in data
202 | ]
203 | )
204 |
205 |
206 | # -----------------------------------------------------------------------------------
207 | # Refer the utils functions of the official GraphRAG implementation:
208 | # https://github.com/microsoft/graphrag
209 | def clean_str(input: Any) -> str:
210 | """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
211 | # If we get non-string input, just give it back
212 | if not isinstance(input, str):
213 | return input
214 |
215 | result = html.unescape(input.strip())
216 | # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
217 | return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
218 |
219 |
220 | # Utils types -----------------------------------------------------------------------
221 | @dataclass
222 | class EmbeddingFunc:
223 | embedding_dim: int
224 | max_token_size: int
225 | func: callable
226 |
227 | async def __call__(self, *args, **kwargs) -> np.ndarray:
228 | return await self.func(*args, **kwargs)
229 |
230 |
231 | # Decorators ------------------------------------------------------------------------
232 | def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
233 | """Add restriction of maximum async calling times for a async func"""
234 |
235 | def final_decro(func):
236 | """Not using async.Semaphore to aovid use nest-asyncio"""
237 | __current_size = 0
238 |
239 | @wraps(func)
240 | async def wait_func(*args, **kwargs):
241 | nonlocal __current_size
242 | while __current_size >= max_size:
243 | await asyncio.sleep(waitting_time)
244 | __current_size += 1
245 | result = await func(*args, **kwargs)
246 | __current_size -= 1
247 | return result
248 |
249 | return wait_func
250 |
251 | return final_decro
252 |
253 |
254 | def wrap_embedding_func_with_attrs(**kwargs):
255 | """Wrap a function with attributes"""
256 |
257 | def final_decro(func) -> EmbeddingFunc:
258 | new_func = EmbeddingFunc(**kwargs, func=func)
259 | return new_func
260 |
261 | return final_decro
262 |
--------------------------------------------------------------------------------
/hirag/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import TypedDict, Union, Literal, Generic, TypeVar
3 | from ._utils import EmbeddingFunc
4 | import numpy as np
5 |
6 |
7 | @dataclass
8 | class QueryParam:
9 | mode: Literal["hi_global", "hi_local", "hi_bridge", "hi_nobridge", "naive", "hi"] = "hi"
10 | only_need_context: bool = False
11 | response_type: str = "Multiple Paragraphs"
12 | level: int = 2
13 | top_k: int = 20 # retrieve top-k entities
14 | top_m: int = 10 # retrieve top-m entities in each retrieved community
15 | # naive search
16 | naive_max_token_for_text_unit = 10000
17 | # hi search
18 | max_token_for_text_unit: int = 20000
19 | max_token_for_local_context: int = 20000
20 | max_token_for_bridge_knowledge: int = 12500
21 | max_token_for_community_report: int = 12500
22 | community_single_one: bool = False
23 |
24 |
25 | TextChunkSchema = TypedDict(
26 | "TextChunkSchema",
27 | {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
28 | )
29 |
30 | SingleCommunitySchema = TypedDict(
31 | "SingleCommunitySchema",
32 | {
33 | "level": int,
34 | "title": str,
35 | "edges": list[list[str, str]],
36 | "nodes": list[str],
37 | "chunk_ids": list[str],
38 | "occurrence": float,
39 | "sub_communities": list[str],
40 | },
41 | )
42 |
43 |
44 | class CommunitySchema(SingleCommunitySchema):
45 | report_string: str
46 | report_json: dict
47 |
48 |
49 | T = TypeVar("T")
50 |
51 |
52 | @dataclass
53 | class StorageNameSpace:
54 | namespace: str
55 | global_config: dict
56 |
57 | async def index_start_callback(self):
58 | """commit the storage operations after indexing"""
59 | pass
60 |
61 | async def index_done_callback(self):
62 | """commit the storage operations after indexing"""
63 | pass
64 |
65 | async def query_done_callback(self):
66 | """commit the storage operations after querying"""
67 | pass
68 |
69 |
70 | @dataclass
71 | class BaseVectorStorage(StorageNameSpace):
72 | embedding_func: EmbeddingFunc
73 | meta_fields: set = field(default_factory=set)
74 |
75 | async def query(self, query: str, top_k: int) -> list[dict]:
76 | raise NotImplementedError
77 |
78 | async def upsert(self, data: dict[str, dict]):
79 | """Use 'content' field from value for embedding, use key as id.
80 | If embedding_func is None, use 'embedding' field from value
81 | """
82 | raise NotImplementedError
83 |
84 |
85 | @dataclass
86 | class BaseKVStorage(Generic[T], StorageNameSpace):
87 | async def all_keys(self) -> list[str]:
88 | raise NotImplementedError
89 |
90 | async def get_by_id(self, id: str) -> Union[T, None]:
91 | raise NotImplementedError
92 |
93 | async def get_by_ids(
94 | self, ids: list[str], fields: Union[set[str], None] = None
95 | ) -> list[Union[T, None]]:
96 | raise NotImplementedError
97 |
98 | async def filter_keys(self, data: list[str]) -> set[str]:
99 | """return un-exist keys"""
100 | raise NotImplementedError
101 |
102 | async def upsert(self, data: dict[str, T]):
103 | raise NotImplementedError
104 |
105 | async def drop(self):
106 | raise NotImplementedError
107 |
108 |
109 | @dataclass
110 | class BaseGraphStorage(StorageNameSpace):
111 | async def has_node(self, node_id: str) -> bool:
112 | raise NotImplementedError
113 |
114 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
115 | raise NotImplementedError
116 |
117 | async def node_degree(self, node_id: str) -> int:
118 | raise NotImplementedError
119 |
120 | async def edge_degree(self, src_id: str, tgt_id: str) -> int:
121 | raise NotImplementedError
122 |
123 | async def get_node(self, node_id: str) -> Union[dict, None]:
124 | raise NotImplementedError
125 |
126 | async def get_edge(
127 | self, source_node_id: str, target_node_id: str
128 | ) -> Union[dict, None]:
129 | raise NotImplementedError
130 |
131 | async def get_node_edges(
132 | self, source_node_id: str
133 | ) -> Union[list[tuple[str, str]], None]:
134 | raise NotImplementedError
135 |
136 | async def upsert_node(self, node_id: str, node_data: dict[str, str]):
137 | raise NotImplementedError
138 |
139 | async def upsert_edge(
140 | self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
141 | ):
142 | raise NotImplementedError
143 |
144 | async def clustering(self, algorithm: str):
145 | raise NotImplementedError
146 |
147 | async def community_schema(self) -> dict[str, SingleCommunitySchema]:
148 | """Return the community representation with report and nodes"""
149 | raise NotImplementedError
150 |
151 | async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
152 | raise NotImplementedError("Node embedding is not used in HiRAG.")
153 |
--------------------------------------------------------------------------------
/hirag/hirag.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from dataclasses import asdict, dataclass, field
4 | from datetime import datetime
5 | from functools import partial
6 | from typing import Callable, Dict, List, Optional, Type, Union, cast
7 |
8 | import tiktoken
9 |
10 |
11 | from ._llm import (
12 | gpt_4o_complete,
13 | gpt_4o_mini_complete,
14 | gpt_35_turbo_complete,
15 | openai_embedding,
16 | azure_gpt_4o_complete,
17 | azure_openai_embedding,
18 | azure_gpt_4o_mini_complete,
19 | )
20 | from ._op import (
21 | chunking_by_token_size,
22 | extract_entities,
23 | extract_hierarchical_entities,
24 | generate_community_report,
25 | get_chunks,
26 | hierarchical_query,
27 | hierarchical_bridge_query,
28 | hierarchical_local_query,
29 | hierarchical_global_query,
30 | hierarchical_nobridge_query,
31 | naive_query,
32 | )
33 | from ._storage import (
34 | JsonKVStorage,
35 | NanoVectorDBStorage,
36 | NetworkXStorage,
37 | )
38 | from ._utils import (
39 | EmbeddingFunc,
40 | compute_mdhash_id,
41 | limit_async_func_call,
42 | convert_response_to_json,
43 | always_get_an_event_loop,
44 | logger,
45 | )
46 | from .base import (
47 | BaseGraphStorage,
48 | BaseKVStorage,
49 | BaseVectorStorage,
50 | StorageNameSpace,
51 | QueryParam,
52 | )
53 |
54 |
55 | @dataclass
56 | class HiRAG:
57 | working_dir: str = field(
58 | default_factory=lambda: f"./hirag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
59 | )
60 | # graph mode
61 | enable_local: bool = True
62 | enable_naive_rag: bool = False
63 | enable_hierachical_mode: bool = True
64 |
65 | # text chunking
66 | chunk_func: Callable[
67 | [
68 | list[list[int]],
69 | List[str],
70 | tiktoken.Encoding,
71 | Optional[int],
72 | Optional[int],
73 | ],
74 | List[Dict[str, Union[str, int]]],
75 | ] = chunking_by_token_size
76 | chunk_token_size: int = 1200
77 | chunk_overlap_token_size: int = 100
78 | tiktoken_model_name: str = "gpt-4o"
79 |
80 | # entity extraction
81 | entity_extract_max_gleaning: int = 1
82 | entity_summary_to_max_tokens: int = 500
83 |
84 | # graph clustering
85 | graph_cluster_algorithm: str = "leiden"
86 | max_graph_cluster_size: int = 10
87 | graph_cluster_seed: int = 0xDEADBEEF
88 |
89 | # node embedding
90 | node_embedding_algorithm: str = "node2vec"
91 | node2vec_params: dict = field(
92 | default_factory=lambda: {
93 | "dimensions": 1536,
94 | "num_walks": 10,
95 | "walk_length": 40,
96 | "num_walks": 10,
97 | "window_size": 2,
98 | "iterations": 3,
99 | "random_seed": 3,
100 | }
101 | )
102 |
103 | # community reports
104 | special_community_report_llm_kwargs: dict = field(
105 | default_factory=lambda: {"response_format": {"type": "json_object"}}
106 | )
107 |
108 | # text embedding
109 | embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
110 | embedding_batch_num: int = 32
111 | embedding_func_max_async: int = 8
112 | query_better_than_threshold: float = 0.2
113 |
114 | # LLM
115 | using_azure_openai: bool = False
116 | # best_model_func: callable = gpt_35_turbo_complete
117 | best_model_func: callable = gpt_4o_mini_complete
118 | best_model_max_token_size: int = 32768
119 | best_model_max_async: int = 8
120 | cheap_model_func: callable = gpt_35_turbo_complete
121 | cheap_model_max_token_size: int = 32768
122 | cheap_model_max_async: int = 8
123 |
124 | # entity extraction
125 | entity_extraction_func: callable = extract_entities
126 | hierarchical_entity_extraction_func: callable = extract_hierarchical_entities
127 |
128 | # storage
129 | key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
130 | vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
131 | vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
132 | graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
133 | enable_llm_cache: bool = True
134 |
135 | # extension
136 | always_create_working_dir: bool = True
137 | addon_params: dict = field(default_factory=dict)
138 | convert_response_to_json_func: callable = convert_response_to_json
139 |
140 | def __post_init__(self):
141 | _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
142 | logger.debug(f"HiRAG init with param:\n\n {_print_config}\n")
143 |
144 | if self.using_azure_openai:
145 | # If there's no OpenAI API key, use Azure OpenAI
146 | if self.best_model_func == gpt_4o_complete:
147 | self.best_model_func = azure_gpt_4o_complete
148 | if self.cheap_model_func == gpt_4o_mini_complete:
149 | self.cheap_model_func = azure_gpt_4o_mini_complete
150 | if self.embedding_func == openai_embedding:
151 | self.embedding_func = azure_openai_embedding
152 | logger.info(
153 | "Switched the default openai funcs to Azure OpenAI if you didn't set any of it"
154 | )
155 |
156 | if not os.path.exists(self.working_dir) and self.always_create_working_dir:
157 | logger.info(f"Creating working directory {self.working_dir}")
158 | os.makedirs(self.working_dir)
159 |
160 | self.full_docs = self.key_string_value_json_storage_cls(
161 | namespace="full_docs", global_config=asdict(self)
162 | )
163 |
164 | self.text_chunks = self.key_string_value_json_storage_cls(
165 | namespace="text_chunks", global_config=asdict(self)
166 | )
167 |
168 | self.llm_response_cache = (
169 | self.key_string_value_json_storage_cls(
170 | namespace="llm_response_cache", global_config=asdict(self)
171 | )
172 | if self.enable_llm_cache
173 | else None
174 | )
175 |
176 | self.community_reports = self.key_string_value_json_storage_cls(
177 | namespace="community_reports", global_config=asdict(self)
178 | )
179 | self.chunk_entity_relation_graph = self.graph_storage_cls(
180 | namespace="chunk_entity_relation", global_config=asdict(self)
181 | )
182 |
183 | self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
184 | self.embedding_func
185 | )
186 | self.entities_vdb = (
187 | self.vector_db_storage_cls(
188 | namespace="entities",
189 | global_config=asdict(self),
190 | embedding_func=self.embedding_func,
191 | meta_fields={"entity_name"},
192 | )
193 | if self.enable_local
194 | else None
195 | )
196 | self.chunks_vdb = (
197 | self.vector_db_storage_cls(
198 | namespace="chunks",
199 | global_config=asdict(self),
200 | embedding_func=self.embedding_func,
201 | )
202 | if self.enable_naive_rag
203 | else None
204 | )
205 |
206 | self.best_model_func = limit_async_func_call(self.best_model_max_async)(
207 | partial(self.best_model_func, hashing_kv=self.llm_response_cache)
208 | )
209 | self.cheap_model_func = limit_async_func_call(self.cheap_model_max_async)(
210 | partial(self.cheap_model_func, hashing_kv=self.llm_response_cache)
211 | )
212 |
213 | def insert(self, string_or_strings):
214 | loop = always_get_an_event_loop()
215 | return loop.run_until_complete(self.ainsert(string_or_strings))
216 |
217 | def query(self, query: str, param: QueryParam = QueryParam()):
218 | loop = always_get_an_event_loop()
219 | return loop.run_until_complete(self.aquery(query, param))
220 |
221 | async def aquery(self, query: str, param: QueryParam = QueryParam()):
222 | if param.mode == "naive" and not self.enable_naive_rag:
223 | raise ValueError("enable_naive_rag is False, cannot query in naive mode")
224 | if param.mode == "hi" and not self.enable_hierachical_mode:
225 | raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical mode")
226 | if param.mode == "hi_nobridge" and not self.enable_hierachical_mode:
227 | raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical_nobridge mode")
228 | if param.mode == "hi_bridge" and not self.enable_hierachical_mode:
229 | raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical_bridge mode")
230 | if param.mode == "hi_local" and not self.enable_hierachical_mode:
231 | raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical_local mode")
232 | if param.mode == "hi_global" and not self.enable_hierachical_mode:
233 | raise ValueError("enable_hierachical_mode is False, cannot query in hierarchical_global mode")
234 |
235 | if param.mode == "hi": # retrieve with hierarchical knowledge
236 | response = await hierarchical_query(
237 | query,
238 | self.chunk_entity_relation_graph,
239 | self.entities_vdb,
240 | self.community_reports,
241 | self.text_chunks,
242 | param,
243 | asdict(self),
244 | )
245 | elif param.mode == "hi_bridge": # retrieve with only bridge knowledge
246 | response = await hierarchical_bridge_query(
247 | query,
248 | self.chunk_entity_relation_graph,
249 | self.entities_vdb,
250 | self.community_reports,
251 | self.text_chunks,
252 | param,
253 | asdict(self),
254 | )
255 | elif param.mode == "hi_local": # retrieve with only local knowledge
256 | response = await hierarchical_local_query(
257 | query,
258 | self.chunk_entity_relation_graph,
259 | self.entities_vdb,
260 | self.community_reports,
261 | self.text_chunks,
262 | param,
263 | asdict(self),
264 | )
265 | elif param.mode == "hi_global": # retrieve with only global knowledge
266 | response = await hierarchical_global_query(
267 | query,
268 | self.chunk_entity_relation_graph,
269 | self.entities_vdb,
270 | self.community_reports,
271 | self.text_chunks,
272 | param,
273 | asdict(self),
274 | )
275 | elif param.mode == "hi_nobridge": # retrieve with no bridge knowledge
276 | response = await hierarchical_nobridge_query(
277 | query,
278 | self.chunk_entity_relation_graph,
279 | self.entities_vdb,
280 | self.community_reports,
281 | self.text_chunks,
282 | param,
283 | asdict(self),
284 | )
285 | elif param.mode == "naive": # retrieve with only text units
286 | response = await naive_query(
287 | query,
288 | self.chunks_vdb,
289 | self.text_chunks,
290 | param,
291 | asdict(self),
292 | )
293 | else:
294 | raise ValueError(f"Unknown mode {param.mode}")
295 | await self._query_done()
296 | return response
297 |
298 | async def ainsert(self, string_or_strings):
299 | await self._insert_start()
300 | try:
301 | if isinstance(string_or_strings, str):
302 | string_or_strings = [string_or_strings]
303 | # ---------- new docs
304 | new_docs = { # dict: {hash: ori_content}
305 | compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
306 | for c in string_or_strings
307 | }
308 | _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) # filter the docs that has already in the storage.
309 | new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
310 | if not len(new_docs):
311 | logger.warning(f"All docs are already in the storage")
312 | return
313 | logger.info(f"[New Docs] inserting {len(new_docs)} docs")
314 |
315 | # ---------- chunking
316 |
317 | inserting_chunks = get_chunks(
318 | new_docs=new_docs,
319 | chunk_func=self.chunk_func,
320 | overlap_token_size=self.chunk_overlap_token_size,
321 | max_token_size=self.chunk_token_size,
322 | )
323 |
324 | _add_chunk_keys = await self.text_chunks.filter_keys(
325 | list(inserting_chunks.keys())
326 | )
327 | inserting_chunks = {
328 | k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
329 | }
330 | if not len(inserting_chunks):
331 | logger.warning(f"All chunks are already in the storage")
332 | return
333 | logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
334 | if self.enable_naive_rag:
335 | logger.info("Insert chunks for naive RAG")
336 | await self.chunks_vdb.upsert(inserting_chunks)
337 |
338 | # TODO: no incremental update for communities now, so just drop all
339 | await self.community_reports.drop() # empty the data
340 |
341 | # ---------- extract/summary entity and upsert to graph
342 | if not self.enable_hierachical_mode:
343 | logger.info("[Entity Extraction]...")
344 | maybe_new_kg = await self.entity_extraction_func(
345 | inserting_chunks,
346 | knwoledge_graph_inst=self.chunk_entity_relation_graph,
347 | entity_vdb=self.entities_vdb,
348 | global_config=asdict(self),
349 | )
350 | else:
351 | logger.info("[Hierachical Entity Extraction]...")
352 | maybe_new_kg = await self.hierarchical_entity_extraction_func(
353 | inserting_chunks,
354 | knowledge_graph_inst=self.chunk_entity_relation_graph,
355 | entity_vdb=self.entities_vdb,
356 | global_config=asdict(self),
357 | )
358 | if maybe_new_kg is None:
359 | logger.warning("No new entities found")
360 | return
361 | self.chunk_entity_relation_graph = maybe_new_kg
362 | # ---------- update clusterings of graph
363 | logger.info("[Community Report]...")
364 | await self.chunk_entity_relation_graph.clustering(
365 | self.graph_cluster_algorithm # use leiden
366 | )
367 | await generate_community_report(
368 | self.community_reports, self.chunk_entity_relation_graph, asdict(self)
369 | )
370 |
371 | # ---------- commit upsertings and indexing
372 | await self.full_docs.upsert(new_docs)
373 | await self.text_chunks.upsert(inserting_chunks)
374 | finally:
375 | await self._insert_done()
376 |
377 | async def _insert_start(self):
378 | tasks = []
379 | for storage_inst in [
380 | self.chunk_entity_relation_graph,
381 | ]:
382 | if storage_inst is None:
383 | continue
384 | tasks.append(cast(StorageNameSpace, storage_inst).index_start_callback())
385 | await asyncio.gather(*tasks)
386 |
387 | async def _insert_done(self):
388 | tasks = []
389 | for storage_inst in [
390 | self.full_docs,
391 | self.text_chunks,
392 | self.llm_response_cache,
393 | self.community_reports,
394 | self.entities_vdb,
395 | self.chunks_vdb,
396 | self.chunk_entity_relation_graph,
397 | ]:
398 | if storage_inst is None:
399 | continue
400 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
401 | await asyncio.gather(*tasks)
402 |
403 | async def _query_done(self):
404 | tasks = []
405 | for storage_inst in [self.llm_response_cache]:
406 | if storage_inst is None:
407 | continue
408 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
409 | await asyncio.gather(*tasks)
410 |
--------------------------------------------------------------------------------
/imgs/hirag_ds.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hhy-huang/HiRAG/4a22510ad84fcc2b6562b0b8cb45341e85dcba2d/imgs/hirag_ds.drawio.png
--------------------------------------------------------------------------------
/imgs/hirag_ds_trans.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hhy-huang/HiRAG/4a22510ad84fcc2b6562b0b8cb45341e85dcba2d/imgs/hirag_ds_trans.drawio.png
--------------------------------------------------------------------------------
/imgs/hirag_icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hhy-huang/HiRAG/4a22510ad84fcc2b6562b0b8cb45341e85dcba2d/imgs/hirag_icon.png
--------------------------------------------------------------------------------
/imgs/hirag_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hhy-huang/HiRAG/4a22510ad84fcc2b6562b0b8cb45341e85dcba2d/imgs/hirag_image.png
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Retrieval-Augmented Generation with Hierarchical Knowledge (HiRAG)
2 | This is the repo for the paper [Retrieval-Augmented Generation with Hierarchical Knowledge](https://arxiv.org/abs/2503.10150).
3 |
4 | ## Model Pipeline
5 |
6 | 
7 |
8 | ## Install
9 |
10 | ```bash
11 | # remember clone this repo first
12 | cd HiRAG
13 | pip install -e .
14 | ```
15 |
16 | ## Quick Start
17 |
18 | You can just utilize the following code to perform a query with HiRAG.
19 |
20 | ```python
21 | graph_func = HiRAG(
22 | working_dir="./your_work_dir",
23 | enable_llm_cache=True,
24 | enable_hierachical_mode=True,
25 | embedding_batch_num=6,
26 | embedding_func_max_async=8, # according to your machine
27 | enable_naive_rag=True
28 | )
29 | # indexing
30 | with open("path_to_your_context.txt", "r") as f:
31 | graph_func.insert(f.read())
32 | # retrieval & generation
33 | print("Perform hi search:")
34 | print(graph_func.query("The question you want to ask?", param=QueryParam(mode="hi")))
35 | ```
36 |
37 | Or if you want to employ HiRAG with DeepSeek, ChatGLM, or other third-party retrieval api, here are the examples in `./hi_Search_deepseek.py`, `./hi_Search_glm.py`, and `./hi_Search_openai.py`. The API keys and the LLM configurations can be set at `./config.yaml`.
38 |
39 |
40 | ## Evaluation
41 |
42 | We take the procedure in Mix dataset as an example.
43 |
44 | ```shell
45 | cd ./HiRAG/eval
46 | ```
47 |
48 | 1. Extract context from original QA datasets.
49 | ```shell
50 | python extract_context.py -i ./datasets/mix -o ./datasets/mix
51 | ```
52 |
53 | 2. Insert context to Graph Database.
54 | ```shell
55 | python insert_context_deepseek.py
56 | ```
57 |
58 | Note that the script `insert_context_deepseek.py` is for the setting of generation with DeepSeek-v3 api, you can replace that with `insert_context_openai.py` or `insert_context_glm.py`.
59 |
60 | 3. Test with different versions of HiRAG.
61 | ```shell
62 | # there are different retrieval options
63 | # If you want to employ HiRAG approach, just run:
64 | python test_deepseek.py -d mix -m hi
65 | # If you want to employ naive RAG approach, just run:
66 | python test_deepseek.py -d mix -m naive
67 | # If you want to employ HiRAG approach w/o bridge, just run:
68 | python test_deepseek.py -d mix -m hi_nobridge
69 | # If you want to employ HiRAG approach with retrieving only local knowledge, just run:
70 | python test_deepseek.py -d mix -m hi_local
71 | # If you want to employ HiRAG approach with retrieving only global knowledge, just run:
72 | python test_deepseek.py -d mix -m hi_global
73 | # If you want to employ HiRAG approach with retrieving only bridge knowledge, just run:
74 | python test_deepseek.py -d mix -m hi_bridge
75 | ```
76 |
77 | Note that the dataset `mix` can be replaced to any other datasets in [Hugging Face link](https://huggingface.co/datasets/TommyChien/UltraDomain/tree/main). And the script `test_deepseek.py` is for the setting of generation with DeepSeek-v3 api, you can replace that with `test_openai.py` or `test_glm.py`.
78 |
79 | 4. Evaluate the generated answers.
80 |
81 | First step, request for evaluations.
82 | ```shell
83 | python batch_eval.py -m request -api openai
84 | python batch_eval.py -m request -api deepseek
85 | ```
86 |
87 | Second step, get the results.
88 | ```shell
89 | python batch_eval.py -m result -api openai
90 | python batch_eval.py -m result -api deepseek
91 | ```
92 |
93 | ## Results
94 |
95 | ### Compare with Naive RAG:
96 |
97 | With the config `output_file` set as `f"./datasets/{DATASET}/{DATASET}_eval_hi_naive.jsonl"`, just run the command:
98 | ```
99 | python batch_eval.py -m result -api openai
100 | ```
101 |
102 | | Dataset | Dimension | NaiveRAG % | HiRAG % |
103 | |----------:|:--------:|--------------:|----------------:|
104 | | Mix||||
105 | | |Comprehensiveness| 16.6| **83.4**|
106 | | |Empowerment| 11.6| **88.4**|
107 | | |Diversity| 12.7| **87.3**|
108 | | |Overall| 12.4| **87.6**|
109 | | CS||||
110 | | |Comprehensiveness| 30.0| **70.0**|
111 | | |Empowerment| 29.0| **71.0**|
112 | | |Diversity| 14.5| **85.5**|
113 | | |Overall| 26.5| **73.5**|
114 | | Legal||||
115 | | |Comprehensiveness| 32.5| **67.5**|
116 | | |Empowerment| 25.0| **75.0**|
117 | | |Diversity| 22.0| **78.0**|
118 | | |Overall| 22.5| **74.5**|
119 | | Agriculture||||
120 | | |Comprehensiveness| 34.0| **66.0**|
121 | | |Empowerment| 31.0| **69.0**|
122 | | |Diversity| 21.0| **79.0**|
123 | | |Overall| 28.5| **71.5**|
124 |
125 |
126 | ### Compare with GraphRAG:
127 |
128 | With the config `output_file` set as `f"./datasets/{DATASET}/{DATASET}_eval_hi_graphrag.jsonl"`, just run the command:
129 | ```
130 | python batch_eval.py -m result -api openai
131 | ```
132 |
133 | | Dataset | Dimension | GraphRAG % | HiRAG % |
134 | |----------:|:--------:|--------------:|----------------:|
135 | | Mix||||
136 | | |Comprehensiveness| 42.1| **57.9**|
137 | | |Empowerment| 35.1| **64.9**|
138 | | |Diversity| 40.5| **59.5**|
139 | | |Overall| 35.9| **64.1**|
140 | | CS||||
141 | | |Comprehensiveness| 40.5| **59.5**|
142 | | |Empowerment| 38.5| **61.5**|
143 | | |Diversity| 30.5| **69.5**|
144 | | |Overall| 36.0| **64.0**|
145 | | Legal||||
146 | | |Comprehensiveness| 48.5| **51.5**|
147 | | |Empowerment| 43.5| **56.5**|
148 | | |Diversity| 47.0| **53.0**|
149 | | |Overall| 45.5| **54.5**|
150 | | Agriculture||||
151 | | |Comprehensiveness| 49.0| **51.0**|
152 | | |Empowerment| 48.5| **51.5**|
153 | | |Diversity| 45.5| **54.5**|
154 | | |Overall| 46.0| **54.0**|
155 |
156 | ### Compare with LightRAG:
157 |
158 | With the config `output_file` set as `f"./datasets/{DATASET}/{DATASET}_eval_hi_lightrag.jsonl"`, just run the command:
159 | ```
160 | python batch_eval.py -m result -api openai
161 | ```
162 |
163 | | Dataset | Dimension | LightRAG % | HiRAG % |
164 | |----------:|:--------:|--------------:|----------------:|
165 | | Mix||||
166 | | |Comprehensiveness| 36.8| **63.2**|
167 | | |Empowerment| 34.9| **65.1**|
168 | | |Diversity| 34.1| **65.9**|
169 | | |Overall| 34.1| **65.9**|
170 | | CS||||
171 | | |Comprehensiveness| 44.5| **55.5**|
172 | | |Empowerment| 41.5| **58.5**|
173 | | |Diversity| 33.0| **67.0**|
174 | | |Overall| 41.0| **59.0**|
175 | | Legal||||
176 | | |Comprehensiveness| 49.0| **51.0**|
177 | | |Empowerment| 43.5| **56.5**|
178 | | |Diversity| **63.0**| 37.0|
179 | | |Overall| 48.0| **52.0**|
180 | | Agriculture||||
181 | | |Comprehensiveness| 38.5| **61.5**|
182 | | |Empowerment| 36.5| **63.5**|
183 | | |Diversity| 37.5| **62.5**|
184 | | |Overall| 38.5| **61.5**|
185 |
186 | ### Compare with FastGraphRAG:
187 |
188 | With the config `output_file` set as `f"./datasets/{DATASET}/{DATASET}_eval_hi_fastgraphrag.jsonl"`, just run the command:
189 | ```
190 | python batch_eval.py -m result -api openai
191 | ```
192 |
193 | | Dataset | Dimension | FastGraphRAG % | HiRAG % |
194 | |----------:|:--------:|--------------:|----------------:|
195 | | Mix||||
196 | | |Comprehensiveness| 0.8| **99.2**|
197 | | |Empowerment| 0.8| **99.2**|
198 | | |Diversity| 0.8| **99.2**|
199 | | |Overall| 0.8| **99.2**|
200 | | CS||||
201 | | |Comprehensiveness| 0.0| **100.0**|
202 | | |Empowerment| 0.0| **100.0**|
203 | | |Diversity| 0.5| **99.5**|
204 | | |Overall| 0.0| **100.0**|
205 | | Legal||||
206 | | |Comprehensiveness| 1.0| **99.0**|
207 | | |Empowerment| 0.0| **100.0**|
208 | | |Diversity| 1.5| **98.5**|
209 | | |Overall| 0.0| **100.0**|
210 | | Agriculture||||
211 | | |Comprehensiveness| 0.0| **100.0**|
212 | | |Empowerment| 0.0| **100.0**|
213 | | |Diversity| 0.0| **100.0**|
214 | | |Overall| 0.0| **100.0**|
215 |
216 | ### Compare with KAG:
217 |
218 | With the config `output_file` set as `f"./datasets/{DATASET}/{DATASET}_eval_hi_kag.jsonl"`, just run the command:
219 | ```
220 | python batch_eval.py -m result -api openai
221 | ```
222 |
223 | | Dataset | Dimension | KAG % | HiRAG % |
224 | |----------:|:--------:|--------------:|----------------:|
225 | | Mix||||
226 | | |Comprehensiveness| 2.3| **97.7**|
227 | | |Empowerment| 3.5| **96.5**|
228 | | |Diversity| 3.8| **96.2**|
229 | | |Overall| 2.3| **97.7**|
230 | | CS||||
231 | | |Comprehensiveness| 1.0| **99.0**|
232 | | |Empowerment| 4.5| **95.5**|
233 | | |Diversity| 5.0| **95.0**|
234 | | |Overall| 1.5| **98.5**|
235 | | Legal||||
236 | | |Comprehensiveness| 16.5| **83.5**|
237 | | |Empowerment| 9.0| **91.0**|
238 | | |Diversity| 11.0| **89.0**|
239 | | |Overall| 8.5| **91.5**|
240 | | Agriculture||||
241 | | |Comprehensiveness| 5.0| **95.0**|
242 | | |Empowerment| 5.0| **95.0**|
243 | | |Diversity| 3.5| **96.5**|
244 | | |Overall| 0.0| **100.0**|
245 |
246 | ## Acknowledgement
247 | We gratefully acknowledge the use of the following open-source projects in our work:
248 | - [nano-graphrag](https://github.com/gusye1234/nano-graphrag): a simple, easy-to-hack GraphRAG implementation
249 |
250 | - [RAPTOR](https://github.com/parthsarthi03/raptor): a novel approach to retrieval-augmented language models by constructing a recursive tree structure from documents.
251 |
252 | ## Cite Us
253 | ```
254 | @misc{huang2025retrievalaugmentedgenerationhierarchicalknowledge,
255 | title={Retrieval-Augmented Generation with Hierarchical Knowledge},
256 | author={Haoyu Huang and Yongfeng Huang and Junjie Yang and Zhenyu Pan and Yongqiang Chen and Kaili Ma and Hongzhi Chen and James Cheng},
257 | year={2025},
258 | eprint={2503.10150},
259 | archivePrefix={arXiv},
260 | primaryClass={cs.CL},
261 | url={https://arxiv.org/abs/2503.10150},
262 | }
263 | ```
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | graspologic==3.4.1
2 | jsonlines==4.0.0
3 | matplotlib==3.10.0
4 | nano_vectordb==0.0.2
5 | neo4j==5.25.0
6 | nest_asyncio==1.6.0
7 | networkx==3.3
8 | numpy==1.26.4
9 | openai==1.61.1
10 | pandas==2.2.3
11 | pydantic==2.10.6
12 | ragas==0.2.2
13 | scikit_learn==1.6.1
14 | seaborn==0.13.2
15 | setuptools==73.0.1
16 | tenacity==9.0.0
17 | tiktoken==0.7.0
18 | tqdm==4.66.5
19 | transformers==4.47.1
20 | umap_learn==0.5.6
21 | xxhash==3.5.0
22 | future==1.0.0
23 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | from setuptools import find_packages
3 |
4 | with open("readme.md", "r") as fh:
5 | long_description = fh.read()
6 |
7 |
8 | vars2find = ["__author__", "__version__", "__url__"]
9 | vars2readme = {}
10 | with open("./hirag/__init__.py") as f:
11 | for line in f.readlines():
12 | for v in vars2find:
13 | if line.startswith(v):
14 | line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
15 | vars2readme[v] = line.split("=")[1]
16 |
17 | deps = []
18 | with open("./requirements.txt") as f:
19 | for line in f.readlines():
20 | if not line.strip():
21 | continue
22 | deps.append(line.strip())
23 |
24 | setuptools.setup(
25 | name="HiRAG",
26 | url=vars2readme["__url__"],
27 | version=vars2readme["__version__"],
28 | author=vars2readme["__author__"],
29 | description="HiRAG implementation",
30 | long_description=long_description,
31 | long_description_content_type="text/markdown",
32 | packages=find_packages(),
33 | classifiers=[
34 | "Programming Language :: Python :: 3",
35 | "License :: OSI Approved :: MIT License",
36 | "Operating System :: OS Independent",
37 | ],
38 | python_requires=">=3.9",
39 | install_requires=deps,
40 | )
41 |
--------------------------------------------------------------------------------