├── .gitignore ├── LICENSE ├── README.md ├── assets ├── Hyper-RAG.pdf ├── extract.svg ├── fw.svg ├── hg.svg ├── many_llms_all.svg ├── many_llms_sp.svg ├── multi_domain.svg ├── speed_all.svg ├── vis-QA.jpg └── vis-hg.jpg ├── config_temp.py ├── evaluate ├── evaluate_by_scoring.py └── evaluate_by_selection.py ├── examples ├── hyperrag_demo.py └── mock_data.txt ├── hyperrag ├── __init__.py ├── base.py ├── hyperrag.py ├── llm.py ├── operate.py ├── prompt.py ├── storage.py └── utils.py ├── reproduce ├── Step_0.py ├── Step_1.py ├── Step_2_extract_question.py └── Step_3_response_question.py ├── requirements.txt └── web-ui ├── .gitignore ├── README.md ├── backend ├── README.md ├── db.py ├── hyperdb │ ├── __init__.py │ ├── _global.py │ ├── base.py │ └── hypergraph.py ├── hypergraph_A_Christmas_Carol.hgdb ├── hypergraph_wukong.hgdb ├── main.py └── requirements.txt └── frontend ├── .commitlintrc.cjs ├── .env.mock ├── .env.production ├── .eslintignore ├── .eslintrc.cjs ├── .prettierignore ├── .prettierrc ├── .stylelintignore ├── .stylelintrc.cjs ├── README.md ├── config ├── defaultSettings.ts ├── mock │ └── user.ts ├── proxy.ts └── routes │ ├── index.tsx │ └── routers.jsx ├── index.html ├── package.json ├── public └── logo.png ├── server.js ├── src ├── 404.tsx ├── App.tsx ├── ErrorPage.tsx ├── _defaultProps.tsx ├── assets │ ├── react.svg │ └── show.png ├── components │ ├── NotFound │ │ ├── index.tsx │ │ └── type.d.ts │ ├── errorBoundary.jsx │ └── loading │ │ ├── index.module.less │ │ └── index.tsx ├── layout │ └── BasicLayout.tsx ├── main.tsx ├── pages │ ├── Home │ │ └── index.jsx │ └── Hyper │ │ ├── DB │ │ └── index.jsx │ │ ├── Files │ │ └── index.jsx │ │ └── Graph │ │ ├── data.js │ │ └── index.jsx ├── store │ └── globalUser.ts ├── utils │ └── index.js └── vite-env.d.ts ├── tsconfig.json └── vite.config.ts /.gitignore: -------------------------------------------------------------------------------- 1 | /my_config.py 2 | .DS_Store 3 | caches/ 4 | datasets/ 5 | .vscode/ 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # UV 103 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | #uv.lock 107 | 108 | # poetry 109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 110 | # This is especially recommended for binary packages to ensure reproducibility, and is more 111 | # commonly ignored for libraries. 112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 113 | #poetry.lock 114 | 115 | # pdm 116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 117 | #pdm.lock 118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 119 | # in version control. 120 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 121 | .pdm.toml 122 | .pdm-python 123 | .pdm-build/ 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | # PyCharm 169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 171 | # and can be added to the global gitignore or merged into this file. For a more nuclear 172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 173 | #.idea/ 174 | 175 | # Ruff stuff: 176 | .ruff_cache/ 177 | 178 | # PyPI configuration file 179 | .pypirc 180 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2025] [Yifan Feng] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 4 | 5 |

Hyper-RAG

6 | 7 |

8 | Github top language 9 | 10 | Github language count 11 | 12 | Repository size 13 | 14 | License 15 | 16 | 17 | 18 | 19 | 20 | Github stars 21 |

22 | 23 |

24 | About   |   25 | Features   |   26 | Installation   |   27 | Quick Start   |   28 | Evaluation   |   29 | License   |   30 | Author 31 |

32 | 33 |
34 | 35 | 36 |
37 | Overall Performance 38 |
39 | 40 | We show that Hyper-RAG is a powerful RAG that can enhance the performance of various LLMs and outperform other SOTA RAG methods in the NeurologyCorp dataset. **Our paper is available at here**. 41 | 42 | ## :dart: About 43 | 44 |
45 | Abstract 46 | Large language models (LLMs) have transformed various sectors, including education, finance, and medicine, by enhancing content generation and decision-making processes. However, their integration into the medical field is cautious due to hallucinations, instances where generated content deviates from factual accuracy, potentially leading to adverse outcomes. To address this, we introduce Hyper-RAG, a hypergraph-driven Retrieval-Augmented Generation method that comprehensively captures both pairwise and beyond-pairwise correlations in domain-specific knowledge, thereby mitigating hallucinations. Experiments on the NeurologyCrop dataset with six prominent LLMs demonstrated that Hyper-RAG improves accuracy by an average of 12.3% over direct LLM use and outperforms Graph RAG and Light RAG by 6.3% and 6.0%, respectively. Additionally, Hyper-RAG maintained stable performance with increasing query complexity, unlike existing methods which declined. Further validation across nine diverse datasets showed a 35.5% performance improvement over Light RAG using a selection-based assessment. The lightweight variant, Hyper-RAG-Lite, achieved twice the retrieval speed and a 3.3\% performance boost compared with Light RAG. These results confirm Hyper-RAG's effectiveness in enhancing LLM reliability and reducing hallucinations, making it a robust solution for high-stakes applications like medical diagnostics. 47 |
48 | 49 |
50 | 51 |
52 | Framework 53 |
54 | Schematic diagram of the proposed Hyper-RAG architecture. a, The patient poses a question. b, A knowledge base is constructed from relevant domainspecific corpora. c, Responses are generated directly using LLMs. d, Hyper-RAG generates responses by first retrieving relevant prior knowledge from the knowledge base and then inputting this knowledge, along with the patient’s question, into the LLMs to formulate the reply. 55 | 56 |
57 |
58 | 59 |
60 | More details about hypergraph modeling 61 |
62 | Hypergraph 63 | Example of hypergraph modeling for entity space. Hypergraph can model the beyond-pairwise relationship among entities, which is more powerful than the pairwise relationship in traditional graph modeling. With hypergraphs, we can avoid the information loss caused by the pairwise relationship. 64 |
65 |
66 |
67 | Extract Hypergraph 68 | Illustration of Entity and Correlation Extraction from Raw Corpus: Dark brown boxes represent entities, blue arrows denote low-order correlations between entities, and red arrows indicate high-order correlations. Yellow boxes contain the original descriptions of the respective entities or their correlations. 69 |
70 |
71 | 72 |
73 | 74 | ## :sparkles: Why Hyper-RAG is More Powerful 75 | 76 | :heavy_check_mark: **Comprehensive Relationship Modeling with Hypergraphs**: Utilizes hypergraphs to thoroughly model the associations within the raw corpus data, providing more complex relationships compared to traditional graph-based data organization.;\ 77 | :heavy_check_mark: **Native Hypergraph-DB Integration**: Employs the native hypergraph database, Hypergraph-DB, as the foundation, supporting rapid retrieval of higher-order associations.;\ 78 | :heavy_check_mark: **Superior Performance**: Hyper-RAG outperforms Graph RAG and Light RAG by 6.3% and 6.0% respectively.;\ 79 | :heavy_check_mark: **Broad Validation**: Across nine diverse datasets, Hyper-RAG shows a 35.5% performance improvement over Light RAG based on a selection-based assessment.;\ 80 | :heavy_check_mark: **Efficiency**: The lightweight variant, Hyper-RAG-Lite, achieves twice the retrieval speed and a 3.3% performance boost compared to Light RAG.; 81 | 82 | ## :rocket: Installation 83 | 84 | 85 | ```bash 86 | # Clone this project 87 | git clone https://github.com/iMoonLab/Hyper-RAG.git 88 | 89 | # Access 90 | cd Hyper-RAG 91 | 92 | # Install dependencies 93 | pip install -r requirements.txt 94 | ``` 95 | 96 | ## :white_check_mark: Quick Start 97 | 98 | ### Configure your LLM API 99 | Copy the `config_temp.py` file to `my_config.py` in the root folder and set your LLM `URL` and `KEY`. 100 | 101 | ```python 102 | LLM_BASE_URL = "Yours xxx" 103 | LLM_API_KEY = "Yours xxx" 104 | LLM_MODEL = "gpt-4o-mini" 105 | 106 | EMB_BASE_URL = "Yours xxx" 107 | EMB_API_KEY = "Yours xxx" 108 | EMB_MODEL = "text-embedding-3-small" 109 | EMB_DIM = 1536 110 | ``` 111 | 112 | ### Run the toy example 113 | 114 | ```bash 115 | python examples/hyperrag_demo.py 116 | ``` 117 | 118 | ### Or Run by Steps 119 | 120 | 1. Prepare the data. You can download the dataset from here. Put the dataset in the root direction. Then run the following command to preprocess the data. 121 | 122 | ```bash 123 | python reproduce/Step_0.py 124 | ``` 125 | 126 | 2. Build the knowledge hypergraphs, and entity and relation vector database with following command. 127 | 128 | ```bash 129 | python reproduce/Step_1.py 130 | ``` 131 | 132 | 3. Extract questions from the orignial datasets with following command. 133 | 134 | ```bash 135 | python reproduce/Step_2_extract_question.py 136 | ``` 137 | 138 | Those questions are saved in the `cache/{{data_name}}/questions` folder. 139 | 140 | 4. Run the Hyper-RAG to response those questions with following command. 141 | 142 | ```bash 143 | python reproduce/Step_3_response_question.py 144 | ``` 145 | 146 | Those response are saved in the `cache/{{data_name}}/response` folder. 147 | 148 | You can also change the `mode` parameter to `hyper` or `hyper-lite` to run the Hyper-RAG or Hyper-RAG-Lite. 149 | 150 | 151 | ### Hypergraph Visualization 152 | We provide a web-based visualization tool for hypergraphs and lightweight Hyper-RAG QA system. For more information, please refer to [Hyper-RAG Web-UI](./web-ui/README.md). 153 | 154 | *Note: The web UI is still under development and may not be fully functional. We welcome any contributions to improve it.* 155 | ![vis-qa](./assets/vis-QA.jpg) 156 | ![vis-hg](./assets/vis-hg.jpg) 157 | 158 | 159 | 160 | ## :checkered_flag: Evaluation 161 | In this work, we propose two evaluation strategys: the **selection-based** and **scoring-based** evaluation. 162 | 163 | ### Scoring-based evaluation 164 | Scoring-Based Assessment is designed to facilitate the comparative evaluation of multiple model outputs by quantifying their performance across various dimensions. This approach allows for a nuanced assessment of model capabilities by providing scores on several key metrics. However, a notable limitation is its reliance on reference answers. In our preprocessing steps, we leverage the source chunks from which each question is derived as reference answers. 165 | 166 | You can use the following command to use this evaluation method. 167 | 168 | ```bash 169 | python evaluate/evaluate_by_scoring.py 170 | ``` 171 | The results of this evaluation are shown in the following figure. 172 |
173 | Scoring-based evaluation 174 |
175 | 176 | 177 | ### Selection-based evaluation 178 | Selection-Based Assessment is tailored for scenarios where preliminary candidate models are available, enabling a comparative evaluation through a binary choice mechanism. This method does not require reference answers, making it suitable for diverse and open-ended questions. However, its limitation lies in its comparative nature, as it only allows for the evaluation of two models at a time. 179 | 180 | You can use the following command to use this evaluation method. 181 | 182 | ```bash 183 | python evaluate/evaluate_by_selection.py 184 | ``` 185 | The results of this evaluation are shown in the following figure. 186 |
187 | Selection-based evaluation 188 |
189 | 190 | 191 | ### Efficiency Analysis 192 | We conducted an efficiency analysis of our Hyper-RAG method using GPT-4o mini on the NeurologyCrop dataset, comparing it with standard RAG, Graph RAG, and Light RAG. To ensure fairness by excluding network latency, we measured only the local retrieval time for relevant knowledge and the construction of the prior knowledge prompt. While standard RAG focuses on the direct retrieval of chunk embeddings, Graph RAG, Light RAG, and Hyper-RAG also include retrieval from node and correlation vector databases and the time for one layer of graph or hypergraph information diffusion. We averaged the response times over 50 questions from the dataset for each method. The results are shown in the following figure. 193 | 194 |
195 | Efficiency analysis 196 |
197 | 198 | ## :memo: License 199 | 200 | This project is under license from Apache 2.0. For more details, see the [LICENSE](LICENSE.md) file. 201 | 202 | Hyper-RAG is maintained by [iMoon-Lab](http://moon-lab.tech/), Tsinghua University. 203 | Made with :heart: by Yifan Feng, Hao Hu, Xingliang Hou, Shiquan Liu, Yifan Zhang, Xizhe Yu. 204 | 205 | If you have any questions, please feel free to contact us via email: [Yifan Feng](mailto:evanfeng97@gmail.com). 206 | 207 | This repo benefits from [LightRAG](https://github.com/HKUDS/LightRAG) and [Hypergraph-DB](https://github.com/iMoonLab/Hypergraph-DB). Thanks for their wonderful works. 208 | 209 |   210 | 211 | ## 🌟Citation 212 | ``` 213 | @misc{feng2025hyperrag, 214 | title={Hyper-RAG: Combating LLM Hallucinations using Hypergraph-Driven Retrieval-Augmented Generation}, 215 | author={Yifan Feng and Hao Hu and Xingliang Hou and Shiquan Liu and Shihui Ying and Shaoyi Du and Han Hu and Yue Gao}, 216 | year={2025}, 217 | eprint={2504.08758}, 218 | archivePrefix={arXiv}, 219 | primaryClass={cs.IR}, 220 | url={https://arxiv.org/abs/2504.08758}, 221 | } 222 | ``` 223 | 224 | Back to top 225 | -------------------------------------------------------------------------------- /assets/Hyper-RAG.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/Hyper-RAG/133529d4250273049a7e11192b142db0f92e8ac3/assets/Hyper-RAG.pdf -------------------------------------------------------------------------------- /assets/vis-QA.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/Hyper-RAG/133529d4250273049a7e11192b142db0f92e8ac3/assets/vis-QA.jpg -------------------------------------------------------------------------------- /assets/vis-hg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/Hyper-RAG/133529d4250273049a7e11192b142db0f92e8ac3/assets/vis-hg.jpg -------------------------------------------------------------------------------- /config_temp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | LLM_BASE_URL = "xxx" 4 | LLM_API_KEY = "xxx" 5 | LLM_MODEL = "gpt-4o-mini" 6 | 7 | EMB_BASE_URL = "xxx" 8 | EMB_API_KEY = "xxx" 9 | EMB_MODEL = "text-embedding-3-small" 10 | EMB_DIM = 1536 11 | -------------------------------------------------------------------------------- /evaluate/evaluate_by_scoring.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 5 | 6 | import re 7 | import json 8 | import numpy as np 9 | from tqdm import tqdm 10 | from openai import OpenAI 11 | from my_config import LLM_API_KEY, LLM_BASE_URL, LLM_MODEL 12 | 13 | 14 | def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: 15 | openai_client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL) 16 | 17 | messages = [] 18 | if system_prompt: 19 | messages.append({"role": "system", "content": system_prompt}) 20 | messages.extend(history_messages) 21 | messages.append({"role": "user", "content": prompt}) 22 | 23 | response = openai_client.chat.completions.create( 24 | model=LLM_MODEL, messages=messages, **kwargs 25 | ) 26 | return response.choices[0].message.content 27 | 28 | 29 | def extract_queries_and_answers(file_path): 30 | with open(file_path, "r", encoding="utf-8") as file: 31 | query_list = json.load(file) 32 | queries = [i["query"] for i in query_list] 33 | answers = [i["result"] for i in query_list] 34 | return queries, answers 35 | 36 | 37 | def extarct_queries_and_refs(file_path: Path): 38 | ref_file_path = file_path.with_stem(f"{file_path.stem}_ref") 39 | with open(file_path, "r", encoding="utf-8") as file: 40 | queries = json.load(file) 41 | with open(ref_file_path, "r", encoding="utf-8") as file: 42 | refs = json.load(file) 43 | return queries, refs 44 | 45 | 46 | def exam_by_scoring(queries, answers, refs): 47 | 48 | responses = [] 49 | sys_prompt = """ 50 | ---Role--- 51 | You are an expert tasked with evaluating answers to the questions by using the relevant documents based on five criteria:**Comprehensiveness**, **Diversity**,**Empowerment**, **Logical**,and **Readability** . 52 | 53 | """ 54 | for query, answer, reference in tqdm( 55 | zip(queries, answers, refs), desc="Evaluating answers", total=len(queries) 56 | ): 57 | prompt = f""" 58 | You will evaluate tht answers to the questions by using the relevant documents based on five criteria:**Comprehensiveness**, **Diversity**,**Empowerment**, **Logical**,and **Readability** . 59 | 60 | - **Comprehensiveness** - 61 | Measure whether the answer comprehensively covers all key aspects of the question and whether there are omissions. 62 | Level | score range | description 63 | Level 1 | 0-20 | The answer is extremely one-sided, leaving out key parts or important aspects of the question. 64 | Level 2 | 20-40 | The answer has some content, but it misses many important aspects of the question and is not comprehensive enough. 65 | Level 3 | 40-60 | The answer is more comprehensive, covering the main aspects of the question, but there are still some omissions. 66 | Level 4 | 60-80 | The answer is comprehensive, covering most aspects of the question, with few omissions. 67 | Level 5 | 80-100 | The answer is extremely comprehensive, covering all aspects of the question with no omissions, enabling the reader to gain a complete understanding. 68 | 69 | - **Diversity** - 70 | Measure the richness of the answer content, including not only the direct answer to the question, but also the background knowledge related to the question, extended information, case studies, etc. 71 | Level | score range | description 72 | Level 1 | 0-20 | The answer is extremely sparse, providing only direct answers to questions without additional information or expansion of relevant knowledge. 73 | Level 2 | 20-40 | The answer provides a direct answer to the question, but contains only a small amount of relevant knowledge expansion, the content is relatively thin. 74 | Level 3 | 40-60 | In addition to the direct answers, the answer also provides some relevant background knowledge or supplementary information. 75 | Level 4 | 60-80 | The answer is rich in content, not only answering the question, but also providing more relevant background knowledge, supplementary information or expanded content, so that readers can understand the question more comprehensively. 76 | Level 5 | 80-100 | In addition to the direct answers, the answer also provides a lot of relevant knowledge, expanded content and in-depth analysis, so that readers can get a comprehensive and in-depth understanding. 77 | 78 | - **Empowerment** - 79 | Measure the credibility of the answer and whether it convinces the reader that it is correct. High confidence answers often cite authoritative sources or provide sufficient evidence. 80 | Level | score range | description 81 | Level 1 | 0-20 | The answer lacks credibility, contains obvious errors or false information, and fails to convince the reader. 82 | Level 2 | 20-40 | The answer has some credibility, but some of the information is not accurate or lacks support, which may cause readers to doubt. 83 | Level 3 | 40-60 | The answer is credible and provides some supporting information, but there are still some areas that are not clear or authoritative. 84 | Level 4 | 60-80 | The answer is highly credible, providing sufficient supporting information (such as quotes, data, etc.), so that readers can be more convinced. 85 | Level 5 | 80-100 | The answer is highly credible, providing sufficient and authoritative supporting information, so that the reader is completely convinced of their correctness. 86 | 87 | - **Logical** - 88 | Measure whether the answer are coherent, clear, and easy to understand. 89 | Level | score range | description 90 | Level 1 | 0-20 | The answer is illogical, incoherent, and difficult to understand. 91 | Level 2 | 20-40 | The answer has some logic, but it is incoherent and difficult to understand in parts. 92 | Level 3 | 40-60 | The answer is logically clear and the sentences are basically coherent, but there are still a few logical loopholes or unclear places. 93 | Level 4 | 60-80 | The answer is logical, coherent, coherent, and easy to understand. 94 | Level 5 | 80-100 | The answer is extremely logical, fluent and well-organized, making it easy for the reader to follow the author's thoughts. 95 | 96 | - **Readability** - 97 | Measure whether the answer is well organized, clear in format, and easy to read. 98 | Level | score range | description 99 | Level 1 | 0-20 | The format of the answer is confused, the writing is poorly organized and difficult to read. 100 | Level 2 | 20-40 | There are some problems in the format of the answer, the organizational structure of the text is not clear enough, and it is difficult to read. 101 | Level 3 | 40-60 | The format of the answer is basically clear, the writing structure is good, but there is still room for improvement. 102 | Level 4 | 60-80 | The format of the answer is clear, the writing is well organized and the reading is smooth. 103 | Level 5 | 80-100 | The format of the answer is very clear, the writing structure is great, the reading experience is excellent, the format is standardized and easy to understand. 104 | 105 | For each indicator, please give the problem a corresponding Level based on the description of the indicator, and then give a score according to the score range of the level. 106 | 107 | 108 | 109 | Here are the relevant documents: 110 | {reference} 111 | 112 | Here are the questions: 113 | {query} 114 | 115 | Here are the answers: 116 | {answer} 117 | 118 | 119 | Evaluate all the answers using the six criteria listed above, for each criterion, provide a summary description, give a Level based on the description of the indicator, and then give a score based on the score range of the level. 120 | 121 | Output your evaluation in the following JSON format: 122 | 123 | {{ 124 | "Comprehensiveness": {{ 125 | "Explanation": "Provide explanation here" 126 | "Level": "A level range 1 to 5" # This should be a single number, not a range 127 | "Score": "A value range 0 to 100" # This should be a single number, not a range 128 | }}, 129 | "Diversity": {{ 130 | "Explanation": "Provide explanation here" 131 | "Level": "A level range 1 to 5" # This should be a single number, not a range 132 | "Score": "A value range 0 to 100" # This should be a single number, not a range 133 | }}, 134 | "Empowerment": {{ 135 | "Explanation": "Provide explanation here" 136 | "Level": "A level range 1 to 5" # This should be a single number, not a range 137 | "Score": "A value range 0 to 100" # This should be a single number, not a range 138 | }} 139 | "Logical": {{ 140 | "Explanation": "Provide explanation here" 141 | "Level": "A level range 1 to 5" # This should be a single number, not a range 142 | "Score": "A value range 0 to 100" # This should be a single number, not a range 143 | }} 144 | "Readability": {{ 145 | "Explanation": "Provide explanation here" 146 | "Level": "A level range 1 to 5" # This should be a single number, not a range 147 | "Score": "A value range 0 to 100" # This should be a single number, not a range 148 | }} 149 | 150 | }} 151 | 152 | """ 153 | response = llm_model_func(prompt, sys_prompt) 154 | responses.append(response) 155 | print(f"{len(responses)} responses evaluated.\n") 156 | 157 | return responses 158 | 159 | 160 | def fetch_scoring_results(responses): 161 | metric_name_list = [ 162 | "Comprehensiveness", 163 | "Diversity", 164 | "Empowerment", 165 | "Logical", 166 | "Readability", 167 | "Averaged Score", 168 | ] 169 | total_scores = [0] * 5 170 | for i, response in enumerate(responses): 171 | scores = re.findall(r'"Score":\s*(?:"?(\d+)"?)', response) 172 | for i in range(5): 173 | total_scores[i] += float(scores[i]) 174 | 175 | total_scores = np.array(total_scores) 176 | total_scores = total_scores / len(responses) 177 | total_scores = np.append(total_scores, np.mean(total_scores)) 178 | for metric_name, score in zip(metric_name_list, total_scores): 179 | print(f"{metric_name:20}: {score:.2f}") 180 | 181 | 182 | if __name__ == "__main__": 183 | data_name = "mix" 184 | mode, question_stage = "naive", 2 185 | WORKING_DIR = Path("caches") / data_name 186 | RESPONSE_DIR = WORKING_DIR / "response" 187 | question_file_path = WORKING_DIR / "questions" / f"{question_stage}_stage.json" 188 | answer_file_path = RESPONSE_DIR / f"{mode}_{question_stage}_stage_result.json" 189 | 190 | # extract questions, answers and references 191 | raw_queries, raw_refs = extarct_queries_and_refs(question_file_path) 192 | queries, answers = extract_queries_and_answers(answer_file_path) 193 | assert len(queries) == len(raw_queries) 194 | assert len(queries) == len(raw_refs) 195 | assert len(queries) == len(answers) 196 | 197 | # evaluate the answers 198 | responses = exam_by_scoring(raw_queries, answers, raw_refs) 199 | 200 | # save the results to a JSON file 201 | OUT_DIR = WORKING_DIR / "evalation" 202 | OUT_DIR.mkdir(parents=True, exist_ok=True) 203 | output_file_path = OUT_DIR / f"scoring_{question_stage}_stage_question.json" 204 | with open(output_file_path, "w", encoding="utf-8") as f: 205 | json.dump(responses, f, indent=4) 206 | print(f"Scoring-based evaluation results written to {output_file_path}\n\n") 207 | 208 | # calculate the scores 209 | print( 210 | f"Scoring-based evaluation for {question_stage}-stage questions of {mode} model:" 211 | ) 212 | fetch_scoring_results(responses) 213 | -------------------------------------------------------------------------------- /evaluate/evaluate_by_selection.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 5 | 6 | import re 7 | import json 8 | import numpy as np 9 | from tqdm import tqdm 10 | from openai import OpenAI 11 | from my_config import LLM_API_KEY, LLM_BASE_URL, LLM_MODEL 12 | 13 | 14 | def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: 15 | openai_client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL) 16 | 17 | messages = [] 18 | if system_prompt: 19 | messages.append({"role": "system", "content": system_prompt}) 20 | messages.extend(history_messages) 21 | messages.append({"role": "user", "content": prompt}) 22 | 23 | response = openai_client.chat.completions.create( 24 | model=LLM_MODEL, messages=messages, **kwargs 25 | ) 26 | return response.choices[0].message.content 27 | 28 | 29 | def extract_queries_and_answers(file_path): 30 | with open(file_path, "r", encoding="utf-8") as file: 31 | query_list = json.load(file) 32 | queries = [i["query"] for i in query_list] 33 | answers = [i["result"] for i in query_list] 34 | return queries, answers 35 | 36 | 37 | # Deal with one by one 38 | def exam_by_selection(queries, A_answers, B_answers): 39 | 40 | responses = [] 41 | sys_prompt = """ 42 | ---Role--- 43 | You will evaluate two answers to the same question based on eight criteria: *Comprehensiveness**, **Empowerment**, **Accuracy**,**Relevance**,**Coherence**, 44 | **Clarity**,**Logical**,and **Flexibility**. 45 | """ 46 | for query, answer1, answer2 in tqdm( 47 | zip(queries, A_answers, B_answers), 48 | desc="Evaluating answers", 49 | total=len(queries), 50 | ): 51 | prompt = f""" 52 | You will evaluate two answers to the same question by using the relevant documents based on eight criteria:*Comprehensiveness**, **Empowerment**, **Accuracy**,**Relevance**,**Coherence**, 53 | **Clarity**,**Logical**,and **Flexibility**. 54 | 55 | - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question? 56 | - **Empowerment**: How well does the answer help the reader understand and make informed judgments about the topic? 57 | - **Accuracy**: How well does the answer align with factual truth and avoid hallucination based on the retrieved context? 58 | - **Relevance**: How precisely does the answer address the core aspects of the question without including unnecessary information? 59 | - **Coherence**: How well does the system integrate and synthesize information from multiple sources into a logically flowing response? 60 | - **Clarity**: How well does the system provide complete information while avoiding unnecessary verbosity and redundancy? 61 | - **Logical**: How well does the system maintain consistent logical arguments without contradicting itself across the response? 62 | - **Flexibility**: How well does the system handle various question formats, tones, and levels of complexity? 63 | 64 | For each criterion, choose the better answer (either Answer 1 or Answer 2) and explain why. Then, select an overall winner based on these ten categories. 65 | 66 | 67 | 68 | Here are the questions: 69 | {query} 70 | 71 | Here are the two answers: 72 | 73 | **Answer 1:** 74 | {answer1} 75 | 76 | **Answer 2:** 77 | {answer2} 78 | 79 | Evaluate both answers using the eight criteria listed above and provide detailed explanations for each criterion. 80 | 81 | Output your evaluation in the following JSON format: 82 | 83 | {{ 84 | "Comprehensiveness": {{ 85 | "Winner": "[Answer 1 or Answer 2]", 86 | "Explanation": "[Provide explanation here]" 87 | }}, 88 | "Empowerment": {{ 89 | "Winner": "[Answer 1 or Answer 2]", 90 | "Explanation": "[Provide explanation here]" 91 | }}, 92 | "Accuracy": {{ 93 | "Winner": "[Answer 1 or Answer 2]", 94 | "Explanation": "[Provide explanation here]" 95 | }}, 96 | "Relevance": {{ 97 | "Winner": "[Answer 1 or Answer 2]", 98 | "Explanation": "[Provide explanation here]" 99 | }}, 100 | "Coherence": {{ 101 | "Winner": "[Answer 1 or Answer 2]", 102 | "Explanation": "[Provide explanation here]" 103 | }}, 104 | "Clarity": {{ 105 | "Winner": "[Answer 1 or Answer 2]", 106 | "Explanation": "[Provide explanation here]" 107 | }}, 108 | "Logical": {{ 109 | "Winner": "[Answer 1 or Answer 2]", 110 | "Explanation": "[Provide explanation here]" 111 | }}, 112 | "Flexibility": {{ 113 | "Winner": "[Answer 1 or Answer 2]", 114 | "Explanation": "[Provide explanation here]" 115 | }}, 116 | }} 117 | 118 | """ 119 | response = llm_model_func(prompt, sys_prompt) 120 | responses.append(response) 121 | print(f"{len(responses)} responses evaluated.\n") 122 | 123 | return responses 124 | 125 | 126 | def fetch_selection_results(responses): 127 | metric_name_list = [ 128 | "Comprehensiveness", 129 | "Empowerment", 130 | "Accuracy", 131 | "Relevance", 132 | "Coherence", 133 | "Clarity", 134 | "Logical", 135 | "Flexibility", 136 | "Averaged Score", 137 | ] 138 | total_scores = [0] * 8 139 | for i, response in enumerate(responses): 140 | # response = response.replace('```json\\n', '').replace('```', '').strip() 141 | # response = response.strip('"').replace('\\n', '\n').replace('\\"', '"') 142 | scores = re.findall(r'"Winner":\s*"([^"]+)"', response) 143 | for i in range(8): 144 | if scores[i].lower() == "answer 1": 145 | total_scores[i] += 1 146 | 147 | total_scores = np.array(total_scores) 148 | total_scores = total_scores / len(responses) 149 | total_scores = np.append(total_scores, np.mean(total_scores)) 150 | for metric_name, score in zip(metric_name_list, total_scores): 151 | print(f"{metric_name:20}: {score:.2f} vs. {1 - score:.2f}") 152 | 153 | 154 | if __name__ == "__main__": 155 | data_name = "mix" 156 | question_stage = 2 157 | # Note: we noticted that the position of the answer (first position or second position) 158 | # will effect the results. Thus, we suggest to average the results of 159 | # (A_mode vs. B_mode) and (B_mode vs. A_mode) as the final results. 160 | A_mode, B_mode = "hyper", "naive" 161 | WORKING_DIR = Path("caches") / data_name 162 | RESPONSE_DIR = WORKING_DIR / "response" 163 | A_answer_file_path = RESPONSE_DIR / f"{A_mode}_{question_stage}_stage_result.json" 164 | B_answer_file_path = RESPONSE_DIR / f"{B_mode}_{question_stage}_stage_result.json" 165 | 166 | # extract questions, answers and references 167 | A_queries, A_answers = extract_queries_and_answers(A_answer_file_path) 168 | B_queries, B_answers = extract_queries_and_answers(B_answer_file_path) 169 | assert len(A_queries) == len(B_queries) 170 | assert len(A_answers) == len(B_answers) 171 | 172 | # evaluate the answers 173 | responses = exam_by_selection(A_queries, A_answers, B_answers) 174 | 175 | # save the results to a JSON file 176 | OUT_DIR = WORKING_DIR / "evalation" 177 | OUT_DIR.mkdir(parents=True, exist_ok=True) 178 | output_file_path = ( 179 | OUT_DIR / f"selection_{question_stage}_stage_question_{A_mode}_vs_{B_mode}.json" 180 | ) 181 | with open(output_file_path, "w", encoding="utf-8") as f: 182 | json.dump(responses, f, indent=4) 183 | print(f"Selection-based evaluation results written to {output_file_path}\n\n") 184 | 185 | # calculate the scores 186 | print( 187 | f"Selection-based evaluation for {question_stage}-stage questions of {A_mode} vs. {B_mode}:" 188 | ) 189 | fetch_selection_results(responses) 190 | -------------------------------------------------------------------------------- /examples/hyperrag_demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 5 | 6 | import time 7 | import numpy as np 8 | 9 | from hyperrag import HyperRAG, QueryParam 10 | from hyperrag.utils import EmbeddingFunc 11 | from hyperrag.llm import openai_embedding, openai_complete_if_cache 12 | 13 | from my_config import LLM_API_KEY, LLM_BASE_URL, LLM_MODEL 14 | from my_config import EMB_API_KEY, EMB_BASE_URL, EMB_MODEL, EMB_DIM 15 | 16 | 17 | async def llm_model_func( 18 | prompt, system_prompt=None, history_messages=[], **kwargs 19 | ) -> str: 20 | return await openai_complete_if_cache( 21 | LLM_MODEL, 22 | prompt, 23 | system_prompt=system_prompt, 24 | history_messages=history_messages, 25 | api_key=LLM_API_KEY, 26 | base_url=LLM_BASE_URL, 27 | **kwargs, 28 | ) 29 | 30 | 31 | async def embedding_func(texts: list[str]) -> np.ndarray: 32 | return await openai_embedding( 33 | texts, 34 | model=EMB_MODEL, 35 | api_key=EMB_API_KEY, 36 | base_url=EMB_BASE_URL, 37 | ) 38 | 39 | 40 | def insert_texts_with_retry(rag, texts, retries=3, delay=5): 41 | for _ in range(retries): 42 | try: 43 | rag.insert(texts) 44 | return 45 | except Exception as e: 46 | print( 47 | f"Error occurred during insertion: {e}. Retrying in {delay} seconds..." 48 | ) 49 | time.sleep(delay) 50 | raise RuntimeError("Failed to insert texts after multiple retries.") 51 | 52 | 53 | if __name__ == "__main__": 54 | data_name = "mock" 55 | WORKING_DIR = Path("caches") / data_name 56 | WORKING_DIR.mkdir(parents=True, exist_ok=True) 57 | rag = HyperRAG( 58 | working_dir=WORKING_DIR, 59 | llm_model_func=llm_model_func, 60 | embedding_func=EmbeddingFunc( 61 | embedding_dim=EMB_DIM, max_token_size=8192, func=embedding_func 62 | ), 63 | ) 64 | 65 | # read the text file 66 | mock_data_file_path = Path("examples/mock_data.txt") 67 | with open(mock_data_file_path, "r", encoding="utf-8") as file: 68 | texts = file.read() 69 | 70 | # Insert the text into the RAG 71 | insert_texts_with_retry(rag, texts) 72 | 73 | # Perform different types of queries and handle potential errors 74 | try: 75 | print("\n\n\nPerforming Naive RAG...") 76 | print( 77 | rag.query( 78 | "What are the top themes in this story?", 79 | param=QueryParam(mode="naive") 80 | ) 81 | ) 82 | except Exception as e: 83 | print(f"Error performing naive-rag search: {e}") 84 | 85 | try: 86 | print("\n\n\nPerforming Hyper-RAG...") 87 | print( 88 | rag.query( 89 | "What are the top themes in this story?", 90 | param=QueryParam(mode="hyper") 91 | ) 92 | ) 93 | except Exception as e: 94 | print(f"Error performing hyper-rag search: {e}") 95 | 96 | try: 97 | print("\n\n\nPerforming Hyper-RAG-Lite...") 98 | print( 99 | rag.query( 100 | "What are the top themes in this story?", 101 | param=QueryParam(mode="hyper-lite"), 102 | ) 103 | ) 104 | except Exception as e: 105 | print(f"Error performing hyper-rag-lite search: {e}") 106 | -------------------------------------------------------------------------------- /hyperrag/__init__.py: -------------------------------------------------------------------------------- 1 | from .hyperrag import HyperRAG, QueryParam 2 | 3 | 4 | __version__ = "0.0.1" 5 | 6 | __all__ = { 7 | HyperRAG, 8 | QueryParam, 9 | } 10 | -------------------------------------------------------------------------------- /hyperrag/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import TypedDict, Union, Literal, Generic, TypeVar, Any, Tuple, List, Set, Optional, Dict 3 | 4 | from .utils import EmbeddingFunc 5 | 6 | TextChunkSchema = TypedDict( 7 | "TextChunkSchema", 8 | {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}, 9 | ) 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | @dataclass 15 | class QueryParam: 16 | mode: Literal["hyper", "hyper-lite", "naive"] = "hyper-query" 17 | only_need_context: bool = False 18 | response_type: str = "Multiple Paragraphs" 19 | # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode. 20 | top_k: int = 60 21 | # Number of tokens for the original chunks. 22 | max_token_for_text_unit: int = 1600 23 | # Number of tokens for the entity descriptions 24 | max_token_for_entity_context: int = 300 25 | # Number of tokens for the relationship descriptions 26 | max_token_for_relation_context: int = 1600 27 | 28 | 29 | @dataclass 30 | class StorageNameSpace: 31 | namespace: str 32 | global_config: dict 33 | 34 | async def index_done_callback(self): 35 | """commit the storage operations after indexing""" 36 | pass 37 | 38 | async def query_done_callback(self): 39 | """commit the storage operations after querying""" 40 | pass 41 | 42 | 43 | @dataclass 44 | class BaseVectorStorage(StorageNameSpace): 45 | embedding_func: EmbeddingFunc 46 | meta_fields: set = field(default_factory=set) 47 | 48 | async def query(self, query: str, top_k: int) -> list[dict]: 49 | raise NotImplementedError 50 | 51 | async def upsert(self, data: dict[str, dict]): 52 | """Use 'content' field from value for embedding, use key as id. 53 | If embedding_func is None, use 'embedding' field from value 54 | """ 55 | raise NotImplementedError 56 | 57 | 58 | @dataclass 59 | class BaseKVStorage(Generic[T], StorageNameSpace): 60 | async def all_keys(self) -> list[str]: 61 | raise NotImplementedError 62 | 63 | async def get_by_id(self, id: str) -> Union[T, None]: 64 | raise NotImplementedError 65 | 66 | async def get_by_ids( 67 | self, ids: list[str], fields: Union[set[str], None] = None 68 | ) -> list[Union[T, None]]: 69 | raise NotImplementedError 70 | 71 | async def filter_keys(self, data: list[str]) -> set[str]: 72 | """return un-exist keys""" 73 | raise NotImplementedError 74 | 75 | async def upsert(self, data: dict[str, T]): 76 | raise NotImplementedError 77 | 78 | async def drop(self): 79 | raise NotImplementedError 80 | 81 | """ 82 | The BaseHypergraphStorage based on hypergraph-DB 83 | """ 84 | @dataclass 85 | class BaseHypergraphStorage(StorageNameSpace): 86 | async def has_vertex(self, v_id: Any) -> bool: 87 | raise NotImplementedError 88 | 89 | async def has_hyperedge(self, e_tuple: Union[List, Set, Tuple]) -> bool: 90 | raise NotImplementedError 91 | 92 | async def get_vertex(self, v_id: str, default: Any = None) : 93 | raise NotImplementedError 94 | 95 | async def get_hyperedge(self, e_tuple: Union[List, Set, Tuple], default: Any = None) : 96 | raise NotImplementedError 97 | 98 | async def get_all_vertices(self): 99 | raise NotImplementedError 100 | 101 | async def get_all_hyperedges(self): 102 | raise NotImplementedError 103 | 104 | async def get_num_of_vertices(self): 105 | raise NotImplementedError 106 | 107 | async def get_num_of_hyperedges(self): 108 | raise NotImplementedError 109 | 110 | async def upsert_vertex(self, v_id: Any, v_data: Optional[Dict] = None) : 111 | raise NotImplementedError 112 | 113 | async def upsert_hyperedge(self, e_tuple: Union[List, Set, Tuple], e_data: Optional[Dict] = None) : 114 | raise NotImplementedError 115 | 116 | async def remove_vertex(self, v_id: Any) : 117 | raise NotImplementedError 118 | 119 | async def remove_hyperedge(self, e_tuple: Union[List, Set, Tuple]) : 120 | raise NotImplementedError 121 | 122 | async def vertex_degree(self, v_id: Any) -> int: 123 | raise NotImplementedError 124 | 125 | async def hyperedge_degree(self, e_tuple: Union[List, Set, Tuple]) -> int: 126 | raise NotImplementedError 127 | 128 | async def get_nbr_e_of_vertex(self, e_tuple: Union[List, Set, Tuple]) -> list: 129 | raise NotImplementedError 130 | 131 | async def get_nbr_v_of_hyperedge(self, v_id: Any, exclude_self=True) -> list: 132 | raise NotImplementedError 133 | 134 | async def get_nbr_v_of_vertex(self, v_id: Any, exclude_self=True) -> list: 135 | raise NotImplementedError -------------------------------------------------------------------------------- /hyperrag/hyperrag.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | from dataclasses import asdict, dataclass, field 4 | from datetime import datetime 5 | from functools import partial 6 | from typing import Type, cast 7 | 8 | from .operate import ( 9 | chunking_by_token_size, 10 | extract_entities, 11 | hyper_query_lite, 12 | hyper_query, 13 | naive_query, 14 | ) 15 | from .llm import ( 16 | gpt_4o_mini_complete, 17 | openai_embedding, 18 | ) 19 | 20 | from .storage import ( 21 | JsonKVStorage, 22 | NanoVectorDBStorage, 23 | HypergraphStorage, 24 | ) 25 | 26 | 27 | from .utils import ( 28 | EmbeddingFunc, 29 | compute_mdhash_id, 30 | limit_async_func_call, 31 | convert_response_to_json, 32 | logger, 33 | set_logger, 34 | ) 35 | from .base import ( 36 | BaseKVStorage, 37 | BaseVectorStorage, 38 | StorageNameSpace, 39 | QueryParam, 40 | BaseHypergraphStorage, 41 | ) 42 | 43 | 44 | def always_get_an_event_loop() -> asyncio.AbstractEventLoop: 45 | try: 46 | return asyncio.get_event_loop() 47 | 48 | except RuntimeError: 49 | logger.info("Creating a new event loop in main thread.") 50 | loop = asyncio.new_event_loop() 51 | asyncio.set_event_loop(loop) 52 | 53 | return loop 54 | 55 | 56 | @dataclass 57 | class HyperRAG: 58 | working_dir: str = field( 59 | default_factory=lambda: f"./HyperRAG_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}" 60 | ) 61 | print(working_dir) 62 | 63 | current_log_level = logger.level 64 | log_level: str = field(default=current_log_level) 65 | 66 | # text chunking 67 | chunk_token_size: int = 1200 68 | chunk_overlap_token_size: int = 100 69 | tiktoken_model_name: str = "gpt-4o-mini" 70 | 71 | # entity extraction 72 | entity_extract_max_gleaning: int = 1 73 | entity_summary_to_max_tokens: int = 500 74 | entity_additional_properties_to_max_tokens: int = 250 75 | relation_summary_to_max_tokens: int = 750 76 | relation_keywords_to_max_tokens: int = 100 77 | 78 | embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding) 79 | embedding_batch_num: int = 32 80 | embedding_func_max_async: int = 16 81 | 82 | # LLM 83 | llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete# 84 | # llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it' 85 | llm_model_name: str = "" 86 | llm_model_max_token_size: int = 32768 87 | llm_model_max_async: int = 16 88 | llm_model_kwargs: dict = field(default_factory=dict) 89 | 90 | # storage 91 | key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage 92 | vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage 93 | vector_db_storage_cls_kwargs: dict = field(default_factory=dict) 94 | hypergraph_storage_cls: Type[BaseHypergraphStorage] = HypergraphStorage 95 | enable_llm_cache: bool = True 96 | 97 | # extension 98 | addon_params: dict = field(default_factory=dict) 99 | convert_response_to_json_func: callable = convert_response_to_json 100 | 101 | def __post_init__(self): 102 | log_file = os.path.join(self.working_dir, "HyperRAG.log") 103 | set_logger(log_file) 104 | logger.setLevel(self.log_level) 105 | 106 | logger.info(f"Logger initialized for working directory: {self.working_dir}") 107 | 108 | _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()]) 109 | logger.debug(f"HyperRAG init with param:\n {_print_config}\n") 110 | 111 | if not os.path.exists(self.working_dir): 112 | logger.info(f"Creating working directory {self.working_dir}") 113 | os.makedirs(self.working_dir) 114 | 115 | self.full_docs = self.key_string_value_json_storage_cls( 116 | namespace="full_docs", global_config=asdict(self) 117 | ) 118 | 119 | self.text_chunks = self.key_string_value_json_storage_cls( 120 | namespace="text_chunks", global_config=asdict(self) 121 | ) 122 | 123 | self.llm_response_cache = ( 124 | self.key_string_value_json_storage_cls( 125 | namespace="llm_response_cache", global_config=asdict(self) 126 | ) 127 | if self.enable_llm_cache 128 | else None 129 | ) 130 | """ 131 | download from hgdb_path 132 | """ 133 | self.chunk_entity_relation_hypergraph = self.hypergraph_storage_cls( 134 | namespace="chunk_entity_relation", global_config=asdict(self) 135 | ) 136 | 137 | self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( 138 | self.embedding_func 139 | ) 140 | 141 | self.entities_vdb = self.vector_db_storage_cls( 142 | namespace="entities", 143 | global_config=asdict(self), 144 | embedding_func=self.embedding_func, 145 | meta_fields={"entity_name"}, 146 | ) 147 | self.relationships_vdb = self.vector_db_storage_cls( 148 | namespace="relationships", 149 | global_config=asdict(self), 150 | embedding_func=self.embedding_func, 151 | meta_fields={"id_set"}, 152 | ) 153 | self.chunks_vdb = self.vector_db_storage_cls( 154 | namespace="chunks", 155 | global_config=asdict(self), 156 | embedding_func=self.embedding_func, 157 | ) 158 | 159 | self.llm_model_func = limit_async_func_call(self.llm_model_max_async)( 160 | partial( 161 | self.llm_model_func, 162 | hashing_kv=self.llm_response_cache, 163 | **self.llm_model_kwargs, 164 | ) 165 | ) 166 | 167 | def insert(self, string_or_strings): 168 | loop = always_get_an_event_loop() 169 | return loop.run_until_complete(self.ainsert(string_or_strings)) 170 | 171 | async def ainsert(self, string_or_strings): 172 | try: 173 | if isinstance(string_or_strings, str): 174 | string_or_strings = [string_or_strings] 175 | 176 | new_docs = { 177 | compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()} 178 | for c in string_or_strings 179 | } 180 | _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys())) 181 | new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} 182 | if not len(new_docs): 183 | logger.warning("All docs are already in the storage") 184 | return 185 | # ---------------------------------------------------------------------------- 186 | logger.info(f"[New Docs] inserting {len(new_docs)} docs") 187 | 188 | inserting_chunks = {} 189 | for doc_key, doc in new_docs.items(): 190 | chunks = { 191 | compute_mdhash_id(dp["content"], prefix="chunk-"): { 192 | **dp, 193 | "full_doc_id": doc_key, 194 | } 195 | for dp in chunking_by_token_size( 196 | doc["content"], 197 | overlap_token_size=self.chunk_overlap_token_size, 198 | max_token_size=self.chunk_token_size, 199 | tiktoken_model=self.tiktoken_model_name, 200 | ) 201 | } 202 | inserting_chunks.update(chunks) 203 | _add_chunk_keys = await self.text_chunks.filter_keys( 204 | list(inserting_chunks.keys()) 205 | ) 206 | inserting_chunks = { 207 | k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys 208 | } 209 | if not len(inserting_chunks): 210 | logger.warning("All chunks are already in the storage") 211 | return 212 | # ---------------------------------------------------------------------------- 213 | logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks") 214 | 215 | await self.chunks_vdb.upsert(inserting_chunks) 216 | # ---------------------------------------------------------------------------- 217 | logger.info("[Entity Extraction]...") 218 | maybe_new_kg = await extract_entities( 219 | inserting_chunks, 220 | knowledge_hypergraph_inst=self.chunk_entity_relation_hypergraph, 221 | entity_vdb=self.entities_vdb, 222 | relationships_vdb=self.relationships_vdb, 223 | global_config=asdict(self), 224 | ) 225 | if maybe_new_kg is None: 226 | logger.warning("No new entities and relationships found") 227 | return 228 | # ---------------------------------------------------------------------------- 229 | self.chunk_entity_relation_hypergraph = maybe_new_kg 230 | await self.full_docs.upsert(new_docs) 231 | await self.text_chunks.upsert(inserting_chunks) 232 | finally: 233 | await self._insert_done() 234 | 235 | async def _insert_done(self): 236 | tasks = [] 237 | for storage_inst in [ 238 | self.full_docs, 239 | self.text_chunks, 240 | self.llm_response_cache, 241 | self.entities_vdb, 242 | self.relationships_vdb, 243 | self.chunks_vdb, 244 | self.chunk_entity_relation_hypergraph, 245 | ]: 246 | if storage_inst is None: 247 | continue 248 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback()) 249 | await asyncio.gather(*tasks) 250 | 251 | def query(self, query: str, param: QueryParam = QueryParam()): 252 | loop = always_get_an_event_loop() 253 | return loop.run_until_complete(self.aquery(query, param)) 254 | 255 | async def aquery(self, query: str, param: QueryParam = QueryParam()): 256 | 257 | if param.mode == "hyper": 258 | response = await hyper_query( 259 | query, 260 | self.chunk_entity_relation_hypergraph, 261 | self.entities_vdb, 262 | self.relationships_vdb, 263 | self.text_chunks, 264 | param, 265 | asdict(self), 266 | ) 267 | elif param.mode == "hyper-lite": 268 | response = await hyper_query_lite( 269 | query, 270 | self.chunk_entity_relation_hypergraph, 271 | self.entities_vdb, 272 | self.text_chunks, 273 | param, 274 | asdict(self), 275 | ) 276 | elif param.mode == "naive": 277 | response = await naive_query( 278 | query, 279 | self.chunks_vdb, 280 | self.text_chunks, 281 | param, 282 | asdict(self), 283 | ) 284 | else: 285 | raise ValueError(f"Unknown mode {param.mode}") 286 | await self._query_done() 287 | return response 288 | 289 | async def _query_done(self): 290 | tasks = [] 291 | for storage_inst in [self.llm_response_cache]: 292 | if storage_inst is None: 293 | continue 294 | tasks.append(cast(StorageNameSpace, storage_inst).query_done_callback()) 295 | await asyncio.gather(*tasks) 296 | -------------------------------------------------------------------------------- /hyperrag/llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from functools import lru_cache 4 | import json 5 | import aioboto3 6 | import aiohttp 7 | import numpy as np 8 | 9 | from openai import ( 10 | AsyncOpenAI, 11 | APIConnectionError, 12 | RateLimitError, 13 | Timeout, 14 | AsyncAzureOpenAI, 15 | ) 16 | 17 | import base64 18 | import struct 19 | 20 | from tenacity import ( 21 | retry, 22 | stop_after_attempt, 23 | wait_exponential, 24 | retry_if_exception_type, 25 | ) 26 | from pydantic import BaseModel, Field 27 | from typing import List, Dict, Callable, Any 28 | from .base import BaseKVStorage 29 | from .utils import compute_args_hash, wrap_embedding_func_with_attrs 30 | 31 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 32 | 33 | 34 | @retry( 35 | stop=stop_after_attempt(3), 36 | wait=wait_exponential(multiplier=1, min=4, max=10), 37 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 38 | ) 39 | async def openai_complete_if_cache( 40 | model, 41 | prompt, 42 | system_prompt=None, 43 | history_messages=[], 44 | base_url=None, 45 | api_key=None, 46 | **kwargs, 47 | ) -> str: 48 | if api_key: 49 | os.environ["OPENAI_API_KEY"] = api_key 50 | 51 | openai_async_client = ( 52 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) 53 | ) 54 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) 55 | messages = [] 56 | if system_prompt is not None: 57 | messages.append({"role": "system", "content": system_prompt}) 58 | messages.extend(history_messages) 59 | messages.append({"role": "user", "content": prompt}) 60 | if hashing_kv is not None: 61 | args_hash = compute_args_hash(model, messages) 62 | if_cache_return = await hashing_kv.get_by_id(args_hash) 63 | if if_cache_return is not None: 64 | return if_cache_return["return"] 65 | 66 | response = await openai_async_client.chat.completions.create( 67 | model=model, messages=messages, **kwargs 68 | ) 69 | 70 | if hashing_kv is not None: 71 | await hashing_kv.upsert( 72 | {args_hash: {"return": response.choices[0].message.content, "model": model}} 73 | ) 74 | return response.choices[0].message.content 75 | 76 | 77 | @retry( 78 | stop=stop_after_attempt(3), 79 | wait=wait_exponential(multiplier=1, min=4, max=10), 80 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 81 | ) 82 | async def azure_openai_complete_if_cache( 83 | model, 84 | prompt, 85 | system_prompt=None, 86 | history_messages=[], 87 | base_url=None, 88 | api_key=None, 89 | **kwargs, 90 | ): 91 | if api_key: 92 | os.environ["AZURE_OPENAI_API_KEY"] = api_key 93 | if base_url: 94 | os.environ["AZURE_OPENAI_ENDPOINT"] = base_url 95 | 96 | openai_async_client = AsyncAzureOpenAI( 97 | azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), 98 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 99 | api_version=os.getenv("AZURE_OPENAI_API_VERSION"), 100 | ) 101 | 102 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) 103 | messages = [] 104 | if system_prompt: 105 | messages.append({"role": "system", "content": system_prompt}) 106 | messages.extend(history_messages) 107 | if prompt is not None: 108 | messages.append({"role": "user", "content": prompt}) 109 | if hashing_kv is not None: 110 | args_hash = compute_args_hash(model, messages) 111 | if_cache_return = await hashing_kv.get_by_id(args_hash) 112 | if if_cache_return is not None: 113 | return if_cache_return["return"] 114 | 115 | response = await openai_async_client.chat.completions.create( 116 | model=model, messages=messages, **kwargs 117 | ) 118 | 119 | if hashing_kv is not None: 120 | await hashing_kv.upsert( 121 | {args_hash: {"return": response.choices[0].message.content, "model": model}} 122 | ) 123 | return response.choices[0].message.content 124 | 125 | 126 | class BedrockError(Exception): 127 | """Generic error for issues related to Amazon Bedrock""" 128 | 129 | 130 | @retry( 131 | stop=stop_after_attempt(5), 132 | wait=wait_exponential(multiplier=1, max=60), 133 | retry=retry_if_exception_type((BedrockError)), 134 | ) 135 | async def bedrock_complete_if_cache( 136 | model, 137 | prompt, 138 | system_prompt=None, 139 | history_messages=[], 140 | aws_access_key_id=None, 141 | aws_secret_access_key=None, 142 | aws_session_token=None, 143 | **kwargs, 144 | ) -> str: 145 | os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( 146 | "AWS_ACCESS_KEY_ID", aws_access_key_id 147 | ) 148 | os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( 149 | "AWS_SECRET_ACCESS_KEY", aws_secret_access_key 150 | ) 151 | os.environ["AWS_SESSION_TOKEN"] = os.environ.get( 152 | "AWS_SESSION_TOKEN", aws_session_token 153 | ) 154 | 155 | # Fix message history format 156 | messages = [] 157 | for history_message in history_messages: 158 | message = copy.copy(history_message) 159 | message["content"] = [{"text": message["content"]}] 160 | messages.append(message) 161 | 162 | # Add user prompt 163 | messages.append({"role": "user", "content": [{"text": prompt}]}) 164 | 165 | # Initialize Converse API arguments 166 | args = {"modelId": model, "messages": messages} 167 | 168 | # Define system prompt 169 | if system_prompt: 170 | args["system"] = [{"text": system_prompt}] 171 | 172 | # Map and set up inference parameters 173 | inference_params_map = { 174 | "max_tokens": "maxTokens", 175 | "top_p": "topP", 176 | "stop_sequences": "stopSequences", 177 | } 178 | if inference_params := list( 179 | set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"]) 180 | ): 181 | args["inferenceConfig"] = {} 182 | for param in inference_params: 183 | args["inferenceConfig"][inference_params_map.get(param, param)] = ( 184 | kwargs.pop(param) 185 | ) 186 | 187 | hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) 188 | if hashing_kv is not None: 189 | args_hash = compute_args_hash(model, messages) 190 | if_cache_return = await hashing_kv.get_by_id(args_hash) 191 | if if_cache_return is not None: 192 | return if_cache_return["return"] 193 | 194 | # Call model via Converse API 195 | session = aioboto3.Session() 196 | async with session.client("bedrock-runtime") as bedrock_async_client: 197 | try: 198 | response = await bedrock_async_client.converse(**args, **kwargs) 199 | except Exception as e: 200 | raise BedrockError(e) 201 | 202 | if hashing_kv is not None: 203 | await hashing_kv.upsert( 204 | { 205 | args_hash: { 206 | "return": response["output"]["message"]["content"][0]["text"], 207 | "model": model, 208 | } 209 | } 210 | ) 211 | 212 | return response["output"]["message"]["content"][0]["text"] 213 | 214 | 215 | async def gpt_4o_complete( 216 | prompt, system_prompt=None, history_messages=[], **kwargs 217 | ) -> str: 218 | 219 | return await openai_complete_if_cache( 220 | "gpt-4o", 221 | prompt, 222 | system_prompt=system_prompt, 223 | history_messages=history_messages, 224 | 225 | **kwargs, 226 | ) 227 | 228 | 229 | async def gpt_4o_mini_complete( 230 | prompt, system_prompt=None, history_messages=[], **kwargs 231 | ) -> str: 232 | return await openai_complete_if_cache( 233 | "gpt-4o-mini", 234 | prompt, 235 | system_prompt=system_prompt, 236 | history_messages=history_messages, 237 | **kwargs, 238 | ) 239 | 240 | 241 | async def azure_openai_complete( 242 | prompt, system_prompt=None, history_messages=[], **kwargs 243 | ) -> str: 244 | return await azure_openai_complete_if_cache( 245 | "conversation-4o-mini", 246 | prompt, 247 | system_prompt=system_prompt, 248 | history_messages=history_messages, 249 | **kwargs, 250 | ) 251 | 252 | 253 | async def bedrock_complete( 254 | prompt, system_prompt=None, history_messages=[], **kwargs 255 | ) -> str: 256 | return await bedrock_complete_if_cache( 257 | "anthropic.claude-3-haiku-20240307-v1:0", 258 | prompt, 259 | system_prompt=system_prompt, 260 | history_messages=history_messages, 261 | **kwargs, 262 | ) 263 | 264 | 265 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) 266 | @retry( 267 | stop=stop_after_attempt(3), 268 | wait=wait_exponential(multiplier=1, min=4, max=60), 269 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 270 | ) 271 | async def openai_embedding( 272 | texts: list[str], 273 | model: str = "text-embedding-3-small", 274 | base_url: str = None, 275 | api_key: str = None, 276 | ) -> np.ndarray: 277 | if api_key: 278 | os.environ["OPENAI_API_KEY"] = api_key 279 | 280 | openai_async_client = ( 281 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) 282 | ) 283 | response = await openai_async_client.embeddings.create( 284 | model=model, input=texts, encoding_format="float" 285 | ) 286 | return np.array([dp.embedding for dp in response.data]) 287 | 288 | 289 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) 290 | @retry( 291 | stop=stop_after_attempt(3), 292 | wait=wait_exponential(multiplier=1, min=4, max=10), 293 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 294 | ) 295 | async def azure_openai_embedding( 296 | texts: list[str], 297 | model: str = "text-embedding-3-small", 298 | base_url: str = None, 299 | api_key: str = None, 300 | ) -> np.ndarray: 301 | if api_key: 302 | os.environ["AZURE_OPENAI_API_KEY"] = api_key 303 | if base_url: 304 | os.environ["AZURE_OPENAI_ENDPOINT"] = base_url 305 | 306 | openai_async_client = AsyncAzureOpenAI( 307 | azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), 308 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 309 | api_version=os.getenv("AZURE_OPENAI_API_VERSION"), 310 | ) 311 | 312 | response = await openai_async_client.embeddings.create( 313 | model=model, input=texts, encoding_format="float" 314 | ) 315 | return np.array([dp.embedding for dp in response.data]) 316 | 317 | 318 | @retry( 319 | stop=stop_after_attempt(3), 320 | wait=wait_exponential(multiplier=1, min=4, max=60), 321 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), 322 | ) 323 | async def siliconcloud_embedding( 324 | texts: list[str], 325 | model: str = "netease-youdao/bce-embedding-base_v1", 326 | base_url: str = "https://api.siliconflow.cn/v1/embeddings", 327 | max_token_size: int = 512, 328 | api_key: str = None, 329 | ) -> np.ndarray: 330 | if api_key and not api_key.startswith("Bearer "): 331 | api_key = "Bearer " + api_key 332 | 333 | headers = {"Authorization": api_key, "Content-Type": "application/json"} 334 | 335 | truncate_texts = [text[0:max_token_size] for text in texts] 336 | 337 | payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"} 338 | 339 | base64_strings = [] 340 | async with aiohttp.ClientSession() as session: 341 | async with session.post(base_url, headers=headers, json=payload) as response: 342 | content = await response.json() 343 | if "code" in content: 344 | raise ValueError(content) 345 | base64_strings = [item["embedding"] for item in content["data"]] 346 | 347 | embeddings = [] 348 | for string in base64_strings: 349 | decode_bytes = base64.b64decode(string) 350 | n = len(decode_bytes) // 4 351 | float_array = struct.unpack("<" + "f" * n, decode_bytes) 352 | embeddings.append(float_array) 353 | return np.array(embeddings) 354 | 355 | 356 | # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) 357 | # @retry( 358 | # stop=stop_after_attempt(3), 359 | # wait=wait_exponential(multiplier=1, min=4, max=10), 360 | # retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions 361 | # ) 362 | async def bedrock_embedding( 363 | texts: list[str], 364 | model: str = "amazon.titan-embed-text-v2:0", 365 | aws_access_key_id=None, 366 | aws_secret_access_key=None, 367 | aws_session_token=None, 368 | ) -> np.ndarray: 369 | os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get( 370 | "AWS_ACCESS_KEY_ID", aws_access_key_id 371 | ) 372 | os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get( 373 | "AWS_SECRET_ACCESS_KEY", aws_secret_access_key 374 | ) 375 | os.environ["AWS_SESSION_TOKEN"] = os.environ.get( 376 | "AWS_SESSION_TOKEN", aws_session_token 377 | ) 378 | 379 | session = aioboto3.Session() 380 | async with session.client("bedrock-runtime") as bedrock_async_client: 381 | if (model_provider := model.split(".")[0]) == "amazon": 382 | embed_texts = [] 383 | for text in texts: 384 | if "v2" in model: 385 | body = json.dumps( 386 | { 387 | "inputText": text, 388 | # 'dimensions': embedding_dim, 389 | "embeddingTypes": ["float"], 390 | } 391 | ) 392 | elif "v1" in model: 393 | body = json.dumps({"inputText": text}) 394 | else: 395 | raise ValueError(f"Model {model} is not supported!") 396 | 397 | response = await bedrock_async_client.invoke_model( 398 | modelId=model, 399 | body=body, 400 | accept="application/json", 401 | contentType="application/json", 402 | ) 403 | 404 | response_body = await response.get("body").json() 405 | 406 | embed_texts.append(response_body["embedding"]) 407 | elif model_provider == "cohere": 408 | body = json.dumps( 409 | {"texts": texts, "input_type": "search_document", "truncate": "NONE"} 410 | ) 411 | 412 | response = await bedrock_async_client.invoke_model( 413 | model=model, 414 | body=body, 415 | accept="application/json", 416 | contentType="application/json", 417 | ) 418 | 419 | response_body = json.loads(response.get("body").read()) 420 | 421 | embed_texts = response_body["embeddings"] 422 | else: 423 | raise ValueError(f"Model provider '{model_provider}' is not supported!") 424 | 425 | return np.array(embed_texts) 426 | 427 | 428 | class Model(BaseModel): 429 | """ 430 | This is a Pydantic model class named 'Model' that is used to define a custom language model. 431 | 432 | Attributes: 433 | gen_func (Callable[[Any], str]): A callable function that generates the response from the language model. 434 | The function should take any argument and return a string. 435 | kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function. 436 | This could include parameters such as the model name, API key, etc. 437 | 438 | Example usage: 439 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}) 440 | 441 | In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model. 442 | The 'kwargs' dictionary contains the model name and API key to be passed to the function. 443 | """ 444 | 445 | gen_func: Callable[[Any], str] = Field( 446 | ..., 447 | description="A function that generates the response from the llm. The response must be a string", 448 | ) 449 | kwargs: Dict[str, Any] = Field( 450 | ..., 451 | description="The arguments to pass to the callable function. Eg. the api key, model name, etc", 452 | ) 453 | 454 | class Config: 455 | arbitrary_types_allowed = True 456 | 457 | 458 | class MultiModel: 459 | """ 460 | Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier. 461 | Could also be used for spliting across diffrent models or providers. 462 | 463 | Attributes: 464 | models (List[Model]): A list of language models to be used. 465 | 466 | Usage example: 467 | ```python 468 | models = [ 469 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}), 470 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}), 471 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}), 472 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}), 473 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}), 474 | ] 475 | multi_model = MultiModel(models) 476 | rag = LightRAG( 477 | llm_model_func=multi_model.llm_model_func 478 | / ..other args 479 | ) 480 | ``` 481 | """ 482 | 483 | def __init__(self, models: List[Model]): 484 | self._models = models 485 | self._current_model = 0 486 | 487 | def _next_model(self): 488 | self._current_model = (self._current_model + 1) % len(self._models) 489 | return self._models[self._current_model] 490 | 491 | async def llm_model_func( 492 | self, prompt, system_prompt=None, history_messages=[], **kwargs 493 | ) -> str: 494 | kwargs.pop("model", None) # stop from overwriting the custom model name 495 | next_model = self._next_model() 496 | args = dict( 497 | prompt=prompt, 498 | system_prompt=system_prompt, 499 | history_messages=history_messages, 500 | **kwargs, 501 | **next_model.kwargs, 502 | ) 503 | 504 | return await next_model.gen_func(**args) 505 | 506 | 507 | if __name__ == "__main__": 508 | import asyncio 509 | 510 | async def main(): 511 | result = await gpt_4o_mini_complete("How are you?") 512 | print(result) 513 | 514 | asyncio.run(main()) 515 | -------------------------------------------------------------------------------- /hyperrag/storage.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import html 3 | import os 4 | from dataclasses import dataclass 5 | from typing import Any, Union, cast, List, Set, Tuple, Optional, Dict 6 | import numpy as np 7 | from nano_vectordb import NanoVectorDB 8 | from hyperdb import HypergraphDB 9 | from .utils import load_json, logger, write_json 10 | from .base import ( 11 | BaseKVStorage, 12 | BaseVectorStorage, 13 | BaseHypergraphStorage 14 | ) 15 | 16 | 17 | @dataclass 18 | class JsonKVStorage(BaseKVStorage): 19 | def __post_init__(self): 20 | working_dir = self.global_config["working_dir"] 21 | self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") 22 | self._data = load_json(self._file_name) or {} 23 | logger.info(f"Load KV {self.namespace} with {len(self._data)} data") 24 | 25 | async def all_keys(self) -> list[str]: 26 | return list(self._data.keys()) 27 | 28 | async def index_done_callback(self): 29 | write_json(self._data, self._file_name) 30 | 31 | async def get_by_id(self, id): 32 | return self._data.get(id, None) 33 | 34 | async def get_by_ids(self, ids, fields=None): 35 | if fields is None: 36 | return [self._data.get(id, None) for id in ids] 37 | return [ 38 | ( 39 | {k: v for k, v in self._data[id].items() if k in fields} 40 | if self._data.get(id, None) 41 | else None 42 | ) 43 | for id in ids 44 | ] 45 | 46 | async def filter_keys(self, data: list[str]) -> set[str]: 47 | return set([s for s in data if s not in self._data]) 48 | 49 | async def upsert(self, data: dict[str, dict]): 50 | left_data = {k: v for k, v in data.items() if k not in self._data} 51 | self._data.update(left_data) 52 | return left_data 53 | 54 | async def drop(self): 55 | self._data = {} 56 | 57 | 58 | @dataclass 59 | class NanoVectorDBStorage(BaseVectorStorage): 60 | cosine_better_than_threshold: float = 0.2 61 | 62 | def __post_init__(self): 63 | self._client_file_name = os.path.join( 64 | self.global_config["working_dir"], f"vdb_{self.namespace}.json" 65 | ) 66 | self._max_batch_size = self.global_config["embedding_batch_num"] 67 | self._client = NanoVectorDB( 68 | self.embedding_func.embedding_dim, storage_file=self._client_file_name 69 | ) 70 | self.cosine_better_than_threshold = self.global_config.get( 71 | "cosine_better_than_threshold", self.cosine_better_than_threshold 72 | ) 73 | 74 | async def upsert(self, data: dict[str, dict]): 75 | logger.info(f"Inserting {len(data)} vectors to {self.namespace}") 76 | if not len(data): 77 | logger.warning("You insert an empty data to vector DB") 78 | return [] 79 | list_data = [ 80 | { 81 | "__id__": k, 82 | **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, 83 | } 84 | for k, v in data.items() 85 | ] 86 | contents = [v["content"] for v in data.values()] 87 | batches = [ 88 | contents[i : i + self._max_batch_size] 89 | for i in range(0, len(contents), self._max_batch_size) 90 | ] 91 | embeddings_list = await asyncio.gather( 92 | *[self.embedding_func(batch) for batch in batches] 93 | ) 94 | embeddings = np.concatenate(embeddings_list) 95 | for i, d in enumerate(list_data): 96 | d["__vector__"] = embeddings[i] 97 | results = self._client.upsert(datas=list_data) 98 | return results 99 | 100 | async def query(self, query: str, top_k=5): 101 | embedding = await self.embedding_func([query]) 102 | embedding = embedding[0] 103 | results = self._client.query( 104 | query=embedding, 105 | top_k=top_k, 106 | better_than_threshold=self.cosine_better_than_threshold, 107 | ) 108 | results = [ 109 | {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results 110 | ] 111 | return results 112 | 113 | async def index_done_callback(self): 114 | self._client.save() 115 | 116 | 117 | @dataclass 118 | class HypergraphStorage(BaseHypergraphStorage): 119 | 120 | @staticmethod 121 | def load_hypergraph(file_name) -> HypergraphDB: 122 | if os.path.exists(file_name): 123 | pre_hypergraph = HypergraphDB() 124 | pre_hypergraph.load(file_name) 125 | return pre_hypergraph 126 | return None 127 | 128 | @staticmethod 129 | def write_hypergraph(hypergraph: HypergraphDB, file_name): 130 | logger.info( 131 | f"Writing hypergraph with {hypergraph.num_v} vertices, {hypergraph.num_e} hyperedges" 132 | ) 133 | hypergraph.save(file_name) 134 | 135 | def __post_init__(self): 136 | self._hgdb_file = os.path.join( 137 | self.global_config["working_dir"], f"hypergraph_{self.namespace}.hgdb" 138 | ) 139 | preloaded_hypergraph = HypergraphStorage.load_hypergraph(self._hgdb_file) 140 | if preloaded_hypergraph is not None: 141 | logger.info( 142 | f"Loaded hypergraph from {self._hgdb_file} with {preloaded_hypergraph.num_v} vertices, {preloaded_hypergraph.num_e} hyperedges" 143 | ) 144 | self._hg = preloaded_hypergraph or HypergraphDB() 145 | 146 | async def index_done_callback(self): 147 | HypergraphStorage.write_hypergraph(self._hg, self._hgdb_file) 148 | 149 | async def has_vertex(self, v_id: Any) -> bool: 150 | return self._hg.has_v(v_id) 151 | 152 | async def has_hyperedge(self, e_tuple: Union[List, Set, Tuple]) -> bool: 153 | return self._hg.has_e(e_tuple) 154 | 155 | async def get_vertex(self, v_id: str, default: Any = None) : 156 | return self._hg.v(v_id) 157 | 158 | async def get_hyperedge(self, e_tuple: Union[List, Set, Tuple], default: Any = None) : 159 | return self._hg.e(e_tuple) 160 | 161 | async def get_all_vertices(self): 162 | return self._hg.all_v 163 | 164 | async def get_all_hyperedges(self): 165 | return self._hg.all_e 166 | 167 | async def get_num_of_vertices(self): 168 | return self._hg.num_v 169 | 170 | async def get_num_of_hyperedges(self): 171 | return self._hg.num_e 172 | 173 | async def upsert_vertex(self, v_id: Any, v_data: Optional[Dict] = None) : 174 | return self._hg.add_v(v_id, v_data) 175 | 176 | async def upsert_hyperedge(self, e_tuple: Union[List, Set, Tuple], e_data: Optional[Dict] = None) : 177 | return self._hg.add_e(e_tuple, e_data) 178 | 179 | async def remove_vertex(self, v_id: Any) : 180 | return self._hg.remove_v(v_id) 181 | 182 | async def remove_hyperedge(self, e_tuple: Union[List, Set, Tuple]) : 183 | return self._hg.remove_e(e_tuple) 184 | 185 | async def vertex_degree(self, v_id: Any) -> int: 186 | return self._hg.degree_v(v_id) 187 | 188 | async def hyperedge_degree(self, e_tuple: Union[List, Set, Tuple]) -> int: 189 | return self._hg.degree_e(e_tuple) 190 | 191 | async def get_nbr_e_of_vertex(self, e_tuple: Union[List, Set, Tuple]) -> list: 192 | """ 193 | Return the incident hyperedges of the vertex. 194 | """ 195 | return self._hg.nbr_e_of_v(e_tuple) 196 | 197 | async def get_nbr_v_of_hyperedge(self, v_id: Any, exclude_self=True) -> list: 198 | """ 199 | Return the incident vertices of the hyperedge. 200 | """ 201 | return self._hg.nbr_v_of_e(v_id) 202 | 203 | async def get_nbr_v_of_vertex(self, v_id: Any, exclude_self=True) -> list: 204 | """ 205 | Return the neighbors of the vertex. 206 | """ 207 | return self._hg.nbr_v(v_id) 208 | -------------------------------------------------------------------------------- /hyperrag/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import html 3 | import io 4 | import csv 5 | import json 6 | import logging 7 | import os 8 | import re 9 | from dataclasses import dataclass 10 | from functools import wraps 11 | from hashlib import md5 12 | from typing import Any, Union, List 13 | import xml.etree.ElementTree as ET 14 | 15 | import numpy as np 16 | import tiktoken 17 | 18 | ENCODER = None 19 | 20 | logger = logging.getLogger("hyper_rag") 21 | 22 | 23 | def set_logger(log_file: str): 24 | logger.setLevel(logging.DEBUG) 25 | 26 | file_handler = logging.FileHandler(log_file) 27 | file_handler.setLevel(logging.DEBUG) 28 | 29 | formatter = logging.Formatter( 30 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 31 | ) 32 | file_handler.setFormatter(formatter) 33 | 34 | if not logger.handlers: 35 | logger.addHandler(file_handler) 36 | 37 | 38 | @dataclass 39 | class EmbeddingFunc: 40 | embedding_dim: int 41 | max_token_size: int 42 | func: callable 43 | 44 | async def __call__(self, *args, **kwargs) -> np.ndarray: 45 | return await self.func(*args, **kwargs) 46 | 47 | 48 | def locate_json_string_body_from_string(content: str) -> Union[str, None]: 49 | """Locate the JSON string body from a string""" 50 | maybe_json_str = re.search(r"{.*}", content, re.DOTALL) 51 | if maybe_json_str is not None: 52 | return maybe_json_str.group(0) 53 | else: 54 | return None 55 | 56 | 57 | def convert_response_to_json(response: str) -> dict: 58 | json_str = locate_json_string_body_from_string(response) 59 | assert json_str is not None, f"Unable to parse JSON from response: {response}" 60 | try: 61 | data = json.loads(json_str) 62 | return data 63 | except json.JSONDecodeError as e: 64 | logger.error(f"Failed to parse JSON: {json_str}") 65 | raise e from None 66 | 67 | 68 | def compute_args_hash(*args): 69 | return md5(str(args).encode()).hexdigest() 70 | 71 | 72 | def compute_mdhash_id(content, prefix: str = ""): 73 | return prefix + md5(content.encode()).hexdigest() 74 | 75 | 76 | def limit_async_func_call(max_size: int, waitting_time: float = 0.0001): 77 | """Add restriction of maximum async calling times for a async func""" 78 | 79 | def final_decro(func): 80 | """Not using async.Semaphore to aovid use nest-asyncio""" 81 | __current_size = 0 82 | 83 | @wraps(func) 84 | async def wait_func(*args, **kwargs): 85 | nonlocal __current_size 86 | while __current_size >= max_size: 87 | await asyncio.sleep(waitting_time) 88 | __current_size += 1 89 | result = await func(*args, **kwargs) 90 | __current_size -= 1 91 | return result 92 | 93 | return wait_func 94 | 95 | return final_decro 96 | 97 | 98 | def wrap_embedding_func_with_attrs(**kwargs): 99 | """Wrap a function with attributes""" 100 | 101 | def final_decro(func) -> EmbeddingFunc: 102 | new_func = EmbeddingFunc(**kwargs, func=func) 103 | return new_func 104 | 105 | return final_decro 106 | 107 | 108 | def load_json(file_name): 109 | if not os.path.exists(file_name): 110 | return None 111 | with open(file_name, encoding="utf-8") as f: 112 | return json.load(f) 113 | 114 | 115 | def write_json(json_obj, file_name): 116 | with open(file_name, "w", encoding="utf-8") as f: 117 | json.dump(json_obj, f, indent=2, ensure_ascii=False) 118 | 119 | 120 | def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"): 121 | global ENCODER 122 | if ENCODER is None: 123 | ENCODER = tiktoken.encoding_for_model(model_name) 124 | tokens = ENCODER.encode(content) 125 | return tokens 126 | 127 | 128 | def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"): 129 | global ENCODER 130 | if ENCODER is None: 131 | ENCODER = tiktoken.encoding_for_model(model_name) 132 | content = ENCODER.decode(tokens) 133 | return content 134 | 135 | 136 | def pack_user_ass_to_openai_messages(*args: str): 137 | roles = ["user", "assistant"] 138 | return [ 139 | {"role": roles[i % 2], "content": content} for i, content in enumerate(args) #if content is not None 140 | ] 141 | 142 | 143 | def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: 144 | """Split a string by multiple markers""" 145 | if not markers: 146 | return [content] 147 | results = re.split("|".join(re.escape(marker) for marker in markers), content) 148 | return [r.strip() for r in results if r.strip()] 149 | 150 | 151 | # Refer the utils functions of the official GraphRAG implementation: 152 | # https://github.com/microsoft/graphrag 153 | def clean_str(input: Any) -> str: 154 | """Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" 155 | # If we get non-string input, just give it back 156 | if not isinstance(input, str): 157 | return input 158 | 159 | result = html.unescape(input.strip()) 160 | # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python 161 | return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) 162 | 163 | 164 | def is_float_regex(value): 165 | return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) 166 | 167 | 168 | def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): 169 | """Truncate a list of data by token size""" 170 | if max_token_size <= 0: 171 | return [] 172 | tokens = 0 173 | for i, data in enumerate(list_data): 174 | tokens += len(encode_string_by_tiktoken(key(data))) 175 | if tokens > max_token_size: 176 | return list_data[:i] 177 | return list_data 178 | 179 | 180 | def list_of_list_to_csv(data: List[List[str]]) -> str: 181 | output = io.StringIO() 182 | writer = csv.writer(output) 183 | writer.writerows(data) 184 | return output.getvalue() 185 | 186 | 187 | def csv_string_to_list(csv_string: str) -> List[List[str]]: 188 | output = io.StringIO(csv_string) 189 | reader = csv.reader(output) 190 | return [row for row in reader] 191 | 192 | 193 | def save_data_to_file(data, file_name): 194 | with open(file_name, "w", encoding="utf-8") as f: 195 | json.dump(data, f, ensure_ascii=False, indent=4) 196 | 197 | 198 | def xml_to_json(xml_file): 199 | try: 200 | tree = ET.parse(xml_file) 201 | root = tree.getroot() 202 | 203 | # Print the root element's tag and attributes to confirm the file has been correctly loaded 204 | print(f"Root element: {root.tag}") 205 | print(f"Root attributes: {root.attrib}") 206 | 207 | data = {"nodes": [], "edges": []} 208 | 209 | # Use namespace 210 | namespace = {"": "http://graphml.graphdrawing.org/xmlns"} 211 | 212 | for node in root.findall(".//node", namespace): 213 | node_data = { 214 | "id": node.get("id").strip('"'), 215 | "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') 216 | if node.find("./data[@key='d0']", namespace) is not None 217 | else "", 218 | "description": node.find("./data[@key='d1']", namespace).text 219 | if node.find("./data[@key='d1']", namespace) is not None 220 | else "", 221 | "source_id": node.find("./data[@key='d2']", namespace).text 222 | if node.find("./data[@key='d2']", namespace) is not None 223 | else "", 224 | } 225 | data["nodes"].append(node_data) 226 | 227 | for edge in root.findall(".//edge", namespace): 228 | edge_data = { 229 | "source": edge.get("source").strip('"'), 230 | "target": edge.get("target").strip('"'), 231 | "weight": float(edge.find("./data[@key='d3']", namespace).text) 232 | if edge.find("./data[@key='d3']", namespace) is not None 233 | else 0.0, 234 | "description": edge.find("./data[@key='d4']", namespace).text 235 | if edge.find("./data[@key='d4']", namespace) is not None 236 | else "", 237 | "keywords": edge.find("./data[@key='d5']", namespace).text 238 | if edge.find("./data[@key='d5']", namespace) is not None 239 | else "", 240 | "source_id": edge.find("./data[@key='d6']", namespace).text 241 | if edge.find("./data[@key='d6']", namespace) is not None 242 | else "", 243 | } 244 | data["edges"].append(edge_data) 245 | 246 | # Print the number of nodes and edges found 247 | print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") 248 | 249 | return data 250 | except ET.ParseError as e: 251 | print(f"Error parsing XML file: {e}") 252 | return None 253 | except Exception as e: 254 | print(f"An error occurred: {e}") 255 | return None 256 | 257 | 258 | def process_combine_contexts(hl, ll): 259 | header = None 260 | list_hl = csv_string_to_list(hl.strip()) 261 | list_ll = csv_string_to_list(ll.strip()) 262 | 263 | if list_hl: 264 | header = list_hl[0] 265 | list_hl = list_hl[1:] 266 | if list_ll: 267 | header = list_ll[0] 268 | list_ll = list_ll[1:] 269 | if header is None: 270 | return "" 271 | 272 | if list_hl: 273 | list_hl = [",".join(item[1:]) for item in list_hl if item] 274 | if list_ll: 275 | list_ll = [",".join(item[1:]) for item in list_ll if item] 276 | 277 | combined_sources_set = set(filter(None, list_hl + list_ll)) 278 | 279 | combined_sources = [",\t".join(header)] 280 | 281 | for i, item in enumerate(combined_sources_set, start=1): 282 | combined_sources.append(f"{i},\t{item}") 283 | 284 | combined_sources = "\n".join(combined_sources) 285 | 286 | return combined_sources 287 | 288 | 289 | def always_get_an_event_loop() -> asyncio.AbstractEventLoop: 290 | """ 291 | Ensure that there is always an event loop available. 292 | 293 | This function tries to get the current event loop. If the current event loop is closed or does not exist, 294 | it creates a new event loop and sets it as the current event loop. 295 | 296 | Returns: 297 | asyncio.AbstractEventLoop: The current or newly created event loop. 298 | """ 299 | try: 300 | # Try to get the current event loop 301 | current_loop = asyncio.get_event_loop() 302 | if current_loop.is_closed(): 303 | raise RuntimeError("Event loop is closed.") 304 | return current_loop 305 | 306 | except RuntimeError: 307 | # If no event loop exists or it is closed, create a new one 308 | logger.info("Creating a new event loop in main thread.") 309 | new_loop = asyncio.new_event_loop() 310 | asyncio.set_event_loop(new_loop) 311 | return new_loop -------------------------------------------------------------------------------- /reproduce/Step_0.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from pathlib import Path 4 | 5 | 6 | def extract_unique_contexts(input_directory, output_directory): 7 | in_dir, out_dir = Path(input_directory), Path(output_directory) 8 | out_dir.mkdir(parents=True, exist_ok=True) 9 | 10 | jsonl_files = list(in_dir.glob("*.jsonl")) 11 | print(f"Found {len(jsonl_files)} JSONL files.") 12 | 13 | for file_path in jsonl_files: 14 | output_path = out_dir / f"{file_path.stem}_unique_contexts.json" 15 | if output_path.exists(): 16 | continue 17 | 18 | unique_contexts_dict = {} 19 | 20 | print(f"Processing file: {file_path.name}") 21 | 22 | try: 23 | with open(file_path, "r", encoding="utf-8") as infile: 24 | for line_number, line in enumerate(infile, start=1): 25 | line = line.strip() 26 | if not line: 27 | continue 28 | try: 29 | json_obj = json.loads(line) 30 | context = json_obj.get("context") 31 | if context and context not in unique_contexts_dict: 32 | unique_contexts_dict[context] = None 33 | except json.JSONDecodeError as e: 34 | print( 35 | f"JSON decoding error in file {file_path.name} at line {line_number}: {e}" 36 | ) 37 | except FileNotFoundError: 38 | print(f"File not found: {file_path.name}") 39 | continue 40 | except Exception as e: 41 | print(f"An error occurred while processing file {file_path.name}: {e}") 42 | continue 43 | 44 | unique_contexts_list = list(unique_contexts_dict.keys()) 45 | print( 46 | f"There are {len(unique_contexts_list)} unique `context` entries in the file {file_path.name}." 47 | ) 48 | 49 | try: 50 | with open(output_path, "w", encoding="utf-8") as outfile: 51 | json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4) 52 | print(f"Unique `context` entries have been saved to: {output_path.name}") 53 | except Exception as e: 54 | print(f"An error occurred while saving to the file {output_path.name}: {e}") 55 | 56 | print("All files have been processed.") 57 | 58 | 59 | if __name__ == "__main__": 60 | data_name = 'mix' 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument( 63 | "-i", "--input_dir", type=str, default=f"datasets/{data_name}" 64 | ) 65 | parser.add_argument( 66 | "-o", "--output_dir", type=str, default=f"caches/{data_name}/contexts" 67 | ) 68 | 69 | args = parser.parse_args() 70 | 71 | extract_unique_contexts(args.input_dir, args.output_dir) 72 | -------------------------------------------------------------------------------- /reproduce/Step_1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 5 | 6 | import time 7 | import numpy as np 8 | 9 | from hyperrag import HyperRAG 10 | from hyperrag.utils import EmbeddingFunc 11 | from hyperrag.llm import openai_embedding, openai_complete_if_cache 12 | 13 | from my_config import LLM_API_KEY, LLM_BASE_URL, LLM_MODEL 14 | from my_config import EMB_API_KEY, EMB_BASE_URL, EMB_MODEL, EMB_DIM 15 | 16 | 17 | async def llm_model_func( 18 | prompt, system_prompt=None, history_messages=[], **kwargs 19 | ) -> str: 20 | return await openai_complete_if_cache( 21 | LLM_MODEL, 22 | prompt, 23 | system_prompt=system_prompt, 24 | history_messages=history_messages, 25 | api_key=LLM_API_KEY, 26 | base_url=LLM_BASE_URL, 27 | **kwargs, 28 | ) 29 | 30 | 31 | async def embedding_func(texts: list[str]) -> np.ndarray: 32 | return await openai_embedding( 33 | texts, 34 | model=EMB_MODEL, 35 | api_key=EMB_API_KEY, 36 | base_url=EMB_BASE_URL, 37 | ) 38 | 39 | 40 | def insert_text(rag, file_path, retries=0, max_retries=3): 41 | with open(file_path, "r", encoding="utf-8") as f: 42 | unique_contexts = f.read() 43 | 44 | while retries < max_retries: 45 | try: 46 | rag.insert(unique_contexts) 47 | break 48 | except Exception as e: 49 | retries += 1 50 | print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}") 51 | time.sleep(10) 52 | if retries == max_retries: 53 | print("Insertion failed after exceeding the maximum number of retries") 54 | 55 | 56 | if __name__ == "__main__": 57 | data_name = "mix" 58 | WORKING_DIR = Path("caches") / data_name 59 | WORKING_DIR.mkdir(parents=True, exist_ok=True) 60 | rag = HyperRAG( 61 | working_dir=WORKING_DIR, 62 | llm_model_func=llm_model_func, 63 | embedding_func=EmbeddingFunc( 64 | embedding_dim=EMB_DIM, max_token_size=8192, func=embedding_func 65 | ), 66 | ) 67 | insert_text(rag, f"caches/{data_name}/contexts/{data_name}_unique_contexts.json") 68 | -------------------------------------------------------------------------------- /reproduce/Step_2_extract_question.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | import json 4 | import tiktoken 5 | import numpy as np 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | from openai import OpenAI 9 | 10 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 11 | 12 | from my_config import LLM_API_KEY, LLM_BASE_URL 13 | 14 | # suggest using more powerful LLMs to extract questions like gpt-4o 15 | LLM_MODEL = "gpt-4o" 16 | # LLM_MODEL = "gpt-4o-mini" 17 | 18 | 19 | def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: 20 | openai_client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL) 21 | 22 | messages = [] 23 | if system_prompt: 24 | messages.append({"role": "system", "content": system_prompt}) 25 | messages.extend(history_messages) 26 | messages.append({"role": "user", "content": prompt}) 27 | 28 | response = openai_client.chat.completions.create( 29 | model=LLM_MODEL, messages=messages, **kwargs 30 | ) 31 | return response.choices[0].message.content 32 | 33 | 34 | question_prompt = { 35 | # one-stage question 36 | 1: """ 37 | You are a professional teacher, and you are now asked to design a question that meets the requirements based on the reference. 38 | ################ 39 | Reference: 40 | Given the following fragment of a data set: 41 | {context} 42 | ################ 43 | Requirements: 44 | 1. This question should be of the question-and-answer (QA) type, and no answer is required. 45 | 2. This question mainly tests the details of the information and knowledge in the reference. Avoid general and macro question. 46 | 3. The question must not include any conjunctions such as "specifically", "particularly", "and", "or", "and how", "and what" or similar phrases that imply additional inquiries. 47 | 4. The question must focus on a single aspect or detail from the reference, avoiding the combination of multiple inquiries. 48 | 5. Please design question from the professional perspective and domain factors covered by the reference. 49 | 6. This question need to be meaningful and difficult, avoiding overly simplistic inquiries. 50 | 7. This question should be based on the complete context, so that the respondent knows what you are asking and doesn't get confused. 51 | 8. State the question directly in a single sentence, without statements like "How in this reference?" or "What about this data set?" or "as described in the reference." 52 | ################ 53 | Output the content of question in the following structure: 54 | {{ 55 | "Question": [question description], 56 | }} 57 | """, 58 | # two-stage question 59 | 2: """ 60 | You are a professional teacher, and your task is to design a single question that contains two interconnected sub-questions, 61 | demonstrating a progressive relationship based on the reference. 62 | ################ 63 | Reference: 64 | Given the following fragment of a data set: 65 | {context} 66 | ################ 67 | Requirements: 68 | 1. This question should be of the question-and-answer (QA) type, and no answer is required. 69 | 2. The question must include two sub-questions connected by transitional phrases such as "and" or "specifically," indicating progression. 70 | 3. Focus on testing the details of the information and knowledge in the reference. Avoid general and macro questions. 71 | 4. Design the question from a professional perspective, considering the domain factors covered by the reference. 72 | 5. Ensure the question is meaningful and challenging, avoiding trivial inquiries. 73 | 6. The question should be based on the complete context, ensuring clarity for the respondent. 74 | 7. State the question directly in a single sentence, without introductory phrases like "How in this reference?" or "What about this data set?". 75 | ################ 76 | Output the content of the question in the following structure: 77 | {{ 78 | "Question": [question description], 79 | }} 80 | """, 81 | # three-stage question 82 | 3: """ 83 | You are a professional teacher, and your task is to design a single question that contains three interconnected sub-questions, 84 | demonstrating a progressive relationship based on the reference. 85 | ################ 86 | Reference: 87 | Given the following fragment of a data set: 88 | {context} 89 | ################ 90 | Requirements: 91 | 1. This question should be of the question-and-answer (QA) type, and no answer is required. 92 | 2. The question must include three sub-questions connected by transitional phrases such as "and" or "specifically," indicating progression. 93 | 3. Focus on testing the details of the information and knowledge in the reference. Avoid general and macro questions. 94 | 4. Design the question from a professional perspective, considering the domain factors covered by the reference. 95 | 5. Ensure the question is meaningful and challenging, avoiding trivial inquiries. 96 | 6. The question should be based on the complete context, ensuring clarity for the respondent. 97 | 7. State the question directly in a single sentence, without introductory phrases like "How in this reference?" or "What about this data set?". 98 | ################ 99 | Output the content of the question in the following structure: 100 | {{ 101 | "Question": [question description], 102 | }} 103 | """, 104 | } 105 | 106 | encoding = tiktoken.encoding_for_model("gpt-4o") 107 | 108 | 109 | if __name__ == "__main__": 110 | data_name = "mix" 111 | question_stage = 2 112 | WORKING_DIR = Path("caches") / data_name 113 | # number of question stages to extract, which can be 1, 2, or 3 114 | len_big_chunks = 3 115 | question_list, reference_list = [], [] 116 | with open( 117 | f"caches/{data_name}/contexts/{data_name}_unique_contexts.json", 118 | mode="r", 119 | encoding="utf-8", 120 | ) as f: 121 | unique_contexts = json.load(f) 122 | 123 | cnt, max_cnt = 0, 5 124 | max_idx = max(len(unique_contexts) - len_big_chunks - 1, 1) 125 | 126 | with tqdm( 127 | total=max_cnt, desc=f"Extracting {question_stage}-stage questions" 128 | ) as pbar: 129 | while cnt < max_cnt: 130 | # randomly select a context 131 | idx = np.random.randint(0, max_idx) 132 | big_chunks = unique_contexts[idx : idx + len_big_chunks] 133 | context = "".join(big_chunks) 134 | 135 | prompt = question_prompt[question_stage].format(context=context) 136 | response = llm_model_func(prompt) 137 | 138 | question = re.findall(r'"Question": "(.*?)"', response) 139 | if len(question) == 0: 140 | print("No question found in the response.") 141 | continue 142 | 143 | question_list.append(question[0]) 144 | reference_list.append(context) 145 | 146 | cnt += 1 147 | pbar.update(1) 148 | 149 | # save the questions and references to a JSON file 150 | prefix = f"caches/{data_name}/questions/{question_stage}_stage" 151 | question_file_path = Path(f"{prefix}.json") 152 | ref_file_path = Path(f"{prefix}_ref.json") 153 | question_file_path.parent.mkdir(parents=True, exist_ok=True) 154 | with open(f"{prefix}.json", "w", encoding="utf-8") as f: 155 | json.dump(question_list, f, ensure_ascii=False, indent=4) 156 | with open(f"{prefix}_ref.json", "w", encoding="utf-8") as f: 157 | json.dump(reference_list, f, ensure_ascii=False, indent=4) 158 | 159 | print(f"questions written to {question_file_path}") 160 | -------------------------------------------------------------------------------- /reproduce/Step_3_response_question.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 5 | 6 | import json 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from hyperrag import HyperRAG, QueryParam 11 | from hyperrag.utils import always_get_an_event_loop, EmbeddingFunc 12 | from hyperrag.llm import openai_embedding, openai_complete_if_cache 13 | 14 | from my_config import LLM_API_KEY, LLM_BASE_URL, LLM_MODEL 15 | from my_config import EMB_API_KEY, EMB_BASE_URL, EMB_MODEL, EMB_DIM 16 | 17 | 18 | async def llm_model_func( 19 | prompt, system_prompt=None, history_messages=[], **kwargs 20 | ) -> str: 21 | return await openai_complete_if_cache( 22 | LLM_MODEL, 23 | prompt, 24 | system_prompt=system_prompt, 25 | history_messages=history_messages, 26 | api_key=LLM_API_KEY, 27 | base_url=LLM_BASE_URL, 28 | **kwargs, 29 | ) 30 | 31 | 32 | async def embedding_func(texts: list[str]) -> np.ndarray: 33 | return await openai_embedding( 34 | texts, 35 | model=EMB_MODEL, 36 | api_key=EMB_API_KEY, 37 | base_url=EMB_BASE_URL, 38 | ) 39 | 40 | 41 | def extract_queries(file_path): 42 | with open(file_path, "r", encoding="utf-8") as file: 43 | query_list = json.load(file) 44 | return query_list 45 | 46 | 47 | async def process_query(query_text, rag_instance, query_param): 48 | try: 49 | result = await rag_instance.aquery(query_text, param=query_param) 50 | return {"query": query_text, "result": result}, None 51 | except Exception as e: 52 | print("error", e) 53 | return None, {"query": query_text, "error": str(e)} 54 | 55 | 56 | def run_queries_and_save_to_json( 57 | queries, rag_instance, query_param, output_file, error_file 58 | ): 59 | loop = always_get_an_event_loop() 60 | 61 | with open(output_file, "a", encoding="utf-8") as result_file, open( 62 | error_file, "a", encoding="utf-8" 63 | ) as err_file: 64 | result_file.write("[\n") 65 | first_entry = True 66 | 67 | for query_text in tqdm(queries, desc="Processing queries", unit="query"): 68 | result, error = loop.run_until_complete( 69 | process_query(query_text, rag_instance, query_param) 70 | ) 71 | if result: 72 | if not first_entry: 73 | result_file.write(",\n") 74 | json.dump(result, result_file, ensure_ascii=False, indent=4) 75 | first_entry = False 76 | elif error: 77 | json.dump(error, err_file, ensure_ascii=False, indent=4) 78 | err_file.write("\n") 79 | 80 | result_file.write("\n]") 81 | 82 | 83 | if __name__ == "__main__": 84 | data_name = "mix" 85 | question_stage = 2 86 | WORKING_DIR = Path("caches") / data_name 87 | # input questions 88 | question_file_path = Path( 89 | WORKING_DIR / f"questions/{question_stage}_stage.json" 90 | ) 91 | queries = extract_queries(question_file_path) 92 | # init HyperRAG 93 | rag = HyperRAG( 94 | working_dir=WORKING_DIR, 95 | llm_model_func=llm_model_func, 96 | embedding_func=EmbeddingFunc( 97 | embedding_dim=EMB_DIM, max_token_size=8192, func=embedding_func 98 | ), 99 | ) 100 | # configure query parameters 101 | mode = "naive" 102 | # mode = "hyper" 103 | # mode = "hyper-lite" 104 | query_param = QueryParam(mode=mode) 105 | 106 | OUT_DIR = WORKING_DIR / "response" 107 | OUT_DIR.mkdir(parents=True, exist_ok=True) 108 | run_queries_and_save_to_json( 109 | queries, 110 | rag, 111 | query_param, 112 | OUT_DIR / f"{mode}_{question_stage}_stage_result.json", 113 | OUT_DIR / f"{mode}_{question_stage}_stage_errors.json", 114 | ) 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | aioboto3 3 | aiohttp 4 | numpy 5 | nano-vectordb 6 | openai 7 | tenacity 8 | tiktoken 9 | xxhash 10 | hypergraph-db 11 | -------------------------------------------------------------------------------- /web-ui/.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | pnpm-debug.log* 8 | lerna-debug.log* 9 | 10 | node_modules 11 | dist 12 | dist-ssr 13 | *.local 14 | 15 | # Editor directories and files 16 | .vscode/* 17 | !.vscode/extensions.json 18 | .idea 19 | .DS_Store 20 | *.suo 21 | *.ntvs* 22 | *.njsproj 23 | *.sln 24 | *.sw? 25 | pnpm-lock.yaml 26 | __pycache__ -------------------------------------------------------------------------------- /web-ui/README.md: -------------------------------------------------------------------------------- 1 | # 前端项目 2 | 3 | ## 安装依赖 4 | 5 | ```bash 6 | cd frontend 7 | pnpm install 8 | ``` 9 | 10 | ## 脚本描述 11 | 12 | ### 开发启动 13 | 14 | ```bash 15 | npm run dev 16 | ``` 17 | 18 | ### 打包 19 | 20 | ```bash 21 | npm run build 22 | ``` 23 | 24 | # 后端服务 25 | 26 | ``` 27 | pip install "fastapi[standard]" 28 | ``` 29 | 30 | ## 开发启动 31 | 32 | ```bash 33 | fastapi dev main.py 34 | ``` 35 | 36 | Server started at http://127.0.0.1:8000 37 | 38 | Documentation at http://127.0.0.1:8000/docs 39 | 40 | nohup fastapi run main.py > logout.log 2>&1 & 41 | -------------------------------------------------------------------------------- /web-ui/backend/README.md: -------------------------------------------------------------------------------- 1 | ## 开发 2 | 3 | fastapi dev main.py 4 | 5 | Server started at http://127.0.0.1:8000 6 | 7 | Documentation at http://127.0.0.1:8000/docs 8 | -------------------------------------------------------------------------------- /web-ui/backend/db.py: -------------------------------------------------------------------------------- 1 | from hyperdb import HypergraphDB 2 | 3 | hg = HypergraphDB(storage_file="hypergraph_wukong.hgdb") 4 | 5 | # 声明函数 6 | def get_hypergraph(): 7 | # 声明变量 赋值 hg.all_v 8 | all_v = hg.all_v 9 | # 声明变量 赋值 hg.all_e 10 | all_e = hg.all_e 11 | 12 | return get_all_detail(all_v, all_e) 13 | 14 | def get_vertices(): 15 | """ 16 | 获取vertices列表 17 | """ 18 | all_v = hg.all_v 19 | return all_v 20 | 21 | def getFrequentVertices(): 22 | """ 23 | 获取频繁的vertices列表 24 | """ 25 | all_v = hg.all_v 26 | 27 | frequent_vertices = [] 28 | 29 | edges = get_hyperedges() 30 | for v in all_v: 31 | count = 0 32 | for e in edges: 33 | if v in e: 34 | count += 1 35 | if count >= 2: 36 | frequent_vertices.append(v) 37 | 38 | return frequent_vertices 39 | 40 | def get_vertice(vertex_id: str): 41 | """ 42 | 获取指定vertex的json 43 | """ 44 | vertex = hg.v(vertex_id) 45 | return vertex 46 | 47 | def get_hyperedges(): 48 | """ 49 | 获取hyperedges列表 50 | """ 51 | all_e = hg.all_e 52 | 53 | hyperedges = [] 54 | for e in all_e: 55 | hyperedges.append('|*|'.join(e)) 56 | 57 | return hyperedges 58 | 59 | def get_hyperedge(hyperedge_id: str): 60 | """ 61 | 获取指定hyperedge的json 62 | """ 63 | hyperedge = hg.e(hyperedge_id) 64 | 65 | return hyperedge 66 | 67 | def get_vertice_neighbor_inner(vertex_id: str): 68 | """ 69 | 获取指定vertex的neighbor 70 | 71 | todo: 查不到会报错 CLERGYMAN 72 | """ 73 | try: 74 | n = hg.nbr_v(vertex_id) 75 | 76 | n.add(vertex_id) 77 | 78 | e = hg.nbr_e_of_v(vertex_id) 79 | except Exception: 80 | # 如果报错,返回空列表 81 | n = [] 82 | e = [] 83 | 84 | return (n,e) 85 | 86 | def get_vertice_neighbor(vertex_id: str): 87 | """ 88 | 获取指定vertex的neighbor 89 | 90 | todo: 查不到会报错 CLERGYMAN 91 | """ 92 | n, e = get_vertice_neighbor_inner(vertex_id) 93 | 94 | return get_all_detail(n, e) 95 | 96 | 97 | def get_all_detail(all_v, all_e): 98 | """ 99 | 获取所有详情 100 | """ 101 | # 循环遍历 all_v 每个元素 赋值为 hg.v 102 | nodes = {} 103 | for v in all_v: 104 | nodes[v] = hg.v(v) 105 | 106 | hyperedges = {} 107 | for e in all_e: 108 | data = hg.e(e) 109 | # data的 keywords 赋值 110 | data['keywords'] = data['keywords'].replace("", ",") 111 | hyperedges['|#|'.join(e)] = data 112 | 113 | return { "vertices": nodes , "edges": hyperedges } 114 | 115 | def get_hyperedge_neighbor_server(hyperedge_id: str): 116 | """ 117 | 获取指定hyperedge的neighbor 118 | """ 119 | nodes = hyperedge_id.split("|#|") 120 | print(hyperedge_id) 121 | vertices = set() 122 | hyperedges = set() 123 | for node in nodes: 124 | n, e = get_vertice_neighbor_inner(node) 125 | # 这里的 n 是一个集合 126 | # 这里的 e 是一个集合 127 | # vertexs 增加n 128 | # hyperedges 增加e 129 | vertices.update(n) 130 | hyperedges.update(e) 131 | 132 | return get_all_detail(vertices, hyperedges) -------------------------------------------------------------------------------- /web-ui/backend/hyperdb/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseHypergraphDB 2 | from .hypergraph import HypergraphDB 3 | 4 | from ._global import AUTHOR_EMAIL 5 | 6 | __version__ = "0.1.3" 7 | 8 | __all__ = {"AUTHOR_EMAIL", "BaseHypergraphDB", "HypergraphDB"} 9 | -------------------------------------------------------------------------------- /web-ui/backend/hyperdb/_global.py: -------------------------------------------------------------------------------- 1 | AUTHOR_EMAIL = "evanfeng97@gmail.com" 2 | -------------------------------------------------------------------------------- /web-ui/backend/hyperdb/base.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from dataclasses import dataclass, field 3 | from functools import cached_property 4 | from typing import Union, Tuple, List, Set, Dict, Any, Optional 5 | 6 | 7 | @dataclass 8 | class BaseHypergraphDB: 9 | r""" 10 | Base class for hypergraph database. 11 | """ 12 | 13 | storage_file: Union[str, Path] = field(default="my_hypergraph.hgdb", compare=False) 14 | 15 | def save(self, file_path: Union[str, Path]): 16 | r""" 17 | Save the hypergraph to a file. 18 | 19 | Args: 20 | ``file_path`` (``Union[str, Path]``): The file path to save the hypergraph. 21 | """ 22 | raise NotImplementedError 23 | 24 | def save_as(self, format: str, file_path: Union[str, Path]): 25 | r""" 26 | Save the hypergraph to a specific format. 27 | 28 | Args: 29 | ``format`` (``str``): The export format (e.g., "json", "csv", "graphml"). 30 | ``file_path`` (``Union[str, Path]``): The file path to export the hypergraph. 31 | """ 32 | raise NotImplementedError 33 | 34 | @staticmethod 35 | def load(self, file_path: Union[str, Path]): 36 | r""" 37 | Load the hypergraph from a file. 38 | 39 | Args: 40 | ``file_path`` (``Union[str, Path]``): The file path to load the hypergraph from. 41 | """ 42 | raise NotImplementedError 43 | 44 | def load_from(self, format: str, file_path: Union[str, Path]): 45 | r""" 46 | Load a hypergraph from a specific format. 47 | 48 | Args: 49 | ``format`` (``str``): The import format (e.g., "json", "csv", "graphml"). 50 | ``file_path`` (``Union[str, Path]``): The file path to import the hypergraph from. 51 | """ 52 | raise NotImplementedError 53 | 54 | def _clear_cache(self): 55 | r""" 56 | Clear the cache. 57 | """ 58 | raise NotImplementedError 59 | 60 | def v(self, v_id: Any, default: Any = None) -> dict: 61 | r""" 62 | Return the vertex data. 63 | 64 | Args: 65 | ``v_id`` (``Any``): The vertex id. 66 | ``default`` (``Any``): The default value if the vertex does not exist. 67 | """ 68 | raise NotImplementedError 69 | 70 | def e(self, e_tuple: Union[List, Set, Tuple], default: Any = None) -> dict: 71 | r""" 72 | Return the hyperedge data. 73 | 74 | Args: 75 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 76 | ``default`` (``Any``): The default value if the hyperedge does not exist. 77 | """ 78 | raise NotImplementedError 79 | 80 | def encode_e(self, e_tuple: Union[List, Set, Tuple]) -> Tuple: 81 | r""" 82 | Sort and check the hyperedge tuple. 83 | 84 | Args: 85 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 86 | """ 87 | raise NotImplementedError 88 | 89 | @cached_property 90 | def all_v(self) -> List[str]: 91 | r""" 92 | Return a list of all vertices in the hypergraph. 93 | """ 94 | raise NotImplementedError 95 | 96 | @cached_property 97 | def all_e(self) -> List[Tuple]: 98 | r""" 99 | Return a list of all hyperedges in the hypergraph. 100 | """ 101 | raise NotImplementedError 102 | 103 | @cached_property 104 | def num_v(self) -> int: 105 | r""" 106 | Return the number of vertices in the hypergraph. 107 | """ 108 | raise NotImplementedError 109 | 110 | @cached_property 111 | def num_e(self) -> int: 112 | r""" 113 | Return the number of hyperedges in the hypergraph. 114 | """ 115 | raise NotImplementedError 116 | 117 | def add_v(self, v_id: Any, v_data: Optional[Dict] = None): 118 | r""" 119 | Add a vertex to the hypergraph. 120 | 121 | Args: 122 | ``v_id`` (``Any``): The vertex id. 123 | ``v_data`` (``Dict``, optional): The vertex data. Defaults to None. 124 | """ 125 | raise NotImplementedError 126 | 127 | def add_e(self, e_tuple: Tuple, e_data: Optional[Dict] = None): 128 | r""" 129 | Add a hyperedge to the hypergraph. 130 | 131 | Args: 132 | ``e_tuple`` (``Tuple``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 133 | ``e_data`` (``Dict``, optional): The hyperedge data. 134 | """ 135 | raise NotImplementedError 136 | 137 | def remove_v(self, v_id: Any): 138 | r""" 139 | Remove a vertex from the hypergraph. 140 | 141 | Args: 142 | ``v_id`` (``Any``): The vertex id. 143 | """ 144 | raise NotImplementedError 145 | 146 | def remove_e(self, e_tuple: Tuple): 147 | r""" 148 | Remove a hyperedge from the hypergraph. 149 | 150 | Args: 151 | ``e_tuple`` (``Tuple``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 152 | """ 153 | raise NotImplementedError 154 | 155 | def update_v(self, v_id: Any): 156 | r""" 157 | Update the vertex data. 158 | 159 | Args: 160 | ``v_id`` (``Any``): The vertex id. 161 | """ 162 | raise NotImplementedError 163 | 164 | def update_e(self, e_tuple: Tuple): 165 | r""" 166 | Update the hyperedge data. 167 | 168 | Args: 169 | ``e_tuple`` (``Tuple``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 170 | """ 171 | raise NotImplementedError 172 | 173 | def has_v(self, v_id: Any) -> bool: 174 | r""" 175 | Return True if the vertex exists in the hypergraph. 176 | 177 | Args: 178 | ``v_id`` (``Any``): The vertex id. 179 | """ 180 | raise NotImplementedError 181 | 182 | def has_e(self, e_tuple: Tuple) -> bool: 183 | r""" 184 | Return True if the hyperedge exists in the hypergraph. 185 | 186 | Args: 187 | ``e_tuple`` (``Tuple``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 188 | """ 189 | raise NotImplementedError 190 | 191 | def degree_v(self, v_id: Any) -> int: 192 | r""" 193 | Return the degree of the vertex. 194 | 195 | Args: 196 | ``v_id`` (``Any``): The vertex id. 197 | """ 198 | raise NotImplementedError 199 | 200 | def degree_e(self, e_tuple: Tuple) -> int: 201 | r""" 202 | Return the degree of the hyperedge. 203 | 204 | Args: 205 | ``e_tuple`` (``Tuple``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 206 | """ 207 | raise NotImplementedError 208 | 209 | def nbr_e_of_v(self, v_id: Any) -> list: 210 | r""" 211 | Return the hyperedge neighbors of the vertex. 212 | 213 | Args: 214 | ``v_id`` (``Any``): The vertex id. 215 | """ 216 | raise NotImplementedError 217 | 218 | def nbr_v_of_e(self, e_tuple: Tuple) -> list: 219 | r""" 220 | Return the vertex neighbors of the hyperedge. 221 | 222 | Args: 223 | ``e_tuple`` (``Tuple``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 224 | """ 225 | raise NotImplementedError 226 | 227 | def nbr_v(self, v_id: Any) -> list: 228 | r""" 229 | Return the vertex neighbors of the vertex. 230 | 231 | Args: 232 | ``v_id`` (``Any``): The vertex id. 233 | """ 234 | raise NotImplementedError 235 | 236 | def draw( 237 | self, 238 | ): 239 | r""" 240 | Draw the hypergraph. 241 | """ 242 | raise NotImplementedError 243 | 244 | def sub(self, v_name_list: List[str]): 245 | r""" 246 | Return the sub-hypergraph. 247 | 248 | Args: 249 | ``v_name_list`` (``List[str]``): The list of vertex ids. 250 | """ 251 | raise NotImplementedError 252 | 253 | def sub_from_v(self, v_id: Any, depth: int): 254 | r""" 255 | Return the sub-hypergraph from the vertex. 256 | 257 | Args: 258 | ``v_id`` (``Any``): The vertex id. 259 | ``depth`` (``int``): The depth of the sub-hypergraph. 260 | """ 261 | raise NotImplementedError 262 | 263 | def query_v(self, filters: Dict[str, Any]) -> List[str]: 264 | r""" 265 | Query and return vertices that match the given filters. 266 | 267 | Args: 268 | ``filters`` (``Dict[str, Any]``): A dictionary of conditions to filter vertices. 269 | """ 270 | raise NotImplementedError 271 | 272 | def query_e(self, filters: Dict[str, Any]) -> List[Tuple]: 273 | r""" 274 | Query and return hyperedges that match the given filters. 275 | 276 | Args: 277 | ``filters`` (``Dict[str, Any]``): A dictionary of conditions to filter hyperedges. 278 | """ 279 | raise NotImplementedError 280 | 281 | def stats(self) -> dict: 282 | r""" 283 | Return basic statistics of the hypergraph. 284 | """ 285 | raise NotImplementedError 286 | -------------------------------------------------------------------------------- /web-ui/backend/hyperdb/hypergraph.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | from pathlib import Path 3 | from copy import deepcopy 4 | from collections import defaultdict 5 | from collections.abc import Hashable 6 | from functools import cached_property 7 | from dataclasses import dataclass, field 8 | from typing import Tuple, List, Any, Union, Set, Dict, Optional 9 | 10 | 11 | from hyperdb.base import BaseHypergraphDB 12 | 13 | 14 | @dataclass 15 | class HypergraphDB(BaseHypergraphDB): 16 | r""" 17 | Hypergraph database. 18 | """ 19 | 20 | _v_data: Dict[str, Any] = field(default_factory=dict) 21 | _e_data: Dict[Tuple, Any] = field(default_factory=dict) 22 | _v_inci: Dict[str, Set[Tuple]] = field(default_factory=lambda: defaultdict(set)) 23 | 24 | def __post_init__(self): 25 | assert isinstance(self.storage_file, (str, Path)) 26 | if isinstance(self.storage_file, str): 27 | self.storage_file = Path(self.storage_file) 28 | if self.storage_file.exists(): 29 | self.load(self.storage_file) 30 | 31 | def load(self, storage_file: Path) -> dict: 32 | r""" 33 | Load the hypergraph database from the storage file. 34 | """ 35 | try: 36 | with open(storage_file, "rb") as f: 37 | data = pkl.load(f) 38 | self._v_data = data.get("v_data", {}) 39 | self._v_inci = data.get("v_inci", {}) 40 | self._e_data = data.get("e_data", {}) 41 | return True 42 | except Exception as e: 43 | return False 44 | 45 | def save(self, storage_file: Path) -> dict: 46 | r""" 47 | Save the hypergraph database to the storage file. 48 | """ 49 | data = { 50 | "v_data": self._v_data, 51 | "v_inci": self._v_inci, 52 | "e_data": self._e_data, 53 | } 54 | try: 55 | with open(storage_file, "wb") as f: 56 | pkl.dump(data, f) 57 | return True 58 | except Exception as e: 59 | return False 60 | 61 | def _clear_cache(self): 62 | r""" 63 | Clear the cached properties. 64 | """ 65 | self.__dict__.pop("all_v", None) 66 | self.__dict__.pop("all_e", None) 67 | self.__dict__.pop("num_v", None) 68 | self.__dict__.pop("num_e", None) 69 | 70 | def v(self, v_id: str, default: Any = None) -> dict: 71 | r""" 72 | Return the vertex data. 73 | 74 | Args: 75 | ``v_id`` (``str``): The vertex id. 76 | ``default`` (``Any``): The default value if the vertex does not exist. 77 | """ 78 | assert isinstance(v_id, Hashable), "The vertex id must be hashable." 79 | try: 80 | return self._v_data[v_id] 81 | except KeyError: 82 | return default 83 | 84 | def e(self, e_tuple: Union[List, Set, Tuple], default: Any = None) -> dict: 85 | r""" 86 | Return the hyperedge data. 87 | 88 | Args: 89 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 90 | ``default`` (``Any``): The default value if the hyperedge does not exist. 91 | """ 92 | assert isinstance( 93 | e_tuple, (set, list, tuple) 94 | ), "The hyperedge must be a set, list, or tuple of vertex ids." 95 | e_tuple = self.encode_e(e_tuple) 96 | try: 97 | return self._e_data[e_tuple] 98 | except KeyError: 99 | return default 100 | 101 | def encode_e(self, e_tuple: Union[List, Set, Tuple]) -> Tuple: 102 | r""" 103 | Sort and check the hyperedge tuple. 104 | 105 | Args: 106 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 107 | """ 108 | assert isinstance( 109 | e_tuple, (list, set, tuple) 110 | ), "The hyperedge must be a list, set, or tuple of vertex ids." 111 | tmp = sorted(list(set(e_tuple))) 112 | for v_id in tmp: 113 | assert isinstance(v_id, Hashable), "The vertex id must be hashable." 114 | assert ( 115 | v_id in self._v_data 116 | ), f"The vertex {v_id} does not exist in the hypergraph." 117 | return tuple(tmp) 118 | 119 | @cached_property 120 | def all_v(self) -> List[str]: 121 | r""" 122 | Return a list of all vertices in the hypergraph. 123 | """ 124 | return set(self._v_data.keys()) 125 | 126 | @cached_property 127 | def all_e(self) -> List[Tuple]: 128 | r""" 129 | Return a list of all hyperedges in the hypergraph. 130 | """ 131 | return set(self._e_data.keys()) 132 | 133 | @cached_property 134 | def num_v(self) -> int: 135 | r""" 136 | Return the number of vertices in the hypergraph. 137 | """ 138 | return len(self._v_data) 139 | 140 | @cached_property 141 | def num_e(self) -> int: 142 | r""" 143 | Return the number of hyperedges in the hypergraph. 144 | """ 145 | return len(self._e_data) 146 | 147 | def add_v(self, v_id: Any, v_data: Optional[Dict] = None): 148 | r""" 149 | Add a vertex to the hypergraph. 150 | 151 | Args: 152 | ``v_id`` (``Any``): The vertex id. 153 | ``v_data`` (``dict``, optional): The vertex data. 154 | """ 155 | assert isinstance(v_id, Hashable), "The vertex id must be hashable." 156 | if v_data is not None: 157 | assert isinstance(v_data, dict), "The vertex data must be a dictionary." 158 | else: 159 | v_data = {} 160 | if v_id not in self._v_data: 161 | self._v_data[v_id] = v_data 162 | self._v_inci[v_id] = set() 163 | else: 164 | self._v_data[v_id].update(v_data) 165 | self._clear_cache() 166 | 167 | def add_e(self, e_tuple: Union[List, Set, Tuple], e_data: Optional[Dict] = None): 168 | r""" 169 | Add a hyperedge to the hypergraph. 170 | 171 | Args: 172 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 173 | ``e_data`` (``dict``, optional): The hyperedge data. 174 | """ 175 | assert isinstance( 176 | e_tuple, (list, set, tuple) 177 | ), "The hyperedge must be a list, set, or tuple of vertex ids." 178 | if e_data is not None: 179 | assert isinstance(e_data, dict), "The hyperedge data must be a dictionary." 180 | else: 181 | e_data = {} 182 | e_tuple = self.encode_e(e_tuple) 183 | if e_tuple not in self._e_data: 184 | self._e_data[e_tuple] = e_data 185 | for v in e_tuple: 186 | self._v_inci[v].add(e_tuple) 187 | else: 188 | self._e_data[e_tuple].update(e_data) 189 | self._clear_cache() 190 | 191 | def remove_v(self, v_id: Any): 192 | r""" 193 | Remove a vertex from the hypergraph. 194 | 195 | Args: 196 | ``v_id`` (``Any``): The vertex id. 197 | """ 198 | assert isinstance(v_id, Hashable), "The vertex id must be hashable." 199 | assert ( 200 | v_id in self._v_data 201 | ), f"The vertex {v_id} does not exist in the hypergraph." 202 | del self._v_data[v_id] 203 | old_e_tuples, new_e_tuples = [], [] 204 | for e_tuple in self._v_inci[v_id]: 205 | new_e_tuple = self.encode_e(set(e_tuple) - {v_id}) 206 | if len(new_e_tuple) >= 2: 207 | # todo: maybe new e tuple existing in hg, need to merge to hyperedge information 208 | self._e_data[new_e_tuple] = deepcopy(self._e_data[e_tuple]) 209 | del self._e_data[e_tuple] 210 | old_e_tuples.append(e_tuple) 211 | new_e_tuples.append(new_e_tuple) 212 | del self._v_inci[v_id] 213 | for old_e_tuple, new_e_tuple in zip(old_e_tuples, new_e_tuples): 214 | for _v_id in old_e_tuple: 215 | if _v_id != v_id: 216 | self._v_inci[_v_id].remove(old_e_tuple) 217 | if len(new_e_tuple) >= 2: 218 | self._v_inci[_v_id].add(new_e_tuple) 219 | self._clear_cache() 220 | 221 | def remove_e(self, e_tuple: Union[List, Set, Tuple]): 222 | r""" 223 | Remove a hyperedge from the hypergraph. 224 | 225 | Args: 226 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 227 | """ 228 | assert isinstance( 229 | e_tuple, (list, set, tuple) 230 | ), "The hyperedge must be a list, set, or tuple of vertex ids." 231 | e_tuple = self.encode_e(e_tuple) 232 | assert ( 233 | e_tuple in self._e_data 234 | ), f"The hyperedge {e_tuple} does not exist in the hypergraph." 235 | for v in e_tuple: 236 | self._v_inci[v].remove(e_tuple) 237 | del self._e_data[e_tuple] 238 | self._clear_cache() 239 | 240 | def update_v(self, v_id: Any, v_data: dict): 241 | r""" 242 | Update the vertex data. 243 | 244 | Args: 245 | ``v_id`` (``Any``): The vertex id. 246 | ``v_data`` (``dict``): The vertex data. 247 | """ 248 | assert isinstance(v_id, Hashable), "The vertex id must be hashable." 249 | assert isinstance(v_data, dict), "The vertex data must be a dictionary." 250 | assert ( 251 | v_id in self._v_data 252 | ), f"The vertex {v_id} does not exist in the hypergraph." 253 | self._v_data[v_id].update(v_data) 254 | self._clear_cache() 255 | 256 | def update_e(self, e_tuple: Union[List, Set, Tuple], e_data: dict): 257 | r""" 258 | Update the hyperedge data. 259 | 260 | Args: 261 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 262 | ``e_data`` (``dict``): The hyperedge data. 263 | """ 264 | assert isinstance( 265 | e_tuple, (list, set, tuple) 266 | ), "The hyperedge must be a list, set, or tuple of vertex ids." 267 | assert isinstance(e_data, dict), "The hyperedge data must be a dictionary." 268 | e_tuple = self.encode_e(e_tuple) 269 | assert ( 270 | e_tuple in self._e_data 271 | ), f"The hyperedge {e_tuple} does not exist in the hypergraph." 272 | self._e_data[e_tuple].update(e_data) 273 | self._clear_cache() 274 | 275 | def has_v(self, v_id: Any) -> bool: 276 | r""" 277 | Check if the vertex exists. 278 | 279 | Args: 280 | ``v_id`` (``Any``): The vertex id. 281 | """ 282 | assert isinstance(v_id, Hashable), "The vertex id must be hashable." 283 | return v_id in self._v_data 284 | 285 | def has_e(self, e_tuple: Union[List, Set, Tuple]) -> bool: 286 | r""" 287 | Check if the hyperedge exists. 288 | 289 | Args: 290 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 291 | """ 292 | assert isinstance( 293 | e_tuple, (list, set, tuple) 294 | ), "The hyperedge must be a list, set, or tuple of vertex ids." 295 | try: 296 | e_tuple = self.encode_e(e_tuple) 297 | except AssertionError: 298 | return False 299 | return e_tuple in self._e_data 300 | 301 | def degree_v(self, v_id: Any) -> int: 302 | r""" 303 | Return the degree of the vertex. 304 | 305 | Args: 306 | ``v_id`` (``Any``): The vertex id. 307 | """ 308 | assert isinstance(v_id, Hashable), "The vertex id must be hashable." 309 | assert ( 310 | v_id in self._v_data 311 | ), f"The vertex {v_id} does not exist in the hypergraph." 312 | return len(self._v_inci[v_id]) 313 | 314 | def degree_e(self, e_tuple: Union[List, Set, Tuple]) -> int: 315 | r""" 316 | Return the degree of the hyperedge. 317 | 318 | Args: 319 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 320 | """ 321 | assert isinstance( 322 | e_tuple, (list, set, tuple) 323 | ), "The hyperedge must be a list, set, or tuple of vertex ids." 324 | e_tuple = self.encode_e(e_tuple) 325 | assert ( 326 | e_tuple in self._e_data 327 | ), f"The hyperedge {e_tuple} does not exist in the hypergraph." 328 | return len(e_tuple) 329 | 330 | def nbr_e_of_v(self, v_id: Any) -> list: 331 | r""" 332 | Return the incident hyperedges of the vertex. 333 | 334 | Args: 335 | ``v_id`` (``Any``): The vertex id. 336 | """ 337 | assert isinstance(v_id, Hashable), "The vertex id must be hashable." 338 | assert ( 339 | v_id in self._v_data 340 | ), f"The vertex {v_id} does not exist in the hypergraph." 341 | return set(self._v_inci[v_id]) 342 | 343 | def nbr_v_of_e(self, e_tuple: Union[List, Set, Tuple]) -> list: 344 | r""" 345 | Return the incident vertices of the hyperedge. 346 | 347 | Args: 348 | ``e_tuple`` (``Union[List, Set, Tuple]``): The hyperedge tuple: (v1_name, v2_name, ..., vn_name). 349 | """ 350 | assert isinstance( 351 | e_tuple, (list, set, tuple) 352 | ), "The hyperedge must be a list, set, or tuple of vertex ids." 353 | e_tuple = self.encode_e(e_tuple) 354 | assert ( 355 | e_tuple in self._e_data 356 | ), f"The hyperedge {e_tuple} does not exist in the hypergraph." 357 | return set(e_tuple) 358 | 359 | def nbr_v(self, v_id: Any, exclude_self=True) -> list: 360 | r""" 361 | Return the neighbors of the vertex. 362 | 363 | Args: 364 | ``v_id`` (``Any``): The vertex id. 365 | """ 366 | assert isinstance(v_id, Hashable), "The vertex id must be hashable." 367 | assert ( 368 | v_id in self._v_data 369 | ), f"The vertex {v_id} does not exist in the hypergraph." 370 | nbrs = set() 371 | for e_tuple in self._v_inci[v_id]: 372 | nbrs.update(e_tuple) 373 | if exclude_self: 374 | nbrs.remove(v_id) 375 | return set(nbrs) 376 | -------------------------------------------------------------------------------- /web-ui/backend/hypergraph_A_Christmas_Carol.hgdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/Hyper-RAG/133529d4250273049a7e11192b142db0f92e8ac3/web-ui/backend/hypergraph_A_Christmas_Carol.hgdb -------------------------------------------------------------------------------- /web-ui/backend/hypergraph_wukong.hgdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/Hyper-RAG/133529d4250273049a7e11192b142db0f92e8ac3/web-ui/backend/hypergraph_wukong.hgdb -------------------------------------------------------------------------------- /web-ui/backend/main.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | from fastapi import FastAPI 7 | from fastapi.middleware.cors import CORSMiddleware 8 | from db import get_hypergraph, getFrequentVertices, get_hyperedges, get_vertice, get_vertice_neighbor, get_hyperedge_neighbor_server 9 | 10 | app = FastAPI() 11 | 12 | app.add_middleware( 13 | CORSMiddleware, 14 | allow_origins=["*"], 15 | allow_credentials=True, 16 | allow_methods=["*"], 17 | allow_headers=["*"], 18 | ) 19 | 20 | @app.get("/") 21 | async def root(): 22 | return {"message": "Hyper-RAG"} 23 | 24 | 25 | @app.get("/db") 26 | async def db(): 27 | """ 28 | 获取全部数据json 29 | """ 30 | data = get_hypergraph() 31 | return data 32 | 33 | @app.get("/db/vertices") 34 | async def get_vertices_function(): 35 | """ 36 | 获取vertices列表 37 | """ 38 | data = getFrequentVertices() 39 | return data 40 | 41 | @app.get("/db/hyperedges") 42 | async def get_hypergraph_function(): 43 | """ 44 | 获取hyperedges列表 45 | """ 46 | data = get_hyperedges() 47 | return data 48 | 49 | @app.get("/db/vertices/{vertex_id}") 50 | async def get_vertex(vertex_id: str): 51 | """ 52 | 获取指定vertex的json 53 | """ 54 | vertex_id = vertex_id.replace("%20", " ") 55 | data = get_vertice(vertex_id) 56 | return data 57 | 58 | @app.get("/db/vertices_neighbor/{vertex_id}") 59 | async def get_vertex_neighbor(vertex_id: str): 60 | """ 61 | 获取指定vertex的neighbor 62 | """ 63 | vertex_id = vertex_id.replace("%20", " ") 64 | data = get_vertice_neighbor(vertex_id) 65 | return data 66 | 67 | @app.get("/db/hyperedge_neighbor/{hyperedge_id}") 68 | async def get_hyperedge_neighbor(hyperedge_id: str): 69 | """ 70 | 获取指定hyperedge的neighbor 71 | """ 72 | hyperedge_id = hyperedge_id.replace("%20", " ") 73 | hyperedge_id = hyperedge_id.replace("*", "#") 74 | print(hyperedge_id) 75 | data = get_hyperedge_neighbor_server(hyperedge_id) 76 | return data 77 | 78 | def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs) -> str: 79 | openai_client = OpenAI(api_key="your_api_key", base_url="your_api_url") 80 | 81 | messages = [] 82 | if system_prompt: 83 | messages.append({"role": "system", "content": system_prompt}) 84 | messages.extend(history_messages) 85 | messages.append({"role": "user", "content": prompt}) 86 | 87 | response = openai_client.chat.completions.create( 88 | model="your_model", messages=messages, **kwargs 89 | ) 90 | return response.choices[0].message.content 91 | 92 | from pydantic import BaseModel 93 | class Message(BaseModel): 94 | message: str 95 | 96 | @app.post("/process_message") 97 | async def process_message(msg: Message): 98 | user_message = msg.message 99 | try: 100 | response_message = llm_model_func(prompt=user_message) 101 | except Exception as e: 102 | return {"response": str(e)} 103 | return {"response": response_message} -------------------------------------------------------------------------------- /web-ui/backend/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest -------------------------------------------------------------------------------- /web-ui/frontend/.commitlintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | extends: ['@commitlint/config-conventional'], 3 | rules: { 4 | 'type-enum': [ 5 | 2, 6 | 'always', 7 | ['feat', 'fix', 'refactor', 'docs', 'chore', 'style', 'revert'] 8 | ], 9 | 'type-case': [0], 10 | 'type-empty': [0], 11 | 'scope-empty': [0], 12 | 'scope-case': [0], 13 | 'subject-full-stop': [0, 'never'], 14 | 'subject-case': [0, 'never'], 15 | 'header-max-length': [0, 'always', 72] 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /web-ui/frontend/.env.mock: -------------------------------------------------------------------------------- 1 | VITE_MODE='mock' -------------------------------------------------------------------------------- /web-ui/frontend/.env.production: -------------------------------------------------------------------------------- 1 | VITE_MODE='production' 2 | VITE_SERVER_URL='http://hyper.dappwind.com:8000' 3 | VITE_APP_URL='' 4 | VITE_APP_NAME='' 5 | -------------------------------------------------------------------------------- /web-ui/frontend/.eslintignore: -------------------------------------------------------------------------------- 1 | **/*.svg 2 | package.json 3 | /dist 4 | .dockerignore 5 | .eslintignore 6 | *.png 7 | *.toml 8 | docker 9 | .editorconfig 10 | Dockerfile* 11 | .gitignore 12 | .prettierignore 13 | LICENSE 14 | .eslintcache 15 | *.lock 16 | yarn-error.log 17 | .history 18 | CNAME 19 | /build 20 | /public 21 | /src/pages/Hyper/Graph/index.tsx 22 | /server -------------------------------------------------------------------------------- /web-ui/frontend/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | env: { 3 | browser: true, 4 | es2021: true, 5 | node: true 6 | }, 7 | extends: [ 8 | 'eslint:recommended', 9 | 'plugin:react/recommended', 10 | 'plugin:@typescript-eslint/recommended', 11 | 'eslint-config-prettier' 12 | ], 13 | overrides: [], 14 | parser: '@typescript-eslint/parser', 15 | parserOptions: { 16 | ecmaVersion: 'latest', 17 | sourceType: 'module' 18 | }, 19 | plugins: ['react', '@typescript-eslint', 'prettier'], 20 | rules: { 21 | 'react/react-in-jsx-scope': 'off', 22 | // semi: ['warn', 'never'], // 禁止尾部使用分号 23 | 'no-debugger': 'warn', // 禁止出现debugger 24 | 'no-duplicate-case': 'warn', // 禁止出现重复case 25 | 'no-empty': 'warn', // 禁止出现空语句块 26 | // 'no-extra-parens': 'warn', // 禁止不必要的括号 27 | 'no-func-assign': 'warn', // 禁止对Function声明重新赋值 28 | 'no-unreachable': 'warn', // 禁止出现[return|throw]之后的代码块 29 | 'no-else-return': 'warn', // 禁止if语句中return语句之后有else块 30 | 'no-empty-function': 'warn', // 禁止出现空的函数块 31 | 'no-lone-blocks': 'warn', // 禁用不必要的嵌套块 32 | 'no-multi-spaces': 'warn', // 禁止使用多个空格 33 | 'no-redeclare': 'warn', // 禁止多次声明同一变量 34 | 'no-return-assign': 'warn', // 禁止在return语句中使用赋值语句 35 | 'no-return-await': 'warn', // 禁用不必要的[return/await] 36 | 'no-self-compare': 'warn', // 禁止自身比较表达式 37 | 'no-useless-catch': 'warn', // 禁止不必要的catch子句 38 | 'no-useless-return': 'warn', // 禁止不必要的return语句 39 | 'no-mixed-spaces-and-tabs': 'warn', // 禁止空格和tab的混合缩进 40 | 'no-multiple-empty-lines': 'warn', // 禁止出现多行空行 41 | 'no-trailing-spaces': 'warn', // 禁止一行结束后面不要有空格 42 | 'no-useless-call': 'warn', // 禁止不必要的.call()和.apply() 43 | 'no-var': 'warn', // 禁止出现var用let和const代替 44 | 'no-delete-var': 'off', // 允许出现delete变量的使用 45 | 'no-shadow': 'off', // 允许变量声明与外层作用域的变量同名 46 | 'dot-notation': 'warn', // 要求尽可能地使用点号 47 | 'default-case': 'warn', // 要求switch语句中有default分支 48 | eqeqeq: 'warn', // 要求使用 === 和 !== 49 | curly: 'warn', // 要求所有控制语句使用一致的括号风格 50 | 'space-before-blocks': 'warn', // 要求在块之前使用一致的空格 51 | 'space-in-parens': 'warn', // 要求在圆括号内使用一致的空格 52 | 'space-infix-ops': 'warn', // 要求操作符周围有空格 53 | 'space-unary-ops': 'warn', // 要求在一元操作符前后使用一致的空格 54 | 'switch-colon-spacing': 'warn', // 要求在switch的冒号左右有空格 55 | 'arrow-spacing': 'warn', // 要求箭头函数的箭头前后使用一致的空格 56 | 'array-bracket-spacing': 'warn', // 要求数组方括号中使用一致的空格 57 | 'brace-style': 'warn', // 要求在代码块中使用一致的大括号风格 58 | // indent: ['warn', 2], // 要求使用JS一致缩进4个空格 59 | 'max-depth': ['warn', 4], // 要求可嵌套的块的最大深度4 60 | 'max-statements': ['warn', 100], // 要求函数块最多允许的的语句数量20 61 | 'max-nested-callbacks': ['warn', 3], // 要求回调函数最大嵌套深度3 62 | 'max-statements-per-line': ['warn', { max: 1 }], // 要求每一行中所允许的最大语句数量 63 | // quotes: ['warn', 'single', 'avoid-escape'], // 要求统一使用单引号符号 64 | '@typescript-eslint/no-explicit-any': 'off', 65 | '@typescript-eslint/no-unused-vars': 'off', 66 | 'react/prop-types': 'off', 67 | 'react/display-name': 'off', 68 | '@typescript-eslint/explicit-module-boundary-types': 'off' 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /web-ui/frontend/.prettierignore: -------------------------------------------------------------------------------- 1 | **/*.svg 2 | package.json 3 | /dist 4 | .dockerignore 5 | .eslintignore 6 | *.png 7 | *.toml 8 | docker 9 | .editorconfig 10 | Dockerfile* 11 | .gitignore 12 | .prettierignore 13 | LICENSE 14 | .eslintcache 15 | *.lock 16 | yarn-error.log 17 | .history 18 | CNAME 19 | /build 20 | /public -------------------------------------------------------------------------------- /web-ui/frontend/.prettierrc: -------------------------------------------------------------------------------- 1 | { 2 | "tabWidth": 2, 3 | "semi": false, 4 | "trailingComma": "none", 5 | "singleQuote": true, 6 | "printWidth": 100, 7 | "arrowParens": "avoid", 8 | "bracketSpacing": true, 9 | 10 | "endOfLine": "auto", 11 | "useTabs": false, 12 | "quoteProps": "as-needed", 13 | "jsxSingleQuote": false, 14 | "jsxBracketSameLine": false, 15 | "rangeStart": 0, 16 | "requirePragma": false, 17 | "insertPragma": false, 18 | "proseWrap": "preserve", 19 | "htmlWhitespaceSensitivity": "css" 20 | } -------------------------------------------------------------------------------- /web-ui/frontend/.stylelintignore: -------------------------------------------------------------------------------- 1 | *.js 2 | *.tsx 3 | *.ts 4 | *.json 5 | *.png 6 | *.eot 7 | *.ttf 8 | *.woff 9 | *.css -------------------------------------------------------------------------------- /web-ui/frontend/.stylelintrc.cjs: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | extends: ['stylelint-config-standard'], 3 | rules: { 4 | 'selector-class-pattern': null, 5 | 'color-function-notation': null, 6 | 'at-rule-no-unknown': null, 7 | 'alpha-value-notation': null 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /web-ui/frontend/README.md: -------------------------------------------------------------------------------- 1 | ## 安装依赖 2 | 3 | ```bash 4 | pnpm install 5 | ``` 6 | 7 | ## 脚本描述 8 | 9 | ### 开发启动 10 | ```bash 11 | # mock模式启动 12 | npm run dev 13 | ``` 14 | 15 | ### 打包 16 | 17 | ```bash 18 | npm run build 19 | ``` 20 | -------------------------------------------------------------------------------- /web-ui/frontend/config/defaultSettings.ts: -------------------------------------------------------------------------------- 1 | import { ProSettings } from '@ant-design/pro-components' 2 | 3 | /** prolayput 设置 */ 4 | const Settings: ProSettings | undefined = { 5 | fixSiderbar: true, 6 | layout: 'mix', 7 | title: 'Hyper-RAG', 8 | logo: '/logo.png' 9 | } 10 | 11 | export default Settings 12 | -------------------------------------------------------------------------------- /web-ui/frontend/config/mock/user.ts: -------------------------------------------------------------------------------- 1 | import { MockMethod } from 'vite-plugin-mock' 2 | 3 | export default [ 4 | { 5 | url: '/api/v1/admin/login', 6 | method: 'post', 7 | response: ({ body }: any) => { 8 | const resObj: Global.ResultType = { 9 | code: 200, 10 | message: '操作成功', 11 | data: { 12 | tokenHead: 'Bearer ', 13 | token: 14 | 'eyJhbGciOiJIUzUxMiJ9.eyJzdWIiOiJhZG1pbiIsImNyZWF0ZWQiOjE2ODkyMjY5MzczNDYsImV4cCI6MTY4OTgzMTczN30.b5D3MhMRhKZDC9iXYxrW29IXdDUch6hSx9G2h9c5iJsayvAE1bm0DJZe4dp32y95yOy98UJrYesN52-cFgpI9Q' 15 | } 16 | } 17 | return resObj 18 | } 19 | }, 20 | { 21 | url: '/api/v1/admin/info', 22 | method: 'get', 23 | response: ({ body }: any) => { 24 | const resObj: Global.ResultType = { 25 | code: 200, 26 | message: '操作成功', 27 | data: {} 28 | } 29 | return resObj 30 | } 31 | } 32 | ] as MockMethod[] 33 | -------------------------------------------------------------------------------- /web-ui/frontend/config/proxy.ts: -------------------------------------------------------------------------------- 1 | export default { 2 | development: { 3 | '/api/': { 4 | target: 'http://127.0.0.1:3060', 5 | changeOrigin: true, 6 | rewrite: (path: string) => path.replace('^/', '') 7 | } 8 | } 9 | } as any 10 | -------------------------------------------------------------------------------- /web-ui/frontend/config/routes/index.tsx: -------------------------------------------------------------------------------- 1 | import { MenuDataItem } from '@ant-design/pro-components' 2 | import { createBrowserRouter, RouteObject, createHashRouter } from 'react-router-dom' 3 | import { routers } from './routers' 4 | 5 | export type RouteType = { 6 | /** 是否隐藏菜单布局 */ 7 | hideLayout?: boolean 8 | /** 在菜单栏是否显示 */ 9 | hideInMenu?: boolean 10 | /** 权限控制 true 则都控制 */ 11 | permissionObj?: { 12 | /** 是否进行页面权限控制,控制取后端数据 */ 13 | isPagePermission?: boolean 14 | /** 判断token是否存在控制 */ 15 | isToken?: boolean 16 | } & true 17 | children?: RouteType[] 18 | } & Partial & 19 | RouteObject 20 | 21 | /** 只给最低层级套 Permission 组件 */ 22 | const renderElement = (item: RouteType) => { 23 | if (item?.element) { 24 | if (item?.children) { 25 | return item?.element 26 | } 27 | return ( 28 | // 29 | item?.element 30 | // 31 | ) 32 | } 33 | return undefined 34 | } 35 | 36 | const reduceRoute: (params: RouteType[]) => RouteType[] = (routesParams: RouteType[]) => { 37 | return routesParams?.map(item => { 38 | let curRouter = item 39 | if (!item?.children) { 40 | curRouter = { 41 | ...curRouter, 42 | element: renderElement(item) 43 | } 44 | } 45 | if (item?.children) { 46 | curRouter = { 47 | ...curRouter, 48 | children: reduceRoute(item?.children) as any 49 | } 50 | } 51 | return curRouter 52 | }) 53 | } 54 | 55 | const relRouters = reduceRoute(routers) 56 | 57 | export const router = createHashRouter(relRouters) 58 | -------------------------------------------------------------------------------- /web-ui/frontend/config/routes/routers.jsx: -------------------------------------------------------------------------------- 1 | import NotFoundPage from '@/404' 2 | import App from '@/App' 3 | import ErrorPage from '@/ErrorPage' 4 | import Home from '@/pages/Home' 5 | import Files from '@/pages/Hyper/Files' 6 | import Graph from '@/pages/Hyper/Graph' 7 | import { HomeFilled, SmileFilled, FileAddOutlined, QuestionCircleOutlined, DeploymentUnitOutlined, DatabaseOutlined, SettingOutlined } from '@ant-design/icons' 8 | import { Navigate } from 'react-router-dom' 9 | 10 | export const routers = [ 11 | { 12 | path: '/', 13 | element: 14 | }, 15 | { 16 | path: '/', 17 | element: , 18 | errorElement: , 19 | icon: , 20 | children: [ 21 | { 22 | path: '/Hyper/show', 23 | name: '超图展示', 24 | icon: , 25 | // permissionObj: true, 26 | element: 27 | }, 28 | { 29 | path: '/Hyper/qa', 30 | name: '检索问答', 31 | icon: , 32 | // permissionObj: true, 33 | element: 34 | }, 35 | { 36 | path: '/Hyper/files', 37 | name: '文档上传', 38 | icon: , 39 | element: , 40 | }, 41 | { 42 | path: '/Hyper/DB', 43 | name: 'HypergraphDB', 44 | icon: , 45 | // permissionObj: true, 46 | element: 47 | }, 48 | { 49 | path: '/Setting', 50 | name: 'key设置', 51 | icon: , 52 | // permissionObj: true, 53 | element: 54 | }, 55 | ] 56 | }, 57 | { path: '*', element: } 58 | ] 59 | -------------------------------------------------------------------------------- /web-ui/frontend/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Vite + React + TS 8 | 9 | 10 |
11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /web-ui/frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hyper-rag-web", 3 | "private": true, 4 | "version": "0.0.0", 5 | "scripts": { 6 | "dev": "vite --mode mock", 7 | "server": "node server.js", 8 | "start:mock": "vite --mode mock", 9 | "start:production": "vite --mode production", 10 | "build": "vite build", 11 | "preview": "vite preview", 12 | "lint:script": "eslint --ext .js,.jsx,.ts,.tsx --fix ./", 13 | "lint:style": "stylelint --fix **/*.{css,less,scss}" 14 | }, 15 | "lint-staged": { 16 | "*.{js,jsx,ts,tsx}": [ 17 | "eslint --fix", 18 | "prettier --write", 19 | "git add ." 20 | ], 21 | "*.{json.md,xml,svg,html,js,jsx}": "prettier --write", 22 | "*.less": [ 23 | "stylelint --fix --custom-syntax postcss-less", 24 | "git add ." 25 | ] 26 | }, 27 | "dependencies": { 28 | "@ant-design/icons": "^5.0.1", 29 | "@ant-design/pro-components": "^2.6.48", 30 | "@ant-design/x": "^1.1.0", 31 | "@antv/g6": "^5.0.44", 32 | "@antv/graphin": "^3.0.4", 33 | "@types/qs": "^6.9.7", 34 | "@websee/core": "^4.0.2", 35 | "@websee/performance": "^4.0.2", 36 | "@websee/recordscreen": "^4.0.2", 37 | "ahooks": "^3.7.5", 38 | "antd": "^5.15.0", 39 | "antd-style": "^3.7.1", 40 | "axios": "^1.3.4", 41 | "dayjs": "^1.11.10", 42 | "koa": "^2.16.1", 43 | "koa-static": "^5.0.0", 44 | "less": "^4.1.3", 45 | "lodash": "^4.17.21", 46 | "mobx": "^6.8.0", 47 | "mobx-react": "^7.6.0", 48 | "postcss": "^8.4.21", 49 | "postcss-less": "^6.0.0", 50 | "react": "^18.2.0", 51 | "react-dom": "^18.2.0", 52 | "react-router-dom": "^6.8.2" 53 | }, 54 | "devDependencies": { 55 | "@babel/core": "^7.21.0", 56 | "@babel/eslint-parser": "^7.19.1", 57 | "@commitlint/cli": "^17.4.4", 58 | "@commitlint/config-conventional": "^17.4.4", 59 | "@types/lodash": "^4.14.195", 60 | "@types/node": "^18.14.5", 61 | "@types/react": "^18.2.56", 62 | "@types/react-dom": "^18.2.19", 63 | "@typescript-eslint/eslint-plugin": "^7.0.2", 64 | "@typescript-eslint/parser": "^7.0.2", 65 | "@vitejs/plugin-react": "^3.1.0", 66 | "@vitejs/plugin-react-swc": "^3.5.0", 67 | "eslint": "^8.56.0", 68 | "eslint-config-prettier": "^8.6.0", 69 | "eslint-plugin-prettier": "^4.2.1", 70 | "eslint-plugin-react": "^7.32.2", 71 | "eslint-plugin-react-hooks": "^4.6.0", 72 | "eslint-plugin-react-refresh": "^0.4.5", 73 | "mockjs": "^1.1.0", 74 | "prettier": "^2.8.4", 75 | "qs": "^6.11.0", 76 | "stylelint": "^15.2.0", 77 | "stylelint-config-standard": "^30.0.1", 78 | "typescript": "^5.2.2", 79 | "vite": "^5.1.4", 80 | "vite-plugin-eslint": "^1.8.1", 81 | "vite-plugin-mock": "^2.9.6" 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /web-ui/frontend/public/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/Hyper-RAG/133529d4250273049a7e11192b142db0f92e8ac3/web-ui/frontend/public/logo.png -------------------------------------------------------------------------------- /web-ui/frontend/server.js: -------------------------------------------------------------------------------- 1 | const Koa = require('koa'); 2 | const path = require('path'); 3 | const serve = require('koa-static'); 4 | const app = new Koa(); 5 | 6 | 7 | const home = serve(path.join(__dirname) + '/dist/', { 8 | gzip: true, 9 | }); 10 | 11 | app.use(async (ctx, next) => { 12 | console.log(new Date(), ctx.request.url); 13 | await next(); 14 | }) 15 | 16 | app.use(home); 17 | app.listen(5000); 18 | console.log('server is running at http://localhost:5000'); -------------------------------------------------------------------------------- /web-ui/frontend/src/404.tsx: -------------------------------------------------------------------------------- 1 | import NotFound from './components/NotFound' 2 | 3 | const NotFoundPage = () => { 4 | return 5 | } 6 | export default NotFoundPage 7 | -------------------------------------------------------------------------------- /web-ui/frontend/src/App.tsx: -------------------------------------------------------------------------------- 1 | import BasicLayout from './layout/BasicLayout' 2 | 3 | const App = () => { 4 | return ( 5 |
10 | 11 |
12 | ) 13 | } 14 | 15 | export default App 16 | -------------------------------------------------------------------------------- /web-ui/frontend/src/ErrorPage.tsx: -------------------------------------------------------------------------------- 1 | import { useRouteError } from 'react-router-dom' 2 | 3 | const ErrorPage = () => { 4 | // 使用 useRouteError 取得路由錯誤資訊 5 | const error: any = useRouteError() 6 | console.error(error) 7 | 8 | // 页面刷新 9 | window.location.reload() 10 | 11 | return
12 | } 13 | export default ErrorPage 14 | -------------------------------------------------------------------------------- /web-ui/frontend/src/_defaultProps.tsx: -------------------------------------------------------------------------------- 1 | import { ChromeFilled, CrownFilled, SmileFilled, TabletFilled } from '@ant-design/icons' 2 | 3 | export default { 4 | route: { 5 | path: '/', 6 | routes: [] 7 | }, 8 | location: { 9 | pathname: '/' 10 | }, 11 | appList: [] 12 | } 13 | -------------------------------------------------------------------------------- /web-ui/frontend/src/assets/react.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /web-ui/frontend/src/assets/show.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/Hyper-RAG/133529d4250273049a7e11192b142db0f92e8ac3/web-ui/frontend/src/assets/show.png -------------------------------------------------------------------------------- /web-ui/frontend/src/components/NotFound/index.tsx: -------------------------------------------------------------------------------- 1 | import { Button, Result } from 'antd' 2 | import React from 'react' 3 | import { Link } from 'react-router-dom' 4 | import type { NotFoundPropsType } from './type' 5 | 6 | const NotFound: React.FC = ({ 7 | status = '404', 8 | title = '404', 9 | subTitle = '对不起!您访问的页面不存在', 10 | extra = ( 11 | 14 | ) 15 | }) => { 16 | return ( 17 | <> 18 | 19 | 20 | ) 21 | } 22 | 23 | export default NotFound 24 | -------------------------------------------------------------------------------- /web-ui/frontend/src/components/NotFound/type.d.ts: -------------------------------------------------------------------------------- 1 | import type { ResultStatusType } from 'antd/lib/result' 2 | import type React from 'react' 3 | 4 | export type NotFoundPropsType = { 5 | status?: ResultStatusType 6 | title?: string 7 | subTitle?: string 8 | extra?: React.ReactDOM | React.JSXElementConstructor 9 | } 10 | -------------------------------------------------------------------------------- /web-ui/frontend/src/components/errorBoundary.jsx: -------------------------------------------------------------------------------- 1 | import React, { Component } from 'react'; 2 | 3 | class ErrorBoundary extends Component { 4 | constructor(props) { 5 | super(props); 6 | this.state = { hasError: false, error: null, errorInfo: null }; 7 | } 8 | 9 | static getDerivedStateFromError(error) { 10 | // 更新 state 使下一次渲染能够显示降级后的 UI 11 | return { hasError: true }; 12 | } 13 | 14 | componentDidCatch(error, errorInfo) { 15 | // 你也可以将错误日志上报给服务器 16 | console.error("ErrorBoundary caught an error:", error, errorInfo); 17 | this.setState({ 18 | hasError: true, 19 | error: error, 20 | errorInfo: errorInfo 21 | }); 22 | } 23 | 24 | render() { 25 | if (this.state.hasError) { 26 | // 你可以自定义降级后的 UI 并渲染 27 | return ( 28 |
29 | ); 30 | } 31 | 32 | // console.error("ErrorBoundary caught an error:", this.state.error, this.state.errorInfo); 33 | 34 | return this.props.children; 35 | } 36 | } 37 | 38 | export default ErrorBoundary; -------------------------------------------------------------------------------- /web-ui/frontend/src/components/loading/index.module.less: -------------------------------------------------------------------------------- 1 | .box { 2 | display: flex; 3 | justify-content: center; 4 | align-items: center; 5 | width: '100%'; 6 | min-height: 60vh; 7 | height: '100%'; 8 | 9 | .container { 10 | position: relative; 11 | height: 150px; 12 | width: 250px; 13 | -webkit-box-reflect: below 1px linear-gradient(transparent, rgb(227, 231, 238)); 14 | } 15 | 16 | .container > span { 17 | position: absolute; 18 | left: 50%; 19 | top: 50%; 20 | transform: translate(-50%, -50%); 21 | color: rgb(20, 129, 202); 22 | text-shadow: 0 0 10px rgb(20, 129, 202), 0 0 30px rgb(20, 129, 202), 0 0 60px rgb(20, 129, 202), 23 | 0 0 100px rgb(20, 129, 202); 24 | font-size: 18px; 25 | z-index: 1; 26 | } 27 | 28 | .circle { 29 | position: relative; 30 | margin: 0 auto; 31 | height: 150px; 32 | width: 150px; 33 | background-color: rgb(219, 50, 168); 34 | border-radius: 50%; 35 | animation: zhuan 2s linear infinite; 36 | } 37 | 38 | @keyframes zhuan { 39 | 0% { 40 | transform: rotate(0deg); 41 | } 42 | 43 | 100% { 44 | transform: rotate(360deg); 45 | } 46 | } 47 | 48 | .circle::after { 49 | content: ''; 50 | position: absolute; 51 | top: 10px; 52 | left: 10px; 53 | right: 10px; 54 | bottom: 10px; 55 | background-color: rgb(71, 21, 64); 56 | border-radius: 50%; 57 | } 58 | 59 | .ring { 60 | position: absolute; 61 | top: 0; 62 | left: 0; 63 | width: 75px; 64 | height: 150px; 65 | background-image: linear-gradient(180deg, rgb(22, 121, 252), transparent 80%); 66 | border-radius: 75px 0 0 75px; 67 | } 68 | 69 | .ring::after { 70 | content: ''; 71 | position: absolute; 72 | right: -5px; 73 | top: -2.5px; 74 | width: 15px; 75 | height: 15px; 76 | background-color: rgb(40, 124, 202); 77 | box-shadow: 0 0 5px rgb(40, 151, 202), 0 0 10px rgb(40, 124, 202), 0 0 20px rgb(40, 124, 202), 78 | 0 0 30px rgb(40, 124, 202), 0 0 40px rgb(40, 124, 202), 0 0 50px rgb(40, 124, 202), 79 | 0 0 60px rgb(40, 124, 202), 0 0 60px rgb(40, 124, 202); 80 | border-radius: 50%; 81 | z-index: 1; 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /web-ui/frontend/src/components/loading/index.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react' 2 | import styles from './index.module.less' 3 | 4 | const Loading: React.FC = props => { 5 | return ( 6 |
7 |
8 | Loading... 9 |
10 |
11 |
12 |
13 |
14 | ) 15 | } 16 | 17 | // Loading.defaultProps = { 18 | // imgUrl: loading 19 | // } 20 | 21 | export default Loading 22 | -------------------------------------------------------------------------------- /web-ui/frontend/src/layout/BasicLayout.tsx: -------------------------------------------------------------------------------- 1 | import { storeGlobalUser } from '@/store/globalUser' 2 | import { storage } from '@/utils' 3 | import { PageContainer, ProLayout } from '@ant-design/pro-components' 4 | import { RouteType, router } from '@config/routes' 5 | import { useAsyncEffect } from 'ahooks' 6 | import { Dropdown, MenuProps } from 'antd' 7 | import { useEffect, useState } from 'react' 8 | import { Outlet, matchRoutes, useLocation, useNavigate } from 'react-router-dom' 9 | import defaultProps from '@/_defaultProps' 10 | import Settings from '@config/defaultSettings' 11 | import { observer } from 'mobx-react' 12 | import React from 'react' 13 | import { routers } from '@config/routes/routers' 14 | 15 | export enum ComponTypeEnum { 16 | MENU, 17 | PAGE, 18 | COMPON 19 | } 20 | 21 | export const GlobalUserInfo = React.createContext>({}) 22 | 23 | const BasicLayout: React.FC = props => { 24 | const [pathname, setPathname] = useState(window.location.hash.replace('#', '')) 25 | const navigate = useNavigate() 26 | const location = useLocation() 27 | const matchRoute = matchRoutes(routers, location) 28 | 29 | const [showLayout, setShowLayout] = useState(false) 30 | 31 | /** 处理菜单权限隐藏菜单 */ 32 | const reduceRouter = (routers: RouteType[]): RouteType[] => { 33 | const authMenus = storeGlobalUser?.userInfo?.menus 34 | ?.filter(item => item?.type === ComponTypeEnum.MENU || item?.type === ComponTypeEnum.PAGE) 35 | ?.map(item => item?.title) 36 | 37 | return routers?.map(item => { 38 | if (item?.children) { 39 | const { children, ...extra } = item 40 | return { 41 | ...extra, 42 | routes: reduceRouter(item?.children), 43 | hideInMenu: item?.hideInMenu 44 | } 45 | } 46 | return { 47 | ...item, 48 | hideInMenu: item?.hideInMenu 49 | } 50 | }) as any 51 | } 52 | 53 | useEffect(() => { 54 | setPathname(window.location.hash.replace('#', '')) 55 | setShowLayout(!matchRoute?.[matchRoute?.length - 1]?.route?.hideLayout) 56 | }, [window.location.hash]) 57 | 58 | useAsyncEffect(async () => { 59 | if (pathname !== '/login') { 60 | await storeGlobalUser.getUserDetail() 61 | } 62 | }, []) 63 | 64 | const items: MenuProps['items'] = [ 65 | { 66 | key: 'out', 67 | label: ( 68 |
{ 70 | storage.clear() 71 | // navigate('login', { replace: true }) 72 | }} 73 | > 74 | 退出登录 75 |
76 | ) 77 | } 78 | ] 79 | 80 | return ( 81 | 82 | {showLayout ? ( 83 | { 94 | return {defaultDom} 95 | } 96 | }} 97 | menuFooterRender={props => { 98 | return ( 99 |
105 |
106 |
107 |
108 | ) 109 | }} 110 | menuProps={{ 111 | onClick: ({ key }) => { 112 | navigate(key || '/') 113 | } 114 | }} 115 | ErrorBoundary={false} 116 | {...Settings} 117 | > 118 | 119 | 120 | 121 |
122 | ) : ( 123 | 124 | )} 125 |
126 | ) 127 | } 128 | 129 | export default observer(BasicLayout) 130 | -------------------------------------------------------------------------------- /web-ui/frontend/src/main.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react' 2 | import ReactDOM from 'react-dom/client' 3 | import { RouterProvider } from 'react-router-dom' 4 | import { router } from '../config/routes' 5 | 6 | import Loading from './components/loading' 7 | 8 | ReactDOM.createRoot(document.getElementById('root') as HTMLElement).render( 9 | } /> 10 | ) 11 | -------------------------------------------------------------------------------- /web-ui/frontend/src/pages/Home/index.jsx: -------------------------------------------------------------------------------- 1 | import { 2 | Attachments, 3 | Bubble, 4 | Conversations, 5 | Prompts, 6 | Sender, 7 | Welcome, 8 | useXAgent, 9 | useXChat 10 | } from '@ant-design/x' 11 | import { createStyles } from 'antd-style' 12 | import { observer } from 'mobx-react' 13 | import React, { useEffect } from 'react' 14 | import { 15 | CloudUploadOutlined, 16 | CommentOutlined, 17 | EllipsisOutlined, 18 | FireOutlined, 19 | HeartOutlined, 20 | PaperClipOutlined, 21 | PlusOutlined, 22 | ReadOutlined, 23 | ShareAltOutlined, 24 | SmileOutlined, 25 | RightOutlined 26 | } from '@ant-design/icons' 27 | import { Badge, Button, Space } from 'antd' 28 | const renderTitle = (icon, title) => ( 29 | 30 | {icon} 31 | {title} 32 | 33 | ) 34 | const defaultConversationsItems = [ 35 | { 36 | key: '0', 37 | label: 'What is Hyper-RAG?' 38 | } 39 | ] 40 | const useStyle = createStyles(({ token, css }) => { 41 | return { 42 | topMenu: css` 43 | display: flex; 44 | justify-content: space-between; 45 | align-items: center; 46 | padding: 0 ${token.padding}px; 47 | background: ${token.colorBgContainer}; 48 | border-radius: ${token.borderRadius}px; 49 | display: flex; 50 | height: 80px; 51 | margin-bottom: 10px; 52 | background: ${token.colorBgContainer}; 53 | font-family: AlibabaPuHuiTi, ${token.fontFamily}, sans-serif; 54 | `, 55 | topCard: css` 56 | height: 60px; 57 | width: 100px; 58 | `, 59 | layout: css` 60 | width: 100%; 61 | min-width: 1000px; 62 | height: 600px; 63 | border-radius: ${token.borderRadius}px; 64 | display: flex; 65 | background: ${token.colorBgContainer}; 66 | font-family: AlibabaPuHuiTi, ${token.fontFamily}, sans-serif; 67 | 68 | .ant-prompts { 69 | color: ${token.colorText}; 70 | } 71 | `, 72 | menu: css` 73 | background: ${token.colorBgLayout}80; 74 | width: 280px; 75 | height: 100%; 76 | display: flex; 77 | flex-direction: column; 78 | `, 79 | conversations: css` 80 | padding: 0 12px; 81 | flex: 1; 82 | overflow-y: auto; 83 | `, 84 | chat: css` 85 | height: 100%; 86 | width: 100%; 87 | max-width: 700px; 88 | margin: 0 auto; 89 | box-sizing: border-box; 90 | display: flex; 91 | flex-direction: column; 92 | padding: ${token.paddingLG}px; 93 | gap: 16px; 94 | `, 95 | messages: css` 96 | flex: 1; 97 | `, 98 | placeholder: css` 99 | padding-top: 32px; 100 | `, 101 | sender: css` 102 | box-shadow: ${token.boxShadow}; 103 | `, 104 | logo: css` 105 | display: flex; 106 | height: 72px; 107 | align-items: center; 108 | justify-content: start; 109 | padding: 0 24px; 110 | box-sizing: border-box; 111 | 112 | img { 113 | width: 24px; 114 | height: 24px; 115 | display: inline-block; 116 | } 117 | 118 | span { 119 | display: inline-block; 120 | margin: 0 8px; 121 | font-weight: bold; 122 | color: ${token.colorText}; 123 | font-size: 16px; 124 | } 125 | `, 126 | addBtn: css` 127 | background: #1677ff0f; 128 | border: 1px solid #1677ff34; 129 | width: calc(100% - 24px); 130 | margin: 0 12px 24px 12px; 131 | ` 132 | } 133 | }) 134 | const placeholderPromptsItems = [ 135 | { 136 | key: '1', 137 | label: renderTitle( 138 | , 143 | 'Hot Topics' 144 | ), 145 | description: 'What are you interested in?', 146 | children: [ 147 | { 148 | key: '1-1', 149 | description: `What's new in RAG?` 150 | }, 151 | { 152 | key: '1-2', 153 | description: `What's Hyper-RAG?` 154 | }, 155 | { 156 | key: '1-3', 157 | description: `Where is the doc?` 158 | } 159 | ] 160 | }, 161 | { 162 | key: '2', 163 | label: renderTitle( 164 | , 169 | 'Design Guide' 170 | ), 171 | description: 'How to design a good product?', 172 | children: [ 173 | { 174 | key: '2-1', 175 | icon: , 176 | description: `Know the well` 177 | }, 178 | { 179 | key: '2-2', 180 | icon: , 181 | description: `Set the AI role` 182 | }, 183 | { 184 | key: '2-3', 185 | icon: , 186 | description: `Express the feeling` 187 | } 188 | ] 189 | } 190 | ] 191 | const senderPromptsItems = [ 192 | { 193 | key: '1', 194 | description: 'Hot Topics', 195 | icon: ( 196 | 201 | ) 202 | }, 203 | { 204 | key: '2', 205 | description: 'Design Guide', 206 | icon: ( 207 | 212 | ) 213 | } 214 | ] 215 | const roles = { 216 | ai: { 217 | placement: 'start', 218 | typing: { 219 | step: 5, 220 | interval: 20 221 | }, 222 | styles: { 223 | content: { 224 | borderRadius: 16 225 | } 226 | } 227 | }, 228 | local: { 229 | placement: 'end', 230 | variant: 'shadow' 231 | } 232 | } 233 | const Independent = () => { 234 | // ==================== Style ==================== 235 | const { styles } = useStyle() 236 | 237 | // ==================== State ==================== 238 | const [headerOpen, setHeaderOpen] = React.useState(false) 239 | const [content, setContent] = React.useState('') 240 | const [conversationsItems, setConversationsItems] = React.useState(defaultConversationsItems) 241 | const [activeKey, setActiveKey] = React.useState(defaultConversationsItems[0].key) 242 | const [attachedFiles, setAttachedFiles] = React.useState([]) 243 | 244 | // ==================== Runtime ==================== 245 | const [agent] = useXAgent({ 246 | request: async ({ message }, { onSuccess }) => { 247 | onSuccess(`Mock success return. You said: ${message}`) 248 | } 249 | }) 250 | const { onRequest, messages, setMessages } = useXChat({ 251 | agent 252 | }) 253 | useEffect(() => { 254 | if (activeKey !== undefined) { 255 | setMessages([]) 256 | } 257 | }, [activeKey]) 258 | 259 | // ==================== Event ==================== 260 | const onSubmit = async nextContent => { 261 | if (!nextContent) return 262 | try { 263 | const response = await fetch('http://127.0.0.1:8000/process_message', { 264 | method: 'POST', 265 | headers: { 266 | 'Content-Type': 'application/json', 267 | }, 268 | body: JSON.stringify({ message: nextContent }), 269 | }); 270 | if (!response.ok) { 271 | throw new Error('网络响应异常'); 272 | } 273 | const data = await response.json(); 274 | 275 | setMessages(prevMessages => [ 276 | ...prevMessages, 277 | 278 | { 279 | id: Date.now(), // 使用 id 代替 key 280 | message: nextContent, 281 | status: 'local', // 282 | }, 283 | { 284 | id: Date.now() + 1, // 285 | message: data.response || '没有返回内容', // 286 | status: 'ai', // 287 | } 288 | ]); 289 | 处理后端返回的消息 290 | onRequest(data.response); // 291 | } catch (error) { 292 | console.error('发送消息时出错:', error); 293 | } 294 | // onRequest(nextContent) 295 | setContent('') 296 | }; 297 | const onPromptsItemClick = info => { 298 | onRequest(info.data.description) 299 | } 300 | const onAddConversation = () => { 301 | setConversationsItems([ 302 | ...conversationsItems, 303 | { 304 | key: `${conversationsItems.length}`, 305 | label: `New Conversation ${conversationsItems.length}` 306 | } 307 | ]) 308 | setActiveKey(`${conversationsItems.length}`) 309 | } 310 | const onConversationClick = key => { 311 | setActiveKey(key) 312 | } 313 | const handleFileChange = info => setAttachedFiles(info.fileList) 314 | 315 | // ==================== Nodes ==================== 316 | const placeholderNode = ( 317 | 318 | 325 | 409 | 413 | 414 | 418 | 422 | 426 | 430 | 434 | 438 | 442 | 443 |
445 |
446 |
447 |
448 | {/* 🌟 Logo */} 449 | {logoNode} 450 | {/* 🌟 添加会话 */} 451 | 459 | {/* 🌟 会话管理 */} 460 | 466 |
467 |
468 | {/* 🌟 消息列表 */} 469 | 0 472 | ? items 473 | : [ 474 | { 475 | content: placeholderNode, 476 | variant: 'borderless' 477 | } 478 | ] 479 | } 480 | roles={roles} 481 | className={styles.messages} 482 | /> 483 | {/* 🌟 提示词 */} 484 | 485 | {/* 🌟 输入框 */} 486 | 495 |
496 |
497 | 498 | ) 499 | } 500 | export default observer(Independent) 501 | -------------------------------------------------------------------------------- /web-ui/frontend/src/pages/Hyper/DB/index.jsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iMoonLab/Hyper-RAG/133529d4250273049a7e11192b142db0f92e8ac3/web-ui/frontend/src/pages/Hyper/DB/index.jsx -------------------------------------------------------------------------------- /web-ui/frontend/src/pages/Hyper/Files/index.jsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { InboxOutlined } from '@ant-design/icons'; 3 | import { message, Upload, Table, Tag, Space } from 'antd'; 4 | const { Dragger } = Upload; 5 | const props = { 6 | name: 'file', 7 | multiple: true, 8 | action: 'https://660d2bd96ddfa2943b33731c.mockapi.io/api/upload', 9 | onChange(info) { 10 | const { status } = info.file; 11 | if (status !== 'uploading') { 12 | console.log(info.file, info.fileList); 13 | } 14 | if (status === 'done') { 15 | message.success(`${info.file.name} file uploaded successfully.`); 16 | } else if (status === 'error') { 17 | message.error(`${info.file.name} file upload failed.`); 18 | } 19 | }, 20 | onDrop(e) { 21 | console.log('Dropped files', e.dataTransfer.files); 22 | }, 23 | }; 24 | 25 | const columns = [ 26 | { 27 | title: 'Name', 28 | dataIndex: 'name', 29 | key: 'name', 30 | render: (text) => {text}, 31 | }, 32 | { 33 | title: 'Age', 34 | dataIndex: 'age', 35 | key: 'age', 36 | }, 37 | { 38 | title: 'Address', 39 | dataIndex: 'address', 40 | key: 'address', 41 | }, 42 | { 43 | title: 'Tags', 44 | key: 'tags', 45 | dataIndex: 'tags', 46 | render: (_, { tags }) => ( 47 | <> 48 | {tags.map((tag) => { 49 | let color = tag.length > 5 ? 'geekblue' : 'green'; 50 | if (tag === 'loser') { 51 | color = 'volcano'; 52 | } 53 | return ( 54 | 55 | {tag.toUpperCase()} 56 | 57 | ); 58 | })} 59 | 60 | ), 61 | }, 62 | { 63 | title: 'Action', 64 | key: 'action', 65 | render: (_, record) => ( 66 | 67 | Invite {record.name} 68 | Delete 69 | 70 | ), 71 | }, 72 | ]; 73 | const data = [ 74 | { 75 | key: '1', 76 | name: 'John Brown', 77 | age: 32, 78 | address: 'New York No. 1 Lake Park', 79 | tags: ['nice', 'developer'], 80 | }, 81 | { 82 | key: '2', 83 | name: 'Jim Green', 84 | age: 42, 85 | address: 'London No. 1 Lake Park', 86 | tags: ['loser'], 87 | }, 88 | { 89 | key: '3', 90 | name: 'Joe Black', 91 | age: 32, 92 | address: 'Sydney No. 1 Lake Park', 93 | tags: ['cool', 'teacher'], 94 | }, 95 | ]; 96 | 97 | const App = () => ( 98 | <> 99 | 100 |

101 | 102 |

103 |

Click or drag file to this area to upload

104 |

105 | Support for a single or bulk upload. Strictly prohibited from uploading company data or other 106 | banned files. 107 |

108 |
109 | 110 | ); 111 | export default App; 112 | 113 | -------------------------------------------------------------------------------- /web-ui/frontend/src/pages/Hyper/Graph/data.js: -------------------------------------------------------------------------------- 1 | export default () => { 2 | return { 3 | vertices: { 4 | "MONSTROUS CHIN": { 5 | "entity_type": "CONCEPT", 6 | "description": "A physical attribute symbolizing greed and gluttony, represented by the character with a large, exaggerated chin in the dialogue amongst the merchants.", 7 | "source_id": "chunk-02baee20cc9463dbe08170a8e1043e32", 8 | "additional_properties": "greed, excess", 9 | "entity_name": "MONSTROUS CHIN" 10 | }, 11 | "SNUFF-BOX": { 12 | "entity_type": "CONCEPT", 13 | "description": "An item used for consuming snuff, representing luxury and leisure among wealthy individuals, highlighting their indifference to the death of a fellow merchant.", 14 | "source_id": "chunk-02baee20cc9463dbe08170a8e1043e32", 15 | "additional_properties": "luxury, indifference", 16 | "entity_name": "SNUFF-BOX" 17 | }, 18 | "FAT MAN": { 19 | "entity_type": "PERSON", 20 | "description": "A character grouped with other merchants, characterized by his obesity and lethargy, seemingly apathetic towards societal issues.", 21 | "source_id": "chunk-02baee20cc9463dbe08170a8e1043e32", 22 | "additional_properties": "apathy, wealth", 23 | "entity_name": "FAT MAN" 24 | }, 25 | "RED-FACED GENTLEMAN": { 26 | "entity_type": "PERSON", 27 | "description": "Another merchant characterized by his physical appearance and demeanor, illustrating the attitude of the wealthy towards the death of others.", 28 | "source_id": "chunk-02baee20cc9463dbe08170a8e1043e32", 29 | "additional_properties": "indifference, wealth", 30 | "entity_name": "RED-FACED GENTLEMAN" 31 | }, 32 | "GREAT GOLD SEALS": { 33 | "entity_type": "CONCEPT", 34 | "description": "Symbols of wealth and status among merchants, representing the materialistic values of the time.", 35 | "source_id": "chunk-02baee20cc9463dbe08170a8e1043e32", 36 | "additional_properties": "status, wealth", 37 | "entity_name": "GREAT GOLD SEALS" 38 | } 39 | }, 40 | edges: { 41 | "FAT MAN|#|MONSTROUS CHIN": { 42 | "description": "The character of the fat man, with his monstrous chin, embodies the excesses and apathetic nature of the wealthy merchants towards societal issues.", 43 | "keywords": "greed, excess", 44 | "source_id": "chunk-02baee20cc9463dbe08170a8e1043e32", 45 | "weight": 7.0 46 | }, 47 | "RED-FACED GENTLEMAN|#|SNUFF-BOX": { 48 | "description": "The use of a snuff-box by the red-faced gentleman represents his wealth while showcasing his indifference to the serious topic of death discussed.", 49 | "keywords": "luxury, indifference", 50 | "source_id": "chunk-02baee20cc9463dbe08170a8e1043e32", 51 | "weight": 6.0 52 | }, 53 | "FAT MAN|#|GREAT GOLD SEALS": { 54 | "description": "The great gold seals symbolize the status of the fat man and reinforce the materialistic values held by the businessmen in the scene.", 55 | "keywords": "status, materialism", 56 | "source_id": "chunk-02baee20cc9463dbe08170a8e1043e32", 57 | "weight": 7.0 58 | }, 59 | "FAT MAN|#|GREAT GOLD SEALS|#|MONSTROUS CHIN|#|RED-FACED GENTLEMAN|#|SNUFF-BOX": { 60 | "description": "The interactions and physical portrayals among the fat man and the red-faced gentleman, along with exaggerated attributes such as the monstrous chin and items like the snuff-box and great gold seals, reflect the overarching theme of materialism and apathy towards mortality in wealthy society, emphasizing their disconnect from the value of human life.", 61 | "keywords": "materialism, apathy, wealth", 62 | "source_id": "chunk-02baee20cc9463dbe08170a8e1043e32", 63 | "weight": 8.0 64 | } 65 | } 66 | } 67 | } -------------------------------------------------------------------------------- /web-ui/frontend/src/pages/Hyper/Graph/index.jsx: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useMemo, useState } from 'react'; 2 | import { Graphin } from '@antv/graphin'; 3 | 4 | import { Select, Card, Tag } from 'antd'; 5 | 6 | const SERVER_URL = import.meta.env.VITE_SERVER_URL; 7 | 8 | const colors = [ 9 | '#F6BD16', 10 | '#00C9C9', 11 | '#F08F56', 12 | '#D580FF', 13 | '#FF3D00', 14 | '#16f69c', 15 | '#004ac9', 16 | '#f056d1', 17 | '#a680ff', 18 | '#c8ff00', 19 | ] 20 | 21 | export default () => { 22 | const [data, setData] = useState(undefined); 23 | const [keys, setKeys] = useState(undefined); 24 | const [key, setKey] = useState(undefined); 25 | const [item, setItem] = useState({ 26 | entity_name: '', 27 | entity_type: '', 28 | descriptions: [''], 29 | properties: [''] 30 | }); 31 | 32 | useEffect(() => { 33 | fetch(SERVER_URL + '/db/vertices') 34 | .then((res) => res.json()) 35 | .then((data) => { 36 | setKeys(data); 37 | } 38 | ) 39 | fetch(SERVER_URL + '/db/vertices_neighbor/' + '刘伯钦') 40 | .then((res) => res.json()) 41 | .then((data) => { 42 | setData(data); 43 | const item = data.vertices['刘伯钦']; 44 | setItem({ 45 | entity_name: item.entity_name, 46 | entity_type: item.entity_type, 47 | descriptions: item.description.split(''), 48 | properties: item.additional_properties.split('') 49 | }); 50 | } 51 | ) 52 | }, []); 53 | 54 | useEffect(() => { 55 | if (!key) return; 56 | fetch(SERVER_URL + '/db/vertices_neighbor/' + key) 57 | .then((res) => res.json()) 58 | .then((data) => { 59 | setData(data); 60 | const item = data.vertices[key]; 61 | setItem({ 62 | entity_name: item.entity_name, 63 | entity_type: item.entity_type, 64 | descriptions: item.description.split(''), 65 | properties: item.additional_properties.split('') 66 | }); 67 | } 68 | ) 69 | }, [key]); 70 | 71 | const options = useMemo( 72 | () => { 73 | let groupedNodesByCluster = {}; 74 | let createStyle = () => ({}); 75 | 76 | let hyperData = { 77 | nodes: [], 78 | edges: [], 79 | }; 80 | let plugins = []; 81 | if (data) { 82 | for (const key in data.vertices) { 83 | hyperData.nodes.push({ 84 | id: key, 85 | label: key, 86 | ...data.vertices[key], 87 | }); 88 | } 89 | 90 | createStyle = (baseColor) => ({ 91 | fill: baseColor, 92 | stroke: baseColor, 93 | labelFill: '#fff', 94 | labelPadding: 2, 95 | labelBackgroundFill: baseColor, 96 | labelBackgroundRadius: 5, 97 | labelPlacement: 'center', 98 | labelAutoRotate: false, 99 | // bubblesets 100 | maxRoutingIterations: 100, 101 | maxMarchingIterations: 20, 102 | pixelGroup: 4, 103 | edgeR0: 10, 104 | edgeR1: 60, 105 | nodeR0: 15, 106 | nodeR1: 50, 107 | morphBuffer: 10, 108 | threshold: 1, 109 | memberInfluenceFactor: 1, 110 | edgeInfluenceFactor: 1, 111 | nonMemberInfluenceFactor: -0.8, 112 | virtualEdges: true, 113 | }); 114 | 115 | const keys = Object.keys(data.edges); 116 | for (let i = 0; i < keys.length; i++) { 117 | const key = keys[i]; 118 | const edge = data.edges[key]; 119 | const nodes = key.split('|#|'); 120 | groupedNodesByCluster[key] = nodes; 121 | plugins.push({ 122 | key: `bubble-sets-${key}`, 123 | type: 'bubble-sets', 124 | members: nodes, 125 | labelText: '' + edge.keywords, 126 | ...createStyle(colors[i % 10]), 127 | }); 128 | } 129 | } 130 | 131 | plugins.push({ 132 | type: 'tooltip', 133 | getContent: (e, items) => { 134 | let result = ''; 135 | items.forEach((item) => { 136 | result += `

${item.id}

${item.description}

`; 137 | }); 138 | return result; 139 | }, 140 | }) 141 | 142 | console.log(hyperData); 143 | 144 | return { 145 | autoResize: true, 146 | data: hyperData, 147 | node: { 148 | palette: { field: 'cluster' }, 149 | style: { 150 | labelText: d => d.id, 151 | } 152 | }, 153 | animate: false, 154 | behaviors: [ 155 | // { 156 | // type: 'click-select', 157 | // degree: 1, 158 | // state: 'active', 159 | // unselectedState: 'inactive', 160 | // multiple: true, 161 | // trigger: ['shift'], 162 | // }, 163 | 'zoom-canvas', 'drag-canvas', 'drag-element', 164 | ], 165 | autoFit: 'center', 166 | layout: { 167 | type: 'force', 168 | // enableWorker: true, 169 | clustering: true, 170 | preventOverlap: true, 171 | // linkDistance: 700, 172 | nodeClusterBy: 'entity_type', 173 | gravity: 20 174 | }, 175 | plugins, 176 | } 177 | }, 178 | [data], 179 | ); 180 | 181 | if (!data) return

Loading...

; 182 | 183 | return <> 184 | 选择实体: 189 |
190 | 193 | 197 |
198 |

类型: {item.entity_type}

199 | 200 |
201 | 描述: 202 | {item.descriptions?.map((desc, index) => ( 203 |

{desc}

204 | ))} 205 |
206 | 207 |
208 | 特征: 209 |
210 | {item.properties?.map((prop, index) => ( 211 | 212 | {prop} 213 | 214 | ))} 215 |
216 |
217 |
218 |
219 |
220 | 221 | } -------------------------------------------------------------------------------- /web-ui/frontend/src/store/globalUser.ts: -------------------------------------------------------------------------------- 1 | import { makeAutoObservable } from 'mobx' 2 | 3 | class GlobalUser { 4 | userInfo: Partial = {} 5 | constructor() { 6 | makeAutoObservable(this) 7 | } 8 | 9 | async getUserDetail() { 10 | // const res = await getCurrentUserInfo() 11 | // this.userInfo = res?.data 12 | // new WebSee(res?.data?.username) 13 | this.userInfo = { 14 | roles: [ 15 | { 16 | id: 5, 17 | name: '超级管理员', 18 | description: '拥有所有查看和操作功能', 19 | adminCount: 0, 20 | status: 1, 21 | sort: 5 22 | } 23 | ], 24 | icon: 'http://jinpika-1308276765.cos.ap-shanghai.myqcloud.com/bootdemo-file/20221220/src=http___desk-fd.zol-img.com.cn_t_s960x600c5_g2_M00_00_0B_ChMlWl6yKqyILFoCACn-5rom2uIAAO4DgEODxAAKf7-298.jpg&refer=http___desk-fd.zol-img.com.png', 25 | username: 'admin' 26 | } 27 | } 28 | 29 | setUserInfo(user: Partial) { 30 | this.userInfo = user 31 | } 32 | } 33 | 34 | export const storeGlobalUser = new GlobalUser() 35 | -------------------------------------------------------------------------------- /web-ui/frontend/src/utils/index.js: -------------------------------------------------------------------------------- 1 | export const storage = function (key, value) { 2 | localStorage.setItem(key, JSON.stringify(value)) 3 | return value 4 | } -------------------------------------------------------------------------------- /web-ui/frontend/src/vite-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | 3 | interface ImportMetaEnv { 4 | readonly VITE_MODE: string 5 | readonly SERVER_URL: string 6 | readonly VITE_APP_URL: string 7 | // 更多环境变量... 8 | } 9 | 10 | interface ImportMeta { 11 | readonly env: ImportMetaEnv 12 | } 13 | 14 | interface Window { 15 | COS: any 16 | } 17 | -------------------------------------------------------------------------------- /web-ui/frontend/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ESNext", 4 | "useDefineForClassFields": true, 5 | "lib": ["DOM", "DOM.Iterable", "ESNext"], 6 | "allowJs": true, 7 | "skipLibCheck": true, 8 | "esModuleInterop": false, 9 | "allowSyntheticDefaultImports": true, 10 | "strict": false, 11 | "forceConsistentCasingInFileNames": true, 12 | "module": "ESNext", 13 | "moduleResolution": "Node", 14 | "resolveJsonModule": true, 15 | "isolatedModules": true, 16 | "experimentalDecorators": true, 17 | "noEmit": true, 18 | "jsx": "react-jsx", 19 | "baseUrl": "./", 20 | "paths": { 21 | "@/*": ["./src/*"], 22 | "@components/*": ["./src/components/*"], 23 | "@config/*": ["./config/*"] 24 | } 25 | }, 26 | // "include": ["src"], 27 | "exclude": ["node_modules", "build", "dist", "scripts", "webpack", "jest", "server"] 28 | } 29 | -------------------------------------------------------------------------------- /web-ui/frontend/vite.config.ts: -------------------------------------------------------------------------------- 1 | import { defineConfig } from 'vite' 2 | import react from '@vitejs/plugin-react' 3 | import eslintPlugin from 'vite-plugin-eslint' 4 | import { viteMockServe } from 'vite-plugin-mock' 5 | 6 | import path from 'path' 7 | import proxy from './config/proxy' 8 | 9 | // https://vitejs.dev/config/ 10 | export default defineConfig(({ mode }) => ({ 11 | plugins: [ 12 | react(), 13 | eslintPlugin({ 14 | include: ['src/**/*.js', 'src/**/*.ts', 'src/**/*.tsx', 'src/*.js', 'src/*.ts', 'src/*.tsx'] 15 | }), 16 | viteMockServe({ 17 | // default 18 | localEnabled: mode === 'mock', 19 | mockPath: './config/mock' 20 | }) 21 | ], 22 | resolve: { 23 | extensions: ['.mjs', '.js', '.jsx', '.ts', '.tsx', '.json', '.sass', '.scss'], // 忽略输入的扩展名 24 | alias: [ 25 | { find: /^~/, replacement: '' }, 26 | { find: '@', replacement: path.resolve(__dirname, 'src') }, 27 | { find: '~', replacement: path.resolve(__dirname, './node_modules') }, 28 | { 29 | find: '@components', 30 | replacement: path.resolve(__dirname, 'src/components') 31 | }, 32 | { find: '@config', replacement: path.resolve(__dirname, 'config') } 33 | // { 34 | // find: '@antd/dist/reset.css', 35 | // replacement: path.join(__dirname, 'node_modules/antd/dist/reset.css') 36 | // } 37 | // // { find: 'antd', replacement: path.join(__dirname, 'node_modules/antd/dist/antd.js') }, 38 | // // { 39 | // // find: '@ant-design/icons', 40 | // // replacement: path.join(__dirname, 'node_modules/@ant-design/icons/dist/index.umd.js') 41 | // // } 42 | ] 43 | }, 44 | css: { 45 | preprocessorOptions: { 46 | less: { 47 | // 支持内联 JavaScript 48 | javascriptEnabled: true 49 | } 50 | } 51 | }, 52 | server: { 53 | proxy: proxy[mode] 54 | }, 55 | build: { 56 | // 打包出map文件 57 | sourcemap: false 58 | } 59 | })) 60 | --------------------------------------------------------------------------------